blob: 6a5984873e336a773c384053b286882515542c38 [file] [log] [blame]
Jeremy Johnson015c3552022-02-23 12:15:03 +00001# Copyright (c) 2020-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3from enum import IntEnum
4from enum import unique
5
6import tensorflow as tf
7
8
9# Get a string name for a given shape
10def get_shape_str(shape, dtype):
11 shape_name = None
12 for dim in shape:
13 shape_name = (shape_name + "x" + str(dim)) if shape_name else str(dim)
14
15 if dtype == tf.float32:
16 shape_name = shape_name + "_f32"
17 elif dtype == tf.float16:
18 shape_name = shape_name + "_f16"
19 elif dtype == tf.int32:
20 shape_name = shape_name + "_i32"
21 elif dtype == tf.uint32:
22 shape_name = shape_name + "_u32"
23 elif dtype == tf.bool:
24 shape_name = shape_name + "_bool"
25 elif dtype == tf.quint8:
26 shape_name = shape_name + "_qu8"
27 elif dtype == tf.qint8:
28 shape_name = shape_name + "_qi8"
29 elif dtype == tf.qint16:
30 shape_name = shape_name + "_qi16"
31 elif dtype == tf.quint16:
32 shape_name = shape_name + "_qu16"
Luke Hutton714aa602023-02-08 19:45:26 +000033 elif dtype == tf.complex64:
34 shape_name = shape_name + "_c64"
Jeremy Johnson015c3552022-02-23 12:15:03 +000035 else:
36 raise Exception("Unsupported type: {}".format(dtype))
37
38 return shape_name
39
40
41@unique
42class QuantType(IntEnum):
43 UNKNOWN = 0
44 ALL_I8 = 1
45 ALL_U8 = 2
46 ALL_I16 = 3
47 # TODO: support QUINT16
48 CONV_U8_U8 = 4
49 CONV_I8_I8 = 5
50 CONV_I8_I4 = 6
51 CONV_I16_I8 = 7
52
53
54def get_tf_dtype(quantized_inference_dtype):
55 if quantized_inference_dtype == QuantType.ALL_I8:
56 return tf.qint8
57 elif quantized_inference_dtype == QuantType.ALL_U8:
58 return tf.quint8
59 elif quantized_inference_dtype == QuantType.ALL_I16:
60 return tf.qint16
61 elif quantized_inference_dtype == QuantType.CONV_U8_U8:
62 return tf.quint8
63 elif quantized_inference_dtype == QuantType.CONV_I8_I8:
64 return tf.qint8
65 elif quantized_inference_dtype == QuantType.CONV_I8_I4:
66 return tf.qint8
67 elif quantized_inference_dtype == QuantType.CONV_I16_I8:
68 return tf.qint16
69 else:
70 return None
71
72
73class TensorScale:
74 def __init__(self, _min, _max, _num_bits, _narrow_range):
75 self.min = _min
76 self.max = _max
77 self.num_bits = _num_bits
78 self.narrow_range = _narrow_range