blob: f31ac63af01b9a3b175fce0c02a978fdbbf1c031 [file] [log] [blame]
Jeremy Johnson015c3552022-02-23 12:15:03 +00001# Copyright (c) 2020-2022, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3from enum import IntEnum
4from enum import unique
5
6import tensorflow as tf
7
8
9# Get a string name for a given shape
10def get_shape_str(shape, dtype):
11 shape_name = None
Won Jeonf9c0cee2023-09-18 16:32:45 -070012 if len(shape) == 0:
13 shape_name = "0"
14
Jeremy Johnson015c3552022-02-23 12:15:03 +000015 for dim in shape:
16 shape_name = (shape_name + "x" + str(dim)) if shape_name else str(dim)
17
18 if dtype == tf.float32:
19 shape_name = shape_name + "_f32"
20 elif dtype == tf.float16:
21 shape_name = shape_name + "_f16"
22 elif dtype == tf.int32:
23 shape_name = shape_name + "_i32"
24 elif dtype == tf.uint32:
25 shape_name = shape_name + "_u32"
26 elif dtype == tf.bool:
27 shape_name = shape_name + "_bool"
28 elif dtype == tf.quint8:
29 shape_name = shape_name + "_qu8"
30 elif dtype == tf.qint8:
31 shape_name = shape_name + "_qi8"
32 elif dtype == tf.qint16:
33 shape_name = shape_name + "_qi16"
34 elif dtype == tf.quint16:
35 shape_name = shape_name + "_qu16"
Luke Hutton714aa602023-02-08 19:45:26 +000036 elif dtype == tf.complex64:
37 shape_name = shape_name + "_c64"
Jeremy Johnson015c3552022-02-23 12:15:03 +000038 else:
39 raise Exception("Unsupported type: {}".format(dtype))
40
41 return shape_name
42
43
44@unique
45class QuantType(IntEnum):
46 UNKNOWN = 0
47 ALL_I8 = 1
48 ALL_U8 = 2
49 ALL_I16 = 3
50 # TODO: support QUINT16
51 CONV_U8_U8 = 4
52 CONV_I8_I8 = 5
53 CONV_I8_I4 = 6
54 CONV_I16_I8 = 7
55
56
57def get_tf_dtype(quantized_inference_dtype):
58 if quantized_inference_dtype == QuantType.ALL_I8:
59 return tf.qint8
60 elif quantized_inference_dtype == QuantType.ALL_U8:
61 return tf.quint8
62 elif quantized_inference_dtype == QuantType.ALL_I16:
63 return tf.qint16
64 elif quantized_inference_dtype == QuantType.CONV_U8_U8:
65 return tf.quint8
66 elif quantized_inference_dtype == QuantType.CONV_I8_I8:
67 return tf.qint8
68 elif quantized_inference_dtype == QuantType.CONV_I8_I4:
69 return tf.qint8
70 elif quantized_inference_dtype == QuantType.CONV_I16_I8:
71 return tf.qint16
72 else:
73 return None
74
75
76class TensorScale:
77 def __init__(self, _min, _max, _num_bits, _narrow_range):
78 self.min = _min
79 self.max = _max
80 self.num_bits = _num_bits
81 self.narrow_range = _narrow_range