Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 1 | # Copyright (c) 2020-2022, ARM Limited. |
| 2 | # SPDX-License-Identifier: Apache-2.0 |
| 3 | from enum import IntEnum |
| 4 | from enum import unique |
| 5 | |
| 6 | import tensorflow as tf |
| 7 | |
| 8 | |
| 9 | # Get a string name for a given shape |
| 10 | def get_shape_str(shape, dtype): |
| 11 | shape_name = None |
Won Jeon | f9c0cee | 2023-09-18 16:32:45 -0700 | [diff] [blame] | 12 | if len(shape) == 0: |
| 13 | shape_name = "0" |
| 14 | |
Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 15 | for dim in shape: |
| 16 | shape_name = (shape_name + "x" + str(dim)) if shape_name else str(dim) |
| 17 | |
| 18 | if dtype == tf.float32: |
| 19 | shape_name = shape_name + "_f32" |
| 20 | elif dtype == tf.float16: |
| 21 | shape_name = shape_name + "_f16" |
| 22 | elif dtype == tf.int32: |
| 23 | shape_name = shape_name + "_i32" |
| 24 | elif dtype == tf.uint32: |
| 25 | shape_name = shape_name + "_u32" |
| 26 | elif dtype == tf.bool: |
| 27 | shape_name = shape_name + "_bool" |
| 28 | elif dtype == tf.quint8: |
| 29 | shape_name = shape_name + "_qu8" |
| 30 | elif dtype == tf.qint8: |
| 31 | shape_name = shape_name + "_qi8" |
| 32 | elif dtype == tf.qint16: |
| 33 | shape_name = shape_name + "_qi16" |
| 34 | elif dtype == tf.quint16: |
| 35 | shape_name = shape_name + "_qu16" |
Luke Hutton | 714aa60 | 2023-02-08 19:45:26 +0000 | [diff] [blame] | 36 | elif dtype == tf.complex64: |
| 37 | shape_name = shape_name + "_c64" |
Jeremy Johnson | 015c355 | 2022-02-23 12:15:03 +0000 | [diff] [blame] | 38 | else: |
| 39 | raise Exception("Unsupported type: {}".format(dtype)) |
| 40 | |
| 41 | return shape_name |
| 42 | |
| 43 | |
| 44 | @unique |
| 45 | class QuantType(IntEnum): |
| 46 | UNKNOWN = 0 |
| 47 | ALL_I8 = 1 |
| 48 | ALL_U8 = 2 |
| 49 | ALL_I16 = 3 |
| 50 | # TODO: support QUINT16 |
| 51 | CONV_U8_U8 = 4 |
| 52 | CONV_I8_I8 = 5 |
| 53 | CONV_I8_I4 = 6 |
| 54 | CONV_I16_I8 = 7 |
| 55 | |
| 56 | |
| 57 | def get_tf_dtype(quantized_inference_dtype): |
| 58 | if quantized_inference_dtype == QuantType.ALL_I8: |
| 59 | return tf.qint8 |
| 60 | elif quantized_inference_dtype == QuantType.ALL_U8: |
| 61 | return tf.quint8 |
| 62 | elif quantized_inference_dtype == QuantType.ALL_I16: |
| 63 | return tf.qint16 |
| 64 | elif quantized_inference_dtype == QuantType.CONV_U8_U8: |
| 65 | return tf.quint8 |
| 66 | elif quantized_inference_dtype == QuantType.CONV_I8_I8: |
| 67 | return tf.qint8 |
| 68 | elif quantized_inference_dtype == QuantType.CONV_I8_I4: |
| 69 | return tf.qint8 |
| 70 | elif quantized_inference_dtype == QuantType.CONV_I16_I8: |
| 71 | return tf.qint16 |
| 72 | else: |
| 73 | return None |
| 74 | |
| 75 | |
| 76 | class TensorScale: |
| 77 | def __init__(self, _min, _max, _num_bits, _narrow_range): |
| 78 | self.min = _min |
| 79 | self.max = _max |
| 80 | self.num_bits = _num_bits |
| 81 | self.narrow_range = _narrow_range |