blob: e966336c8088ed6cfa08d92416a2a3b9c59f0c96 [file] [log] [blame]
Alex Tawsedaba3cf2023-09-29 15:55:38 +01001# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
alexander3c798932021-03-26 21:42:19 +00002# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""
Alex Tawsedaba3cf2023-09-29 15:55:38 +010016This script will provide you with a short example of how to perform
17clustering of weights (weight sharing) in TensorFlow
18using the TensorFlow Model Optimization Toolkit.
alexander3c798932021-03-26 21:42:19 +000019
Alex Tawsedaba3cf2023-09-29 15:55:38 +010020The output from this example will be a TensorFlow Lite model file
21where weights in each layer have been 'clustered' into 16 clusters
22during training - quantization has then been applied on top of this.
alexander3c798932021-03-26 21:42:19 +000023
Alex Tawsedaba3cf2023-09-29 15:55:38 +010024By clustering the model we can improve compression of the model file.
25This can be essential for deploying certain models on systems with
26limited resources - such as embedded systems using an Arm Ethos NPU.
alexander3c798932021-03-26 21:42:19 +000027
Alex Tawsedaba3cf2023-09-29 15:55:38 +010028After performing clustering we do post-training quantization
29to quantize the model and then generate a TensorFlow Lite file.
alexander3c798932021-03-26 21:42:19 +000030
Alex Tawsedaba3cf2023-09-29 15:55:38 +010031If you are targeting an Arm Ethos-U55 NPU then the output
32TensorFlow Lite file will also need to be passed through the Vela
alexander3c798932021-03-26 21:42:19 +000033compiler for further optimizations before it can be used.
34
Alex Tawsedaba3cf2023-09-29 15:55:38 +010035For more information on using Vela see:
36 https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
37For more information on clustering see:
38 https://www.tensorflow.org/model_optimization/guide/clustering
alexander3c798932021-03-26 21:42:19 +000039"""
40import pathlib
41
42import tensorflow as tf
43import tensorflow_model_optimization as tfmot
44
45from training_utils import get_data, create_model
46from post_training_quantization import post_training_quantize, evaluate_tflite_model
47
48
49def prepare_for_clustering(keras_model):
50 """Prepares a Keras model for clustering."""
51
Alex Tawsedaba3cf2023-09-29 15:55:38 +010052 # Choose the number of clusters to use and how to initialize them.
53 # Using more clusters will generally reduce accuracy,
54 # so you will need to find the optimal number for your use-case.
alexander3c798932021-03-26 21:42:19 +000055 number_of_clusters = 16
56 cluster_centroids_init = tfmot.clustering.keras.CentroidInitialization.LINEAR
57
Alex Tawsedaba3cf2023-09-29 15:55:38 +010058 # Apply the clustering wrapper to the whole model so weights in
59 # every layer will get clustered. You may find that to avoid
60 # too much accuracy loss only certain non-critical layers in
61 # your model should be clustered.
62 clustering_ready_model = tfmot.clustering.keras.cluster_weights(
63 keras_model,
64 number_of_clusters=number_of_clusters,
65 cluster_centroids_init=cluster_centroids_init
66 )
alexander3c798932021-03-26 21:42:19 +000067
68 # We must recompile the model after making it ready for clustering.
Alex Tawsedaba3cf2023-09-29 15:55:38 +010069 clustering_ready_model.compile(
70 optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
71 loss=tf.keras.losses.sparse_categorical_crossentropy,
72 metrics=['accuracy']
73 )
alexander3c798932021-03-26 21:42:19 +000074
75 return clustering_ready_model
76
77
78def main():
Alex Tawsedaba3cf2023-09-29 15:55:38 +010079 """
80 Run weight clustering
81 """
alexander3c798932021-03-26 21:42:19 +000082 x_train, y_train, x_test, y_test = get_data()
83 model = create_model()
84
85 # Compile and train the model first.
Alex Tawsedaba3cf2023-09-29 15:55:38 +010086 # In general, it is easier to do clustering as a
87 # fine-tuning step after the model is fully trained.
88 model.compile(
89 optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
90 loss=tf.keras.losses.sparse_categorical_crossentropy,
91 metrics=['accuracy']
92 )
alexander3c798932021-03-26 21:42:19 +000093
94 model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
95
96 # Test the trained model accuracy.
Alex Tawsedaba3cf2023-09-29 15:55:38 +010097 test_loss, test_acc = model.evaluate(x_test, y_test) # pylint: disable=unused-variable
alexander3c798932021-03-26 21:42:19 +000098 print(f"Test accuracy before clustering: {test_acc:.3f}")
99
100 # Prepare the model for clustering.
101 clustered_model = prepare_for_clustering(model)
102
103 # Continue training the model but now with clustering applied.
104 clustered_model.fit(x=x_train, y=y_train, batch_size=128, epochs=1, verbose=1, shuffle=True)
105 test_loss, test_acc = clustered_model.evaluate(x_test, y_test)
106 print(f"Test accuracy after clustering: {test_acc:.3f}")
107
108 # Remove all variables that clustering only needed in the training phase.
109 model_for_export = tfmot.clustering.keras.strip_clustering(clustered_model)
110
Alex Tawsedaba3cf2023-09-29 15:55:38 +0100111 # Apply post-training quantization on top of the clustering
112 # and save the resulting TensorFlow Lite model to file.
alexander3c798932021-03-26 21:42:19 +0000113 tflite_model = post_training_quantize(model_for_export, x_train)
114
115 tflite_models_dir = pathlib.Path('./conditioned_models/')
116 tflite_models_dir.mkdir(exist_ok=True, parents=True)
117
Alex Tawsedaba3cf2023-09-29 15:55:38 +0100118 clustered_quant_model_save_path = \
119 tflite_models_dir / 'clustered_post_training_quant_model.tflite'
alexander3c798932021-03-26 21:42:19 +0000120 with open(clustered_quant_model_save_path, 'wb') as f:
121 f.write(tflite_model)
122
Alex Tawsedaba3cf2023-09-29 15:55:38 +0100123 # Test the clustered quantized model accuracy.
124 # Save time by only testing a subset of the whole data.
alexander3c798932021-03-26 21:42:19 +0000125 num_test_samples = 1000
Alex Tawsedaba3cf2023-09-29 15:55:38 +0100126 evaluate_tflite_model(
127 clustered_quant_model_save_path,
128 x_test[0:num_test_samples],
129 y_test[0:num_test_samples]
130 )
alexander3c798932021-03-26 21:42:19 +0000131
132
133if __name__ == "__main__":
134 main()