blob: a39be0e1ac2ac0ffded3e8b86d018bc9769757e3 [file] [log] [blame]
Richard Burtonf32a86a2022-11-15 11:46:11 +00001# SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
alexander3c798932021-03-26 21:42:19 +00002# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""
16This script will provide you with an example of how to perform post-training quantization in TensorFlow.
17
18The output from this example will be a TensorFlow Lite model file where weights and activations are quantized to 8bit
19integer values.
20
21Quantization helps reduce the size of your models and is necessary for running models on certain hardware such as Arm
22Ethos NPU.
23
24In addition to quantizing weights, post-training quantization uses a calibration dataset to
25capture the minimum and maximum values of all variable tensors in your model.
26By capturing these ranges it is possible to fully quantize not just the weights of the model but also the activations.
27
28Depending on the model you are quantizing there may be some accuracy loss, but for a lot of models the loss should
29be minimal.
30
31If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
32compiler for further optimizations before it can be used.
33
34For more information on using Vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
35For more information on post-training quantization
36see: https://www.tensorflow.org/lite/performance/post_training_integer_quant
37"""
38import pathlib
39
40import numpy as np
41import tensorflow as tf
42
43from training_utils import get_data, create_model
44
45
46def post_training_quantize(keras_model, sample_data):
47 """Quantize Keras model using post-training quantization with some sample data.
48
49 TensorFlow Lite will have fp32 inputs/outputs and the model will handle quantizing/dequantizing.
50
51 Args:
52 keras_model: Keras model to quantize.
53 sample_data: A numpy array of data to use as a representative dataset.
54
55 Returns:
56 Quantized TensorFlow Lite model.
57 """
58
59 converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
60
61 # We set the following converter options to ensure our model is fully quantized.
62 # An error should get thrown if there is any ops that can't be quantized.
63 converter.optimizations = [tf.lite.Optimize.DEFAULT]
64 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
65
66 # To use post training quantization we must provide some sample data that will be used to
67 # calculate activation ranges for quantization. This data should be representative of the data
68 # we expect to feed the model and must be provided by a generator function.
69 def generate_repr_dataset():
70 for i in range(100): # 100 samples is all we should need in this example.
71 yield [np.expand_dims(sample_data[i], axis=0)]
72
73 converter.representative_dataset = generate_repr_dataset
74 tflite_model = converter.convert()
75
76 return tflite_model
77
78
79def evaluate_tflite_model(tflite_save_path, x_test, y_test):
80 """Calculate the accuracy of a TensorFlow Lite model using TensorFlow Lite interpreter.
81
82 Args:
83 tflite_save_path: Path to TensorFlow Lite model to test.
84 x_test: numpy array of testing data.
85 y_test: numpy array of testing labels (sparse categorical).
86 """
87
88 interpreter = tf.lite.Interpreter(model_path=str(tflite_save_path))
89
90 interpreter.allocate_tensors()
91 input_details = interpreter.get_input_details()
92 output_details = interpreter.get_output_details()
93
94 accuracy_count = 0
95 num_test_images = len(y_test)
96
97 for i in range(num_test_images):
98 interpreter.set_tensor(input_details[0]['index'], x_test[i][np.newaxis, ...])
99 interpreter.invoke()
100 output_data = interpreter.get_tensor(output_details[0]['index'])
101
102 if np.argmax(output_data) == y_test[i]:
103 accuracy_count += 1
104
105 print(f"Test accuracy quantized: {accuracy_count / num_test_images:.3f}")
106
107
108def main():
109 x_train, y_train, x_test, y_test = get_data()
110 model = create_model()
111
112 # Compile and train the model in fp32 as normal.
113 model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
114 loss=tf.keras.losses.sparse_categorical_crossentropy,
115 metrics=['accuracy'])
116
117 model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
118
119 # Test the fp32 model accuracy.
120 test_loss, test_acc = model.evaluate(x_test, y_test)
121 print(f"Test accuracy float: {test_acc:.3f}")
122
123 # Quantize and export the resulting TensorFlow Lite model to file.
124 tflite_model = post_training_quantize(model, x_train)
125
126 tflite_models_dir = pathlib.Path('./conditioned_models/')
127 tflite_models_dir.mkdir(exist_ok=True, parents=True)
128
129 quant_model_save_path = tflite_models_dir / 'post_training_quant_model.tflite'
130 with open(quant_model_save_path, 'wb') as f:
131 f.write(tflite_model)
132
133 # Test the quantized model accuracy. Save time by only testing a subset of the whole data.
134 num_test_samples = 1000
135 evaluate_tflite_model(quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples])
136
137
138if __name__ == "__main__":
139 main()