blob: 829cef38d5c86c76641f6b1a6357b0073a804b5a [file] [log] [blame]
# SPDX-FileCopyrightText: Copyright 2020-2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the License); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Description:
# Defines the basic numeric type classes for tensors.
import enum
from typing import Any
import numpy as np
from .numeric_util import round_up_divide
class BaseType(enum.Flag):
Signed = 1
Unsigned = 2
Asymmetric = 4
Int = 8
SignedInt = Int | Signed
UnsignedInt = Int | Unsigned
AsymmSInt = Int | Asymmetric | Signed
AsymmUInt = Int | Asymmetric | Unsigned
Float = 16
BFloat = 32
Bool = 64
String = 128
Resource = 256
Variant = 512
Complex = 1024
class DataType:
"""Defines a data type. Consists of a base type, and the number of bits used for this type"""
__slots__ = "type", "bits"
int4: Any
int8: Any
int16: Any
int32: Any
int48: Any
int64: Any
uint8: Any
uint16: Any
uint32: Any
uint64: Any
quint4: Any
quint8: Any
quint12: Any
quint16: Any
quint32: Any
qint4: Any
qint8: Any
qint12: Any
qint16: Any
qint32: Any
float16: Any
float32: Any
float64: Any
string: Any
bool: Any
resource: Any
variant: Any
complex64: Any
complex128: Any
def __init__(self, type_, bits):
self.type = type_
self.bits = bits
def __eq__(self, other):
return self.type == other.type and self.bits == other.bits
def __hash__(self):
return hash((self.type, self.bits))
def size_in_bytes(self):
return round_up_divide(self.bits, 8)
def size_in_bits(self):
return self.bits
def __str__(self):
stem, needs_format = DataType.stem_name[self.type]
if not needs_format:
return stem
else:
return stem % (self.bits,)
__repr__ = __str__
def as_numpy_type(self):
numpy_dtype_code = {
BaseType.UnsignedInt: "u",
BaseType.SignedInt: "i",
BaseType.Float: "f",
BaseType.Complex: "c",
}
assert self.type in numpy_dtype_code, f"Failed to interpret {self} as a numpy dtype"
return np.dtype(numpy_dtype_code[self.type] + str(self.size_in_bytes()))
stem_name = {
BaseType.UnsignedInt: ("uint%s", True),
BaseType.SignedInt: ("int%s", True),
BaseType.AsymmUInt: ("quint%s", True),
BaseType.AsymmSInt: ("qint%s", True),
BaseType.Float: ("float%s", True),
BaseType.BFloat: ("bfloat%s", True),
BaseType.Bool: ("bool", False),
BaseType.String: ("string", False),
BaseType.Resource: ("resource", False),
BaseType.Variant: ("variant", False),
BaseType.Complex: ("complex%s", True),
}
# generate the standard set of data types
DataType.int4 = DataType(BaseType.SignedInt, 4)
DataType.int8 = DataType(BaseType.SignedInt, 8)
DataType.int16 = DataType(BaseType.SignedInt, 16)
DataType.int32 = DataType(BaseType.SignedInt, 32)
DataType.int48 = DataType(BaseType.SignedInt, 48)
DataType.int64 = DataType(BaseType.SignedInt, 64)
DataType.uint8 = DataType(BaseType.UnsignedInt, 8)
DataType.uint16 = DataType(BaseType.UnsignedInt, 16)
DataType.uint32 = DataType(BaseType.UnsignedInt, 32)
DataType.uint64 = DataType(BaseType.UnsignedInt, 64)
DataType.quint4 = DataType(BaseType.AsymmUInt, 4)
DataType.quint8 = DataType(BaseType.AsymmUInt, 8)
DataType.quint12 = DataType(BaseType.AsymmUInt, 12)
DataType.quint16 = DataType(BaseType.AsymmUInt, 16)
DataType.quint32 = DataType(BaseType.AsymmUInt, 32)
DataType.qint4 = DataType(BaseType.AsymmSInt, 4)
DataType.qint8 = DataType(BaseType.AsymmSInt, 8)
DataType.qint12 = DataType(BaseType.AsymmSInt, 12)
DataType.qint16 = DataType(BaseType.AsymmSInt, 16)
DataType.qint32 = DataType(BaseType.AsymmSInt, 32)
DataType.float16 = DataType(BaseType.Float, 16)
DataType.float32 = DataType(BaseType.Float, 32)
DataType.float64 = DataType(BaseType.Float, 64)
DataType.string = DataType(BaseType.String, 64)
DataType.bool = DataType(BaseType.Bool, 8)
DataType.resource = DataType(BaseType.Resource, 8)
DataType.variant = DataType(BaseType.Variant, 8)
DataType.complex64 = DataType(BaseType.Complex, 64)
DataType.complex128 = DataType(BaseType.Complex, 128)