| # Copyright (c) 2020-2022, ARM Limited. |
| # SPDX-License-Identifier: Apache-2.0 |
| from enum import IntEnum |
| from enum import unique |
| |
| import tensorflow as tf |
| |
| |
| # Get a string name for a given shape |
| def get_shape_str(shape, dtype): |
| shape_name = None |
| if len(shape) == 0: |
| shape_name = "0" |
| |
| for dim in shape: |
| shape_name = (shape_name + "x" + str(dim)) if shape_name else str(dim) |
| |
| if dtype == tf.float32: |
| shape_name = shape_name + "_f32" |
| elif dtype == tf.float16: |
| shape_name = shape_name + "_f16" |
| elif dtype == tf.int32: |
| shape_name = shape_name + "_i32" |
| elif dtype == tf.uint32: |
| shape_name = shape_name + "_u32" |
| elif dtype == tf.bool: |
| shape_name = shape_name + "_bool" |
| elif dtype == tf.quint8: |
| shape_name = shape_name + "_qu8" |
| elif dtype == tf.qint8: |
| shape_name = shape_name + "_qi8" |
| elif dtype == tf.qint16: |
| shape_name = shape_name + "_qi16" |
| elif dtype == tf.quint16: |
| shape_name = shape_name + "_qu16" |
| elif dtype == tf.complex64: |
| shape_name = shape_name + "_c64" |
| else: |
| raise Exception("Unsupported type: {}".format(dtype)) |
| |
| return shape_name |
| |
| |
| @unique |
| class QuantType(IntEnum): |
| UNKNOWN = 0 |
| ALL_I8 = 1 |
| ALL_U8 = 2 |
| ALL_I16 = 3 |
| # TODO: support QUINT16 |
| CONV_U8_U8 = 4 |
| CONV_I8_I8 = 5 |
| CONV_I8_I4 = 6 |
| CONV_I16_I8 = 7 |
| |
| |
| def get_tf_dtype(quantized_inference_dtype): |
| if quantized_inference_dtype == QuantType.ALL_I8: |
| return tf.qint8 |
| elif quantized_inference_dtype == QuantType.ALL_U8: |
| return tf.quint8 |
| elif quantized_inference_dtype == QuantType.ALL_I16: |
| return tf.qint16 |
| elif quantized_inference_dtype == QuantType.CONV_U8_U8: |
| return tf.quint8 |
| elif quantized_inference_dtype == QuantType.CONV_I8_I8: |
| return tf.qint8 |
| elif quantized_inference_dtype == QuantType.CONV_I8_I4: |
| return tf.qint8 |
| elif quantized_inference_dtype == QuantType.CONV_I16_I8: |
| return tf.qint16 |
| else: |
| return None |
| |
| |
| class TensorScale: |
| def __init__(self, _min, _max, _num_bits, _narrow_range): |
| self.min = _min |
| self.max = _max |
| self.num_bits = _num_bits |
| self.narrow_range = _narrow_range |