blob: 2d8e5d60f9e1988a3e09b082f3a85842843d9d31 [file] [log] [blame]
Jeremy Johnson015c3552022-02-23 12:15:03 +00001# Copyright (c) 2020-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3from enum import IntEnum
4from enum import unique
5
6import tensorflow as tf
7
8
9# Get a string name for a given shape
10def get_shape_str(shape, dtype):
11 shape_name = None
12 for dim in shape:
13 shape_name = (shape_name + "x" + str(dim)) if shape_name else str(dim)
14
15 if dtype == tf.float32:
16 shape_name = shape_name + "_f32"
17 elif dtype == tf.float16:
18 shape_name = shape_name + "_f16"
19 elif dtype == tf.int32:
20 shape_name = shape_name + "_i32"
21 elif dtype == tf.uint32:
22 shape_name = shape_name + "_u32"
23 elif dtype == tf.bool:
24 shape_name = shape_name + "_bool"
25 elif dtype == tf.quint8:
26 shape_name = shape_name + "_qu8"
27 elif dtype == tf.qint8:
28 shape_name = shape_name + "_qi8"
29 elif dtype == tf.qint16:
30 shape_name = shape_name + "_qi16"
31 elif dtype == tf.quint16:
32 shape_name = shape_name + "_qu16"
33 else:
34 raise Exception("Unsupported type: {}".format(dtype))
35
36 return shape_name
37
38
39@unique
40class QuantType(IntEnum):
41 UNKNOWN = 0
42 ALL_I8 = 1
43 ALL_U8 = 2
44 ALL_I16 = 3
45 # TODO: support QUINT16
46 CONV_U8_U8 = 4
47 CONV_I8_I8 = 5
48 CONV_I8_I4 = 6
49 CONV_I16_I8 = 7
50
51
52def get_tf_dtype(quantized_inference_dtype):
53 if quantized_inference_dtype == QuantType.ALL_I8:
54 return tf.qint8
55 elif quantized_inference_dtype == QuantType.ALL_U8:
56 return tf.quint8
57 elif quantized_inference_dtype == QuantType.ALL_I16:
58 return tf.qint16
59 elif quantized_inference_dtype == QuantType.CONV_U8_U8:
60 return tf.quint8
61 elif quantized_inference_dtype == QuantType.CONV_I8_I8:
62 return tf.qint8
63 elif quantized_inference_dtype == QuantType.CONV_I8_I4:
64 return tf.qint8
65 elif quantized_inference_dtype == QuantType.CONV_I16_I8:
66 return tf.qint16
67 else:
68 return None
69
70
71class TensorScale:
72 def __init__(self, _min, _max, _num_bits, _narrow_range):
73 self.min = _min
74 self.max = _max
75 self.num_bits = _num_bits
76 self.narrow_range = _narrow_range