Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 1 | # Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. |
| 2 | # |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | # |
| 5 | # Licensed under the Apache License, Version 2.0 (the License); you may |
| 6 | # not use this file except in compliance with the License. |
| 7 | # You may obtain a copy of the License at |
| 8 | # |
| 9 | # www.apache.org/licenses/LICENSE-2.0 |
| 10 | # |
| 11 | # Unless required by applicable law or agreed to in writing, software |
| 12 | # distributed under the License is distributed on an AS IS BASIS, WITHOUT |
| 13 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | # See the License for the specific language governing permissions and |
| 15 | # limitations under the License. |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 16 | # Description: |
| 17 | # Defines the basic numeric type classes for tensors. |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 18 | import enum |
| 19 | |
Diego Russo | ea6111a | 2020-04-14 18:41:58 +0100 | [diff] [blame] | 20 | from .numeric_util import round_up_divide |
| 21 | |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 22 | |
| 23 | class BaseType(enum.Flag): |
| 24 | Signed = 1 |
| 25 | Unsigned = 2 |
| 26 | Asymmetric = 4 |
| 27 | Int = 8 |
| 28 | SignedInt = Int | Signed |
| 29 | UnsignedInt = Int | Unsigned |
| 30 | AsymmSInt = Int | Asymmetric | Signed |
| 31 | AsymmUInt = Int | Asymmetric | Unsigned |
| 32 | Float = 16 |
| 33 | BFloat = 32 |
| 34 | Bool = 64 |
| 35 | String = 128 |
| 36 | Resource = 256 |
| 37 | Variant = 512 |
Jacob Bohlin | f767b93 | 2020-08-13 15:32:45 +0200 | [diff] [blame] | 38 | Complex = 1024 |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 39 | |
| 40 | |
| 41 | class DataType: |
| 42 | """Defines a data type. Consists of a base type, and the number of bits used for this type""" |
| 43 | |
| 44 | __slots__ = "type", "bits" |
| 45 | |
| 46 | def __init__(self, type_, bits): |
| 47 | self.type = type_ |
| 48 | self.bits = bits |
| 49 | |
| 50 | def __eq__(self, other): |
| 51 | return self.type == other.type and self.bits == other.bits |
| 52 | |
| 53 | def __hash__(self): |
| 54 | return hash((self.type, self.bits)) |
| 55 | |
| 56 | def size_in_bytes(self): |
| 57 | return round_up_divide(self.bits, 8) |
| 58 | |
| 59 | def size_in_bits(self): |
| 60 | return self.bits |
| 61 | |
| 62 | def __str__(self): |
| 63 | stem, needs_format = DataType.stem_name[self.type] |
| 64 | if not needs_format: |
| 65 | return stem |
| 66 | else: |
| 67 | return stem % (self.bits,) |
| 68 | |
| 69 | __repr__ = __str__ |
| 70 | |
| 71 | stem_name = { |
| 72 | BaseType.UnsignedInt: ("uint%s", True), |
| 73 | BaseType.SignedInt: ("int%s", True), |
| 74 | BaseType.AsymmUInt: ("quint%s", True), |
| 75 | BaseType.AsymmSInt: ("qint%s", True), |
| 76 | BaseType.Float: ("float%s", True), |
| 77 | BaseType.BFloat: ("bfloat%s", True), |
| 78 | BaseType.Bool: ("bool", False), |
| 79 | BaseType.String: ("string", False), |
| 80 | BaseType.Resource: ("resource", False), |
| 81 | BaseType.Variant: ("variant", False), |
Jacob Bohlin | f767b93 | 2020-08-13 15:32:45 +0200 | [diff] [blame] | 82 | BaseType.Complex: ("complex%s", True), |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 83 | } |
| 84 | |
| 85 | |
| 86 | # generate the standard set of data types |
| 87 | DataType.int8 = DataType(BaseType.SignedInt, 8) |
| 88 | DataType.int16 = DataType(BaseType.SignedInt, 16) |
| 89 | DataType.int32 = DataType(BaseType.SignedInt, 32) |
| 90 | DataType.int64 = DataType(BaseType.SignedInt, 64) |
| 91 | |
| 92 | DataType.uint8 = DataType(BaseType.UnsignedInt, 8) |
| 93 | DataType.uint16 = DataType(BaseType.UnsignedInt, 16) |
| 94 | DataType.uint32 = DataType(BaseType.UnsignedInt, 32) |
| 95 | DataType.uint64 = DataType(BaseType.UnsignedInt, 64) |
| 96 | |
| 97 | DataType.quint4 = DataType(BaseType.AsymmUInt, 4) |
| 98 | DataType.quint8 = DataType(BaseType.AsymmUInt, 8) |
| 99 | DataType.quint12 = DataType(BaseType.AsymmUInt, 12) |
| 100 | DataType.quint16 = DataType(BaseType.AsymmUInt, 16) |
| 101 | DataType.quint32 = DataType(BaseType.AsymmUInt, 32) |
| 102 | |
| 103 | DataType.qint4 = DataType(BaseType.AsymmSInt, 4) |
| 104 | DataType.qint8 = DataType(BaseType.AsymmSInt, 8) |
| 105 | DataType.qint12 = DataType(BaseType.AsymmSInt, 12) |
| 106 | DataType.qint16 = DataType(BaseType.AsymmSInt, 16) |
| 107 | DataType.qint32 = DataType(BaseType.AsymmSInt, 32) |
| 108 | |
| 109 | DataType.float16 = DataType(BaseType.Float, 16) |
| 110 | DataType.float32 = DataType(BaseType.Float, 32) |
| 111 | DataType.float64 = DataType(BaseType.Float, 64) |
| 112 | |
| 113 | DataType.string = DataType(BaseType.String, 64) |
| 114 | DataType.bool = DataType(BaseType.Bool, 8) |
| 115 | DataType.resource = DataType(BaseType.Resource, 8) |
| 116 | DataType.variant = DataType(BaseType.Variant, 8) |
Jacob Bohlin | f767b93 | 2020-08-13 15:32:45 +0200 | [diff] [blame] | 117 | DataType.complex64 = DataType(BaseType.Complex, 64) |
Jacob Bohlin | 8daf6b7 | 2020-09-15 16:28:35 +0200 | [diff] [blame] | 118 | DataType.complex128 = DataType(BaseType.Complex, 128) |