blob: 303b6df2f61da7a140e95a68774427188107b603 [file] [log] [blame]
Alex Tawsedaba3cf2023-09-29 15:55:38 +01001# SPDX-FileCopyrightText: Copyright 2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
alexander3c798932021-03-26 21:42:19 +00002# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""
Alex Tawsedaba3cf2023-09-29 15:55:38 +010016This script will provide you with a short example of how to perform
17magnitude-based weight pruning in TensorFlow
alexander3c798932021-03-26 21:42:19 +000018using the TensorFlow Model Optimization Toolkit.
19
Alex Tawsedaba3cf2023-09-29 15:55:38 +010020The output from this example will be a TensorFlow Lite model file
21where ~75% percent of the weights have been 'pruned' to the
alexander3c798932021-03-26 21:42:19 +000022value 0 during training - quantization has then been applied on top of this.
23
Alex Tawsedaba3cf2023-09-29 15:55:38 +010024By pruning the model we can improve compression of the model file.
25This can be essential for deploying certain models on systems
26with limited resources - such as embedded systems using Arm Ethos NPU.
27Also, if the pruned model is run on an Arm Ethos NPU then
28this pruning can improve the execution time of the model.
alexander3c798932021-03-26 21:42:19 +000029
Alex Tawsedaba3cf2023-09-29 15:55:38 +010030After pruning is complete we do post-training quantization
31to quantize the model and then generate a TensorFlow Lite file.
alexander3c798932021-03-26 21:42:19 +000032
Alex Tawsedaba3cf2023-09-29 15:55:38 +010033If you are targeting an Arm Ethos-U55 NPU then the output
34TensorFlow Lite file will also need to be passed through the Vela
alexander3c798932021-03-26 21:42:19 +000035compiler for further optimizations before it can be used.
36
Alex Tawsedaba3cf2023-09-29 15:55:38 +010037For more information on using Vela see:
38 https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
39For more information on weight pruning see:
40 https://www.tensorflow.org/model_optimization/guide/pruning
alexander3c798932021-03-26 21:42:19 +000041"""
42import pathlib
43
44import tensorflow as tf
45import tensorflow_model_optimization as tfmot
46
47from training_utils import get_data, create_model
48from post_training_quantization import post_training_quantize, evaluate_tflite_model
49
50
51def prepare_for_pruning(keras_model):
52 """Prepares a Keras model for pruning."""
53
Alex Tawsedaba3cf2023-09-29 15:55:38 +010054 # We use a constant sparsity schedule so the amount of sparsity
55 # in the model is kept at the same percent throughout training.
56 # An alternative is PolynomialDecay where sparsity
57 # can be gradually increased during training.
alexander3c798932021-03-26 21:42:19 +000058 pruning_schedule = tfmot.sparsity.keras.ConstantSparsity(target_sparsity=0.75, begin_step=0)
59
Alex Tawsedaba3cf2023-09-29 15:55:38 +010060 # Apply the pruning wrapper to the whole model
61 # so weights in every layer will get pruned.
62 # You may find that to avoid too much accuracy loss only
63 # certain non-critical layers in your model should be pruned.
64 pruning_ready_model = tfmot.sparsity.keras.prune_low_magnitude(
65 keras_model,
66 pruning_schedule=pruning_schedule
67 )
alexander3c798932021-03-26 21:42:19 +000068
69 # We must recompile the model after making it ready for pruning.
70 pruning_ready_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
71 loss=tf.keras.losses.sparse_categorical_crossentropy,
72 metrics=['accuracy'])
73
74 return pruning_ready_model
75
76
77def main():
Alex Tawsedaba3cf2023-09-29 15:55:38 +010078 """
79 Run weight pruning
80 """
alexander3c798932021-03-26 21:42:19 +000081 x_train, y_train, x_test, y_test = get_data()
82 model = create_model()
83
84 # Compile and train the model first.
Alex Tawsedaba3cf2023-09-29 15:55:38 +010085 # In general, it is easier to do pruning as a fine-tuning step
86 # after the model is fully trained.
alexander3c798932021-03-26 21:42:19 +000087 model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
88 loss=tf.keras.losses.sparse_categorical_crossentropy,
89 metrics=['accuracy'])
90
91 model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
92
93 # Test the trained model accuracy.
Alex Tawsedaba3cf2023-09-29 15:55:38 +010094 test_loss, test_acc = model.evaluate(x_test, y_test) # pylint: disable=unused-variable
alexander3c798932021-03-26 21:42:19 +000095 print(f"Test accuracy before pruning: {test_acc:.3f}")
96
97 # Prepare the model for pruning and add the pruning update callback needed in training.
98 pruned_model = prepare_for_pruning(model)
99 callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]
100
101 # Continue training the model but now with pruning applied - remember to pass in the callbacks!
Alex Tawsedaba3cf2023-09-29 15:55:38 +0100102 pruned_model.fit(
103 x=x_train,
104 y=y_train,
105 batch_size=128,
106 epochs=1,
107 verbose=1,
108 shuffle=True,
109 callbacks=callbacks
110 )
alexander3c798932021-03-26 21:42:19 +0000111 test_loss, test_acc = pruned_model.evaluate(x_test, y_test)
112 print(f"Test accuracy after pruning: {test_acc:.3f}")
113
114 # Remove all variables that pruning only needed in the training phase.
115 model_for_export = tfmot.sparsity.keras.strip_pruning(pruned_model)
116
Alex Tawsedaba3cf2023-09-29 15:55:38 +0100117 # Apply post-training quantization on top of the pruning
118 # and save the resulting TensorFlow Lite model to file.
alexander3c798932021-03-26 21:42:19 +0000119 tflite_model = post_training_quantize(model_for_export, x_train)
120
121 tflite_models_dir = pathlib.Path('./conditioned_models/')
122 tflite_models_dir.mkdir(exist_ok=True, parents=True)
123
124 pruned_quant_model_save_path = tflite_models_dir / 'pruned_post_training_quant_model.tflite'
125 with open(pruned_quant_model_save_path, 'wb') as f:
126 f.write(tflite_model)
127
Alex Tawsedaba3cf2023-09-29 15:55:38 +0100128 # Test the pruned quantized model accuracy.
129 # Save time by only testing a subset of the whole data.
alexander3c798932021-03-26 21:42:19 +0000130 num_test_samples = 1000
Alex Tawsedaba3cf2023-09-29 15:55:38 +0100131 evaluate_tflite_model(
132 pruned_quant_model_save_path,
133 x_test[0:num_test_samples],
134 y_test[0:num_test_samples]
135 )
alexander3c798932021-03-26 21:42:19 +0000136
137
138if __name__ == "__main__":
139 main()