blob: 54f241c4320fc380ba779d41177c5de338d4b8b5 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001# Copyright (c) 2021 Arm Limited. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""
16This script will provide you with a short example of how to perform clustering of weights (weight sharing) in
17TensorFlow using the TensorFlow Model Optimization Toolkit.
18
19The output from this example will be a TensorFlow Lite model file where weights in each layer have been 'clustered' into
2016 clusters during training - quantization has then been applied on top of this.
21
22By clustering the model we can improve compression of the model file. This can be essential for deploying certain
23models on systems with limited resources - such as embedded systems using an Arm Ethos NPU.
24
25After performing clustering we do post-training quantization to quantize the model and then generate a TensorFlow Lite file.
26
27If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
28compiler for further optimizations before it can be used.
29
30For more information on using Vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
31For more information on clustering see: https://www.tensorflow.org/model_optimization/guide/clustering
32"""
33import pathlib
34
35import tensorflow as tf
36import tensorflow_model_optimization as tfmot
37
38from training_utils import get_data, create_model
39from post_training_quantization import post_training_quantize, evaluate_tflite_model
40
41
42def prepare_for_clustering(keras_model):
43 """Prepares a Keras model for clustering."""
44
45 # Choose the number of clusters to use and how to initialize them. Using more clusters will generally
46 # reduce accuracy so you will need to find the optimal number for your use-case.
47 number_of_clusters = 16
48 cluster_centroids_init = tfmot.clustering.keras.CentroidInitialization.LINEAR
49
50 # Apply the clustering wrapper to the whole model so weights in every layer will get clustered. You may find that
51 # to avoid too much accuracy loss only certain non-critical layers in your model should be clustered.
52 clustering_ready_model = tfmot.clustering.keras.cluster_weights(keras_model,
53 number_of_clusters=number_of_clusters,
54 cluster_centroids_init=cluster_centroids_init)
55
56 # We must recompile the model after making it ready for clustering.
57 clustering_ready_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
58 loss=tf.keras.losses.sparse_categorical_crossentropy,
59 metrics=['accuracy'])
60
61 return clustering_ready_model
62
63
64def main():
65 x_train, y_train, x_test, y_test = get_data()
66 model = create_model()
67
68 # Compile and train the model first.
69 # In general it is easier to do clustering as a fine-tuning step after the model is fully trained.
70 model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
71 loss=tf.keras.losses.sparse_categorical_crossentropy,
72 metrics=['accuracy'])
73
74 model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
75
76 # Test the trained model accuracy.
77 test_loss, test_acc = model.evaluate(x_test, y_test)
78 print(f"Test accuracy before clustering: {test_acc:.3f}")
79
80 # Prepare the model for clustering.
81 clustered_model = prepare_for_clustering(model)
82
83 # Continue training the model but now with clustering applied.
84 clustered_model.fit(x=x_train, y=y_train, batch_size=128, epochs=1, verbose=1, shuffle=True)
85 test_loss, test_acc = clustered_model.evaluate(x_test, y_test)
86 print(f"Test accuracy after clustering: {test_acc:.3f}")
87
88 # Remove all variables that clustering only needed in the training phase.
89 model_for_export = tfmot.clustering.keras.strip_clustering(clustered_model)
90
91 # Apply post-training quantization on top of the clustering and save the resulting TensorFlow Lite model to file.
92 tflite_model = post_training_quantize(model_for_export, x_train)
93
94 tflite_models_dir = pathlib.Path('./conditioned_models/')
95 tflite_models_dir.mkdir(exist_ok=True, parents=True)
96
97 clustered_quant_model_save_path = tflite_models_dir / 'clustered_post_training_quant_model.tflite'
98 with open(clustered_quant_model_save_path, 'wb') as f:
99 f.write(tflite_model)
100
101 # Test the clustered quantized model accuracy. Save time by only testing a subset of the whole data.
102 num_test_samples = 1000
103 evaluate_tflite_model(clustered_quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples])
104
105
106if __name__ == "__main__":
107 main()