blob: d590763d04eaee930b22baa122edff35cdeafa81 [file] [log] [blame]
Alex Tawsedaba3cf2023-09-29 15:55:38 +01001# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
alexander3c798932021-03-26 21:42:19 +00002# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""
Alex Tawsedaba3cf2023-09-29 15:55:38 +010016This script will provide you with a short example of how to perform
17quantization aware training in TensorFlow using the
alexander3c798932021-03-26 21:42:19 +000018TensorFlow Model Optimization Toolkit.
19
Alex Tawsedaba3cf2023-09-29 15:55:38 +010020The output from this example will be a TensorFlow Lite model file
21where weights and activations are quantized to 8bit integer values.
alexander3c798932021-03-26 21:42:19 +000022
Alex Tawsedaba3cf2023-09-29 15:55:38 +010023Quantization helps reduce the size of your models and is necessary
24for running models on certain hardware such as Arm Ethos NPU.
alexander3c798932021-03-26 21:42:19 +000025
Alex Tawsedaba3cf2023-09-29 15:55:38 +010026In quantization aware training (QAT), the error introduced with
27quantizing from fp32 to int8 is simulated using fake quantization nodes.
28By simulating this quantization error when training,
29the model can learn better adapted weights and minimize accuracy losses
30caused by the reduced precision.
alexander3c798932021-03-26 21:42:19 +000031
Alex Tawsedaba3cf2023-09-29 15:55:38 +010032Minimum and maximum values for activations are also captured
33during training so activations for every layer can be quantized
34along with the weights later.
alexander3c798932021-03-26 21:42:19 +000035
Alex Tawsedaba3cf2023-09-29 15:55:38 +010036Quantization is only simulated during training and the
37training backward passes are still performed in full float precision.
38Actual quantization happens when generating a TensorFlow Lite model.
alexander3c798932021-03-26 21:42:19 +000039
Alex Tawsedaba3cf2023-09-29 15:55:38 +010040If you are targeting an Arm Ethos-U55 NPU then the output
41TensorFlow Lite file will also need to be passed through the Vela
alexander3c798932021-03-26 21:42:19 +000042compiler for further optimizations before it can be used.
43
Alex Tawsedaba3cf2023-09-29 15:55:38 +010044For more information on using vela see:
45 https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
46For more information on quantization aware training see:
47 https://www.tensorflow.org/model_optimization/guide/quantization/training
alexander3c798932021-03-26 21:42:19 +000048"""
49import pathlib
50
51import numpy as np
52import tensorflow as tf
53import tensorflow_model_optimization as tfmot
54
55from training_utils import get_data, create_model
56
57
58def quantize_and_convert_to_tflite(keras_model):
59 """Quantize and convert Keras model trained with QAT to TensorFlow Lite.
60
61 TensorFlow Lite will have fp32 inputs/outputs and the model will handle quantizing/dequantizing.
62
63 Args:
64 keras_model: Keras model trained with quantization aware training.
65
66 Returns:
67 Quantized TensorFlow Lite model.
68 """
69
70 converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
71
72 # After doing quantization aware training all the information for creating a fully quantized
73 # TensorFlow Lite model is already within the quantization aware Keras model.
Alex Tawsedaba3cf2023-09-29 15:55:38 +010074 # This means we only need to call convert with default optimizations to
75 # generate the quantized TensorFlow Lite model.
alexander3c798932021-03-26 21:42:19 +000076 converter.optimizations = [tf.lite.Optimize.DEFAULT]
77 tflite_model = converter.convert()
78
79 return tflite_model
80
81
Alex Tawsedaba3cf2023-09-29 15:55:38 +010082# pylint: disable=duplicate-code
alexander3c798932021-03-26 21:42:19 +000083def evaluate_tflite_model(tflite_save_path, x_test, y_test):
84 """Calculate the accuracy of a TensorFlow Lite model using TensorFlow Lite interpreter.
85
86 Args:
87 tflite_save_path: Path to TensorFlow Lite model to test.
88 x_test: numpy array of testing data.
89 y_test: numpy array of testing labels (sparse categorical).
90 """
91
92 interpreter = tf.lite.Interpreter(model_path=str(tflite_save_path))
93
94 interpreter.allocate_tensors()
95 input_details = interpreter.get_input_details()
96 output_details = interpreter.get_output_details()
97
98 accuracy_count = 0
99 num_test_images = len(y_test)
100
101 for i in range(num_test_images):
102 interpreter.set_tensor(input_details[0]['index'], x_test[i][np.newaxis, ...])
103 interpreter.invoke()
104 output_data = interpreter.get_tensor(output_details[0]['index'])
105
106 if np.argmax(output_data) == y_test[i]:
107 accuracy_count += 1
108
109 print(f"Test accuracy quantized: {accuracy_count / num_test_images:.3f}")
110
111
112def main():
Alex Tawsedaba3cf2023-09-29 15:55:38 +0100113 """
114 Run quantization aware training
115 """
alexander3c798932021-03-26 21:42:19 +0000116 x_train, y_train, x_test, y_test = get_data()
117 model = create_model()
118
Alex Tawsedaba3cf2023-09-29 15:55:38 +0100119 # When working with the TensorFlow Keras API and theTF Model Optimization Toolkit
120 # we can make our model quantization aware in one line.
121 # Once this is done we compile the model and train as normal.
122 # It is important to note that the model is only quantization aware
123 # and is not quantized yet.
124 # The weights are still floating point and will only be converted
125 # to int8 when we generate the TensorFlow Lite model later on.
alexander3c798932021-03-26 21:42:19 +0000126 quant_aware_model = tfmot.quantization.keras.quantize_model(model)
127
128 quant_aware_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
129 loss=tf.keras.losses.sparse_categorical_crossentropy,
130 metrics=['accuracy'])
131
132 quant_aware_model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
133
134 # Test the quantization aware model accuracy.
Alex Tawsedaba3cf2023-09-29 15:55:38 +0100135 test_loss, test_acc = quant_aware_model.evaluate(x_test, y_test) # pylint: disable=unused-variable
alexander3c798932021-03-26 21:42:19 +0000136 print(f"Test accuracy quant aware: {test_acc:.3f}")
137
138 # Quantize and save the resulting TensorFlow Lite model to file.
139 tflite_model = quantize_and_convert_to_tflite(quant_aware_model)
140
141 tflite_models_dir = pathlib.Path('./conditioned_models/')
142 tflite_models_dir.mkdir(exist_ok=True, parents=True)
143
144 quant_model_save_path = tflite_models_dir / 'qat_quant_model.tflite'
145 with open(quant_model_save_path, 'wb') as f:
146 f.write(tflite_model)
147
148 # Test quantized model accuracy. Save time by only testing a subset of the whole data.
149 num_test_samples = 1000
Alex Tawsedaba3cf2023-09-29 15:55:38 +0100150 evaluate_tflite_model(
151 quant_model_save_path,
152 x_test[0:num_test_samples],
153 y_test[0:num_test_samples]
154 )
155# pylint: enable=duplicate-code
alexander3c798932021-03-26 21:42:19 +0000156
157
158if __name__ == "__main__":
159 main()