blob: f31ac63af01b9a3b175fce0c02a978fdbbf1c031 [file] [log] [blame]
# Copyright (c) 2020-2022, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
from enum import IntEnum
from enum import unique
import tensorflow as tf
# Get a string name for a given shape
def get_shape_str(shape, dtype):
shape_name = None
if len(shape) == 0:
shape_name = "0"
for dim in shape:
shape_name = (shape_name + "x" + str(dim)) if shape_name else str(dim)
if dtype == tf.float32:
shape_name = shape_name + "_f32"
elif dtype == tf.float16:
shape_name = shape_name + "_f16"
elif dtype == tf.int32:
shape_name = shape_name + "_i32"
elif dtype == tf.uint32:
shape_name = shape_name + "_u32"
elif dtype == tf.bool:
shape_name = shape_name + "_bool"
elif dtype == tf.quint8:
shape_name = shape_name + "_qu8"
elif dtype == tf.qint8:
shape_name = shape_name + "_qi8"
elif dtype == tf.qint16:
shape_name = shape_name + "_qi16"
elif dtype == tf.quint16:
shape_name = shape_name + "_qu16"
elif dtype == tf.complex64:
shape_name = shape_name + "_c64"
else:
raise Exception("Unsupported type: {}".format(dtype))
return shape_name
@unique
class QuantType(IntEnum):
UNKNOWN = 0
ALL_I8 = 1
ALL_U8 = 2
ALL_I16 = 3
# TODO: support QUINT16
CONV_U8_U8 = 4
CONV_I8_I8 = 5
CONV_I8_I4 = 6
CONV_I16_I8 = 7
def get_tf_dtype(quantized_inference_dtype):
if quantized_inference_dtype == QuantType.ALL_I8:
return tf.qint8
elif quantized_inference_dtype == QuantType.ALL_U8:
return tf.quint8
elif quantized_inference_dtype == QuantType.ALL_I16:
return tf.qint16
elif quantized_inference_dtype == QuantType.CONV_U8_U8:
return tf.quint8
elif quantized_inference_dtype == QuantType.CONV_I8_I8:
return tf.qint8
elif quantized_inference_dtype == QuantType.CONV_I8_I4:
return tf.qint8
elif quantized_inference_dtype == QuantType.CONV_I16_I8:
return tf.qint16
else:
return None
class TensorScale:
def __init__(self, _min, _max, _num_bits, _narrow_range):
self.min = _min
self.max = _max
self.num_bits = _num_bits
self.narrow_range = _narrow_range