Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 1 | # Copyright (c) 2020-2022, ARM Limited. |
| 2 | # SPDX-License-Identifier: Apache-2.0 |
| 3 | from enum import IntEnum |
| 4 | from enum import unique |
| 5 | |
| 6 | import tensorflow as tf |
| 7 | |
| 8 | |
| 9 | # Get a string name for a given shape |
| 10 | def get_shape_str(shape, dtype): |
| 11 | shape_name = None |
| 12 | for dim in shape: |
| 13 | shape_name = (shape_name + "x" + str(dim)) if shape_name else str(dim) |
| 14 | |
| 15 | if dtype == tf.float32: |
| 16 | shape_name = shape_name + "_f32" |
| 17 | elif dtype == tf.float16: |
| 18 | shape_name = shape_name + "_f16" |
| 19 | elif dtype == tf.int32: |
| 20 | shape_name = shape_name + "_i32" |
| 21 | elif dtype == tf.uint32: |
| 22 | shape_name = shape_name + "_u32" |
| 23 | elif dtype == tf.bool: |
| 24 | shape_name = shape_name + "_bool" |
| 25 | elif dtype == tf.quint8: |
| 26 | shape_name = shape_name + "_qu8" |
| 27 | elif dtype == tf.qint8: |
| 28 | shape_name = shape_name + "_qi8" |
| 29 | elif dtype == tf.qint16: |
| 30 | shape_name = shape_name + "_qi16" |
| 31 | elif dtype == tf.quint16: |
| 32 | shape_name = shape_name + "_qu16" |
Luke Hutton | 714aa60 | 2023-02-08 19:45:26 +0000 | [diff] [blame] | 33 | elif dtype == tf.complex64: |
| 34 | shape_name = shape_name + "_c64" |
Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 35 | else: |
| 36 | raise Exception("Unsupported type: {}".format(dtype)) |
| 37 | |
| 38 | return shape_name |
| 39 | |
| 40 | |
| 41 | @unique |
| 42 | class QuantType(IntEnum): |
| 43 | UNKNOWN = 0 |
| 44 | ALL_I8 = 1 |
| 45 | ALL_U8 = 2 |
| 46 | ALL_I16 = 3 |
| 47 | # TODO: support QUINT16 |
| 48 | CONV_U8_U8 = 4 |
| 49 | CONV_I8_I8 = 5 |
| 50 | CONV_I8_I4 = 6 |
| 51 | CONV_I16_I8 = 7 |
| 52 | |
| 53 | |
| 54 | def get_tf_dtype(quantized_inference_dtype): |
| 55 | if quantized_inference_dtype == QuantType.ALL_I8: |
| 56 | return tf.qint8 |
| 57 | elif quantized_inference_dtype == QuantType.ALL_U8: |
| 58 | return tf.quint8 |
| 59 | elif quantized_inference_dtype == QuantType.ALL_I16: |
| 60 | return tf.qint16 |
| 61 | elif quantized_inference_dtype == QuantType.CONV_U8_U8: |
| 62 | return tf.quint8 |
| 63 | elif quantized_inference_dtype == QuantType.CONV_I8_I8: |
| 64 | return tf.qint8 |
| 65 | elif quantized_inference_dtype == QuantType.CONV_I8_I4: |
| 66 | return tf.qint8 |
| 67 | elif quantized_inference_dtype == QuantType.CONV_I16_I8: |
| 68 | return tf.qint16 |
| 69 | else: |
| 70 | return None |
| 71 | |
| 72 | |
| 73 | class TensorScale: |
| 74 | def __init__(self, _min, _max, _num_bits, _narrow_range): |
| 75 | self.min = _min |
| 76 | self.max = _max |
| 77 | self.num_bits = _num_bits |
| 78 | self.narrow_range = _narrow_range |