blob: acb768c0302ee48a92051b879af672b7b262e970 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001# Copyright (c) 2021 Arm Limited. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""
16This script will provide you with a short example of how to perform quantization aware training in TensorFlow using the
17TensorFlow Model Optimization Toolkit.
18
19The output from this example will be a TensorFlow Lite model file where weights and activations are quantized to 8bit
20integer values.
21
22Quantization helps reduce the size of your models and is necessary for running models on certain hardware such as Arm
23Ethos NPU.
24
25In quantization aware training (QAT), the error introduced with quantizing from fp32 to int8 is simulated using
26fake quantization nodes. By simulating this quantization error when training, the model can learn better adapted
27weights and minimize accuracy losses caused by the reduced precision.
28
29Minimum and maximum values for activations are also captured during training so activations for every layer can be
30quantized along with the weights later.
31
32Quantization is only simulated during training and the training backward passes are still performed in full float
33precision. Actual quantization happens when generating a TensorFlow Lite model.
34
35If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
36compiler for further optimizations before it can be used.
37
38For more information on using vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
39For more information on quantization aware training
40see: https://www.tensorflow.org/model_optimization/guide/quantization/training
41"""
42import pathlib
43
44import numpy as np
45import tensorflow as tf
46import tensorflow_model_optimization as tfmot
47
48from training_utils import get_data, create_model
49
50
51def quantize_and_convert_to_tflite(keras_model):
52 """Quantize and convert Keras model trained with QAT to TensorFlow Lite.
53
54 TensorFlow Lite will have fp32 inputs/outputs and the model will handle quantizing/dequantizing.
55
56 Args:
57 keras_model: Keras model trained with quantization aware training.
58
59 Returns:
60 Quantized TensorFlow Lite model.
61 """
62
63 converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
64
65 # After doing quantization aware training all the information for creating a fully quantized
66 # TensorFlow Lite model is already within the quantization aware Keras model.
67 # This means we only need to call convert with default optimizations to generate the quantized TensorFlow Lite model.
68 converter.optimizations = [tf.lite.Optimize.DEFAULT]
69 tflite_model = converter.convert()
70
71 return tflite_model
72
73
74def evaluate_tflite_model(tflite_save_path, x_test, y_test):
75 """Calculate the accuracy of a TensorFlow Lite model using TensorFlow Lite interpreter.
76
77 Args:
78 tflite_save_path: Path to TensorFlow Lite model to test.
79 x_test: numpy array of testing data.
80 y_test: numpy array of testing labels (sparse categorical).
81 """
82
83 interpreter = tf.lite.Interpreter(model_path=str(tflite_save_path))
84
85 interpreter.allocate_tensors()
86 input_details = interpreter.get_input_details()
87 output_details = interpreter.get_output_details()
88
89 accuracy_count = 0
90 num_test_images = len(y_test)
91
92 for i in range(num_test_images):
93 interpreter.set_tensor(input_details[0]['index'], x_test[i][np.newaxis, ...])
94 interpreter.invoke()
95 output_data = interpreter.get_tensor(output_details[0]['index'])
96
97 if np.argmax(output_data) == y_test[i]:
98 accuracy_count += 1
99
100 print(f"Test accuracy quantized: {accuracy_count / num_test_images:.3f}")
101
102
103def main():
104 x_train, y_train, x_test, y_test = get_data()
105 model = create_model()
106
107 # When working with the TensorFlow Keras API and the TF Model Optimization Toolkit we can make our
108 # model quantization aware in one line. Once this is done we compile the model and train as normal.
109 # It is important to note that the model is only quantization aware and is not quantized yet. The weights are
110 # still floating point and will only be converted to int8 when we generate the TensorFlow Lite model later on.
111 quant_aware_model = tfmot.quantization.keras.quantize_model(model)
112
113 quant_aware_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
114 loss=tf.keras.losses.sparse_categorical_crossentropy,
115 metrics=['accuracy'])
116
117 quant_aware_model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
118
119 # Test the quantization aware model accuracy.
120 test_loss, test_acc = quant_aware_model.evaluate(x_test, y_test)
121 print(f"Test accuracy quant aware: {test_acc:.3f}")
122
123 # Quantize and save the resulting TensorFlow Lite model to file.
124 tflite_model = quantize_and_convert_to_tflite(quant_aware_model)
125
126 tflite_models_dir = pathlib.Path('./conditioned_models/')
127 tflite_models_dir.mkdir(exist_ok=True, parents=True)
128
129 quant_model_save_path = tflite_models_dir / 'qat_quant_model.tflite'
130 with open(quant_model_save_path, 'wb') as f:
131 f.write(tflite_model)
132
133 # Test quantized model accuracy. Save time by only testing a subset of the whole data.
134 num_test_samples = 1000
135 evaluate_tflite_model(quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples])
136
137
138if __name__ == "__main__":
139 main()