Opensource ML embedded evaluation kit

Change-Id: I12e807f19f5cacad7cef82572b6dd48252fd61fd
diff --git a/model_conditioning_examples/Readme.md b/model_conditioning_examples/Readme.md
new file mode 100644
index 0000000..ede2c24
--- /dev/null
+++ b/model_conditioning_examples/Readme.md
@@ -0,0 +1,173 @@
+# Model conditioning examples
+
+- [Introduction](#introduction)
+  - [How to run](#how-to-run)
+- [Quantization](#quantization)
+  - [Post-training quantization](#post-training-quantization)
+  - [Quantization aware training](#quantization-aware-training)
+- [Weight pruning](#weight-pruning)
+- [Weight clustering](#weight-clustering)
+- [References](#references)
+
+## Introduction
+
+This folder contains short example scripts that demonstrate some methods available in TensorFlow to condition your model
+in preparation for deployment on Arm Ethos NPU.
+
+These scripts will cover three main topics:
+
+- Quantization
+- Weight clustering
+- Weight pruning
+
+The objective of these scripts is not to be a single source of knowledge on everything related to model conditioning.
+Instead the aim is to provide the reader with a quick starting point that demonstrates some commonly used tools that
+will enable models to run on Arm Ethos NPU and also optimize them to enable maximum performance from the Arm Ethos NPU.
+
+Links to more in-depth guides available on the TensorFlow website are provided in the [references](#references) section
+in this Readme.
+
+### How to run
+
+From the `model_conditioning_examples` folder run the following command:
+
+```commandline
+./setup.sh
+```
+
+This will create a Python virtual environment and install the required versions of TensorFlow and TensorFlow model
+optimization toolkit to run the examples scripts.
+
+If the virtual environment has not been activated you can do so by running:
+
+```commandline
+source ./env/bin/activate
+```
+
+You can then run the examples from the command line. For example to run the post-training quantization example:
+
+```commandline
+python ./post_training_quantization.py
+```
+
+The produced TensorFlow Lite model files will be saved in a `conditioned_models` sub-folder.
+
+## Quantization
+
+Most machine learning models are trained using 32bit floating point precision. However, Arm Ethos NPU performs
+calculations in 8bit integer precision. As a result, it is required that any model you wish to deploy on Arm Ethos NPU is
+first fully quantized to 8bits.
+
+TensorFlow provides two methods of quantization and the scripts in this folder will demonstrate these:
+
+- [Post-training quantization](./post_training_quantization.py)
+- [Quantization aware training](./quantization_aware_training.py)
+
+Both of these techniques will not only quantize weights of the the model but also the variable tensors such as model
+input and output, and the activations of each intermediate layer.
+
+For details on the quantization specification used by TensorFlow please see
+[here](https://www.tensorflow.org/lite/performance/quantization_spec).
+
+In both methods scale and zero point values are chosen to allow the floating point weights to be maximally
+represented in this reduced precision. Quantization is performed per-axis, meaning a different scale and zero point
+is used for each channel of a layer's weights.
+
+### Post-training quantization
+
+The first of the quantization methods that will be covered is called post-training quantization. As the name suggests
+this form of quantization takes place after training of your model is complete. It is also the simpler of the methods
+we will show to quantize a model.
+
+In post-training quantization, first the weights of the model are quantized to 8bit integer values. After this we
+quantize the variable tensors, such as layer activations. To do this we need to calculate the potential range of values
+that all these tensors can take.
+
+Calculating these ranges requires a small dataset that is representative of what you expect your model to see when
+it is deployed. Model inference is then performed using this representative dataset and the resulting minimum and
+maximum values for variable tensors are calculated.
+
+Only a small number of samples need to be used in this calibration dataset (around 100-500 should be enough). These
+samples can be taken from the training or validation sets.
+
+Quantizing your model can result in accuracy drops depending on your model. However for a lot of use cases the accuracy
+drop when using post-training quantization is usually minimal. After post-training quantization is complete you will
+have a fully quantized TensorFlow Lite model.
+
+If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
+compiler for further optimizations before it can be used.
+
+### Quantization aware training
+
+Depending on the model, the use of post-training quantization can result in an accuracy drop that is too large to be
+considered suitable. This is where quantization aware training can be used to improve things. Quantization aware
+training simulates the quantization of weights and activations during the inference stage of training using fake
+quantization nodes.
+
+By simulating quantization during training, the model weights will be adjusted in the backward pass so that they are
+better suited for the reduced precision of quantization. It is this simulating of quantization and adjusting of weights
+that can minimize accuracy loss incurred when quantizing. Note that quantization is only simulated
+at this stage and backward passes of training are still performed in full floating point precision.
+
+Importantly, with quantization aware training you do not have to train your model from scratch to use it. Instead, you
+can train it normally (not quantization aware) and after training is complete you can then fine-tune it using
+quantization aware training. By only fine-tuning you can save a lot of training time.
+
+As well as simulating quantization and adjusting weights, the ranges for variable tensors are captured so that the
+model can be fully quantized afterwards. Once you have finished quantization aware training the TensorFlow Lite converter is
+used to produce a fully quantized TensorFlow Lite model.
+
+If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
+compiler for further optimizations before it can be used.
+
+## Weight pruning
+
+After you have trained your deep learning model it is common to see that many of the weights in the model
+have the value of 0, and also have many values very close to 0. These weights have very little effect in network
+calculations so are safe to be removed or 'pruned' from the model. This is accomplished by setting all these weight
+values to 0, resulting in a sparse model.
+
+Compression algorithms can then take advantage of this to reduce model size in memory, which can be very important when
+deploying on small embedded systems. Moreover, Arm Ethos NPU can take advantage of model sparsity to further accelerate
+execution of a model.
+
+Training with weight pruning will force your model to have a certain percentage of its weights set (or 'pruned') to 0
+during the training phase. This is done by forcing those that are closest to 0 to become 0. Doing it during training
+guarantees your model will have a certain level of sparsity and the weights of your model can also be better adapted
+to the sparsity level chosen. This means, accuracy loss will hopefully be minimized if a large pruning percentage
+is desired.
+
+Weight pruning can be further combined with quantization so you have a model that is both pruned and quantized, meaning
+that the memory saving affects of both can be combined. Quantization then allows the model to be used with
+Arm Ethos NPU.
+
+If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
+compiler for further optimizations before it can be used.
+
+## Weight clustering
+
+Another method of model conditioning is weight clustering (also called weight sharing). With this technique, a fixed
+number of values (cluster centers) are used in each layer of a model to represent all the possible values that the
+layer's weights take. The weights in a layer will then use the value of their closest cluster center. By restricting
+the number of possible clusters, weight clustering reduces the amount of memory needed to store all the weight values
+in a model.
+
+Depending on the model and number of clusters chosen, using this kind of technique can have a negative effect on
+accuracy. To reduce the impact on accuracy you can introduce clustering during training so the models weights can be
+better adjusted to the reduced precision.
+
+Weight clustering can be further combined with quantization so you have a model that is both clustered and quantized,
+meaning that the memory saving affects of both can be combined. Quantization then allows the model to be used with
+Arm Ethos NPU.
+
+If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
+compiler for further optimizations before it can be used (see [Optimize model with Vela compiler](./building.md#optimize-custom-model-with-vela-compiler)).
+
+## References
+
+- [TensorFlow Model Optimization Toolkit](https://www.tensorflow.org/model_optimization)
+- [Post-training quantization](https://www.tensorflow.org/lite/performance/post_training_integer_quant)
+- [Quantization aware training](https://www.tensorflow.org/model_optimization/guide/quantization/training)
+- [Weight pruning](https://www.tensorflow.org/model_optimization/guide/pruning)
+- [Weight clustering](https://www.tensorflow.org/model_optimization/guide/clustering)
+- [Vela](https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/)
diff --git a/model_conditioning_examples/post_training_quantization.py b/model_conditioning_examples/post_training_quantization.py
new file mode 100644
index 0000000..ab535ac
--- /dev/null
+++ b/model_conditioning_examples/post_training_quantization.py
@@ -0,0 +1,139 @@
+#  Copyright (c) 2021 Arm Limited. All rights reserved.
+#  SPDX-License-Identifier: Apache-2.0
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+"""
+This script will provide you with an example of how to perform post-training quantization in TensorFlow.
+
+The output from this example will be a TensorFlow Lite model file where weights and activations are quantized to 8bit
+integer values.
+
+Quantization helps reduce the size of your models and is necessary for running models on certain hardware such as Arm
+Ethos NPU.
+
+In addition to quantizing weights, post-training quantization uses a calibration dataset to
+capture the minimum and maximum values of all variable tensors in your model.
+By capturing these ranges it is possible to fully quantize not just the weights of the model but also the activations.
+
+Depending on the model you are quantizing there may be some accuracy loss, but for a lot of models the loss should
+be minimal.
+
+If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
+compiler for further optimizations before it can be used.
+
+For more information on using Vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
+For more information on post-training quantization
+see: https://www.tensorflow.org/lite/performance/post_training_integer_quant
+"""
+import pathlib
+
+import numpy as np
+import tensorflow as tf
+
+from training_utils import get_data, create_model
+
+
+def post_training_quantize(keras_model, sample_data):
+    """Quantize Keras model using post-training quantization with some sample data.
+
+    TensorFlow Lite will have fp32 inputs/outputs and the model will handle quantizing/dequantizing.
+
+    Args:
+        keras_model: Keras model to quantize.
+        sample_data: A numpy array of data to use as a representative dataset.
+
+    Returns:
+        Quantized TensorFlow Lite model.
+    """
+
+    converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+
+    # We set the following converter options to ensure our model is fully quantized.
+    # An error should get thrown if there is any ops that can't be quantized.
+    converter.optimizations = [tf.lite.Optimize.DEFAULT]
+    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+
+    # To use post training quantization we must provide some sample data that will be used to
+    # calculate activation ranges for quantization. This data should be representative of the data
+    # we expect to feed the model and must be provided by a generator function.
+    def generate_repr_dataset():
+        for i in range(100):  # 100 samples is all we should need in this example.
+            yield [np.expand_dims(sample_data[i], axis=0)]
+
+    converter.representative_dataset = generate_repr_dataset
+    tflite_model = converter.convert()
+
+    return tflite_model
+
+
+def evaluate_tflite_model(tflite_save_path, x_test, y_test):
+    """Calculate the accuracy of a TensorFlow Lite model using TensorFlow Lite interpreter.
+
+    Args:
+        tflite_save_path: Path to TensorFlow Lite model to test.
+        x_test: numpy array of testing data.
+        y_test: numpy array of testing labels (sparse categorical).
+    """
+
+    interpreter = tf.lite.Interpreter(model_path=str(tflite_save_path))
+
+    interpreter.allocate_tensors()
+    input_details = interpreter.get_input_details()
+    output_details = interpreter.get_output_details()
+
+    accuracy_count = 0
+    num_test_images = len(y_test)
+
+    for i in range(num_test_images):
+        interpreter.set_tensor(input_details[0]['index'], x_test[i][np.newaxis, ...])
+        interpreter.invoke()
+        output_data = interpreter.get_tensor(output_details[0]['index'])
+
+        if np.argmax(output_data) == y_test[i]:
+            accuracy_count += 1
+
+    print(f"Test accuracy quantized: {accuracy_count / num_test_images:.3f}")
+
+
+def main():
+    x_train, y_train, x_test, y_test = get_data()
+    model = create_model()
+
+    # Compile and train the model in fp32 as normal.
+    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
+                  loss=tf.keras.losses.sparse_categorical_crossentropy,
+                  metrics=['accuracy'])
+
+    model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
+
+    # Test the fp32 model accuracy.
+    test_loss, test_acc = model.evaluate(x_test, y_test)
+    print(f"Test accuracy float: {test_acc:.3f}")
+
+    # Quantize and export the resulting TensorFlow Lite model to file.
+    tflite_model = post_training_quantize(model, x_train)
+
+    tflite_models_dir = pathlib.Path('./conditioned_models/')
+    tflite_models_dir.mkdir(exist_ok=True, parents=True)
+
+    quant_model_save_path = tflite_models_dir / 'post_training_quant_model.tflite'
+    with open(quant_model_save_path, 'wb') as f:
+        f.write(tflite_model)
+
+    # Test the quantized model accuracy. Save time by only testing a subset of the whole data.
+    num_test_samples = 1000
+    evaluate_tflite_model(quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples])
+
+
+if __name__ == "__main__":
+    main()
diff --git a/model_conditioning_examples/quantization_aware_training.py b/model_conditioning_examples/quantization_aware_training.py
new file mode 100644
index 0000000..acb768c
--- /dev/null
+++ b/model_conditioning_examples/quantization_aware_training.py
@@ -0,0 +1,139 @@
+#  Copyright (c) 2021 Arm Limited. All rights reserved.
+#  SPDX-License-Identifier: Apache-2.0
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+"""
+This script will provide you with a short example of how to perform quantization aware training in TensorFlow using the
+TensorFlow Model Optimization Toolkit.
+
+The output from this example will be a TensorFlow Lite model file where weights and activations are quantized to 8bit
+integer values.
+
+Quantization helps reduce the size of your models and is necessary for running models on certain hardware such as Arm
+Ethos NPU.
+
+In quantization aware training (QAT), the error introduced with quantizing from fp32 to int8 is simulated using
+fake quantization nodes. By simulating this quantization error when training, the model can learn better adapted
+weights and minimize accuracy losses caused by the reduced precision.
+
+Minimum and maximum values for activations are also captured during training so activations for every layer can be
+quantized along with the weights later.
+
+Quantization is only simulated during training and the training backward passes are still performed in full float
+precision. Actual quantization happens when generating a TensorFlow Lite model.
+
+If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
+compiler for further optimizations before it can be used.
+
+For more information on using vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
+For more information on quantization aware training
+see: https://www.tensorflow.org/model_optimization/guide/quantization/training
+"""
+import pathlib
+
+import numpy as np
+import tensorflow as tf
+import tensorflow_model_optimization as tfmot
+
+from training_utils import get_data, create_model
+
+
+def quantize_and_convert_to_tflite(keras_model):
+    """Quantize and convert Keras model trained with QAT to TensorFlow Lite.
+
+    TensorFlow Lite will have fp32 inputs/outputs and the model will handle quantizing/dequantizing.
+
+    Args:
+        keras_model: Keras model trained with quantization aware training.
+
+    Returns:
+        Quantized TensorFlow Lite model.
+    """
+
+    converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+
+    # After doing quantization aware training all the information for creating a fully quantized
+    # TensorFlow Lite model is already within the quantization aware Keras model.
+    # This means we only need to call convert with default optimizations to generate the quantized TensorFlow Lite model.
+    converter.optimizations = [tf.lite.Optimize.DEFAULT]
+    tflite_model = converter.convert()
+
+    return tflite_model
+
+
+def evaluate_tflite_model(tflite_save_path, x_test, y_test):
+    """Calculate the accuracy of a TensorFlow Lite model using TensorFlow Lite interpreter.
+
+    Args:
+        tflite_save_path: Path to TensorFlow Lite model to test.
+        x_test: numpy array of testing data.
+        y_test: numpy array of testing labels (sparse categorical).
+    """
+
+    interpreter = tf.lite.Interpreter(model_path=str(tflite_save_path))
+
+    interpreter.allocate_tensors()
+    input_details = interpreter.get_input_details()
+    output_details = interpreter.get_output_details()
+
+    accuracy_count = 0
+    num_test_images = len(y_test)
+
+    for i in range(num_test_images):
+        interpreter.set_tensor(input_details[0]['index'], x_test[i][np.newaxis, ...])
+        interpreter.invoke()
+        output_data = interpreter.get_tensor(output_details[0]['index'])
+
+        if np.argmax(output_data) == y_test[i]:
+            accuracy_count += 1
+
+    print(f"Test accuracy quantized: {accuracy_count / num_test_images:.3f}")
+
+
+def main():
+    x_train, y_train, x_test, y_test = get_data()
+    model = create_model()
+
+    # When working with the TensorFlow Keras API and the TF Model Optimization Toolkit we can make our
+    # model quantization aware in one line. Once this is done we compile the model and train as normal.
+    # It is important to note that the model is only quantization aware and is not quantized yet. The weights are
+    # still floating point and will only be converted to int8 when we generate the TensorFlow Lite model later on.
+    quant_aware_model = tfmot.quantization.keras.quantize_model(model)
+
+    quant_aware_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
+                              loss=tf.keras.losses.sparse_categorical_crossentropy,
+                              metrics=['accuracy'])
+
+    quant_aware_model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
+
+    # Test the quantization aware model accuracy.
+    test_loss, test_acc = quant_aware_model.evaluate(x_test, y_test)
+    print(f"Test accuracy quant aware: {test_acc:.3f}")
+
+    # Quantize and save the resulting TensorFlow Lite model to file.
+    tflite_model = quantize_and_convert_to_tflite(quant_aware_model)
+
+    tflite_models_dir = pathlib.Path('./conditioned_models/')
+    tflite_models_dir.mkdir(exist_ok=True, parents=True)
+
+    quant_model_save_path = tflite_models_dir / 'qat_quant_model.tflite'
+    with open(quant_model_save_path, 'wb') as f:
+        f.write(tflite_model)
+
+    # Test quantized model accuracy. Save time by only testing a subset of the whole data.
+    num_test_samples = 1000
+    evaluate_tflite_model(quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples])
+
+
+if __name__ == "__main__":
+    main()
diff --git a/model_conditioning_examples/requirements.txt b/model_conditioning_examples/requirements.txt
new file mode 100644
index 0000000..96e15a3
--- /dev/null
+++ b/model_conditioning_examples/requirements.txt
@@ -0,0 +1,3 @@
+tensorflow==2.4.0
+tensorflow-model-optimization==0.5.0
+numpy==1.19.5
\ No newline at end of file
diff --git a/model_conditioning_examples/setup.sh b/model_conditioning_examples/setup.sh
new file mode 100644
index 0000000..f552662
--- /dev/null
+++ b/model_conditioning_examples/setup.sh
@@ -0,0 +1,21 @@
+#----------------------------------------------------------------------------
+#  Copyright (c) 2021 Arm Limited. All rights reserved.
+#  SPDX-License-Identifier: Apache-2.0
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+#----------------------------------------------------------------------------
+#!/bin/bash
+python3 -m venv ./env
+source ./env/bin/activate
+pip install -U pip
+pip install -r requirements.txt
\ No newline at end of file
diff --git a/model_conditioning_examples/training_utils.py b/model_conditioning_examples/training_utils.py
new file mode 100644
index 0000000..3467b2a
--- /dev/null
+++ b/model_conditioning_examples/training_utils.py
@@ -0,0 +1,61 @@
+#  Copyright (c) 2021 Arm Limited. All rights reserved.
+#  SPDX-License-Identifier: Apache-2.0
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+"""
+Utility functions related to data and models that are common to all the model conditioning examples.
+"""
+import tensorflow as tf
+import numpy as np
+
+
+def get_data():
+    """Downloads and returns the pre-processed data and labels for training and testing.
+
+    Returns:
+        Tuple of: (train data, train labels, test data, test labels)
+    """
+
+    # To save time we use the MNIST dataset for this example.
+    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
+
+    # Convolution operations require data to have 4 dimensions.
+    # We divide by 255 to help training and cast to float32 for TensorFlow.
+    x_train = (x_train[..., np.newaxis] / 255.0).astype(np.float32)
+    x_test = (x_test[..., np.newaxis] / 255.0).astype(np.float32)
+
+    return x_train, y_train, x_test, y_test
+
+
+def create_model():
+    """Create and returns a simple Keras model for training MNIST.
+
+    We will use a simple convolutional neural network for this example,
+    but the model optimization methods employed should be compatible with a
+    wide variety of CNN architectures such as Mobilenet and Inception etc.
+
+    Returns:
+        Uncompiled Keras model.
+    """
+
+    keras_model = tf.keras.models.Sequential([
+        tf.keras.layers.Conv2D(32, 3, padding='same', input_shape=(28, 28, 1), activation=tf.nn.relu),
+        tf.keras.layers.Conv2D(32, 3, padding='same', activation=tf.nn.relu),
+        tf.keras.layers.MaxPool2D(),
+        tf.keras.layers.Conv2D(32, 3, padding='same', activation=tf.nn.relu),
+        tf.keras.layers.MaxPool2D(),
+        tf.keras.layers.Flatten(),
+        tf.keras.layers.Dense(units=10, activation=tf.nn.softmax)
+    ])
+
+    return keras_model
diff --git a/model_conditioning_examples/weight_clustering.py b/model_conditioning_examples/weight_clustering.py
new file mode 100644
index 0000000..54f241c
--- /dev/null
+++ b/model_conditioning_examples/weight_clustering.py
@@ -0,0 +1,107 @@
+#  Copyright (c) 2021 Arm Limited. All rights reserved.
+#  SPDX-License-Identifier: Apache-2.0
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+"""
+This script will provide you with a short example of how to perform clustering of weights (weight sharing) in
+TensorFlow using the TensorFlow Model Optimization Toolkit.
+
+The output from this example will be a TensorFlow Lite model file where weights in each layer have been 'clustered' into
+16 clusters during training - quantization has then been applied on top of this.
+
+By clustering the model we can improve compression of the model file. This can be essential for deploying certain
+models on systems with limited resources - such as embedded systems using an Arm Ethos NPU.
+
+After performing clustering we do post-training quantization to quantize the model and then generate a TensorFlow Lite file.
+
+If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
+compiler for further optimizations before it can be used.
+
+For more information on using Vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
+For more information on clustering see: https://www.tensorflow.org/model_optimization/guide/clustering
+"""
+import pathlib
+
+import tensorflow as tf
+import tensorflow_model_optimization as tfmot
+
+from training_utils import get_data, create_model
+from post_training_quantization import post_training_quantize, evaluate_tflite_model
+
+
+def prepare_for_clustering(keras_model):
+    """Prepares a Keras model for clustering."""
+
+    # Choose the number of clusters to use and how to initialize them. Using more clusters will generally
+    # reduce accuracy so you will need to find the optimal number for your use-case.
+    number_of_clusters = 16
+    cluster_centroids_init = tfmot.clustering.keras.CentroidInitialization.LINEAR
+
+    # Apply the clustering wrapper to the whole model so weights in every layer will get clustered. You may find that
+    # to avoid too much accuracy loss only certain non-critical layers in your model should be clustered.
+    clustering_ready_model = tfmot.clustering.keras.cluster_weights(keras_model,
+                                                                    number_of_clusters=number_of_clusters,
+                                                                    cluster_centroids_init=cluster_centroids_init)
+
+    # We must recompile the model after making it ready for clustering.
+    clustering_ready_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
+                                   loss=tf.keras.losses.sparse_categorical_crossentropy,
+                                   metrics=['accuracy'])
+
+    return clustering_ready_model
+
+
+def main():
+    x_train, y_train, x_test, y_test = get_data()
+    model = create_model()
+
+    # Compile and train the model first.
+    # In general it is easier to do clustering as a fine-tuning step after the model is fully trained.
+    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
+                  loss=tf.keras.losses.sparse_categorical_crossentropy,
+                  metrics=['accuracy'])
+
+    model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
+
+    # Test the trained model accuracy.
+    test_loss, test_acc = model.evaluate(x_test, y_test)
+    print(f"Test accuracy before clustering: {test_acc:.3f}")
+
+    # Prepare the model for clustering.
+    clustered_model = prepare_for_clustering(model)
+
+    # Continue training the model but now with clustering applied.
+    clustered_model.fit(x=x_train, y=y_train, batch_size=128, epochs=1, verbose=1, shuffle=True)
+    test_loss, test_acc = clustered_model.evaluate(x_test, y_test)
+    print(f"Test accuracy after clustering: {test_acc:.3f}")
+
+    # Remove all variables that clustering only needed in the training phase.
+    model_for_export = tfmot.clustering.keras.strip_clustering(clustered_model)
+
+    # Apply post-training quantization on top of the clustering and save the resulting TensorFlow Lite model to file.
+    tflite_model = post_training_quantize(model_for_export, x_train)
+
+    tflite_models_dir = pathlib.Path('./conditioned_models/')
+    tflite_models_dir.mkdir(exist_ok=True, parents=True)
+
+    clustered_quant_model_save_path = tflite_models_dir / 'clustered_post_training_quant_model.tflite'
+    with open(clustered_quant_model_save_path, 'wb') as f:
+        f.write(tflite_model)
+
+    # Test the clustered quantized model accuracy. Save time by only testing a subset of the whole data.
+    num_test_samples = 1000
+    evaluate_tflite_model(clustered_quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples])
+
+
+if __name__ == "__main__":
+    main()
diff --git a/model_conditioning_examples/weight_pruning.py b/model_conditioning_examples/weight_pruning.py
new file mode 100644
index 0000000..bf26f1f
--- /dev/null
+++ b/model_conditioning_examples/weight_pruning.py
@@ -0,0 +1,106 @@
+#  Copyright (c) 2021 Arm Limited. All rights reserved.
+#  SPDX-License-Identifier: Apache-2.0
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+"""
+This script will provide you with a short example of how to perform magnitude-based weight pruning in TensorFlow
+using the TensorFlow Model Optimization Toolkit.
+
+The output from this example will be a TensorFlow Lite model file where ~75% percent of the weights have been 'pruned' to the
+value 0 during training - quantization has then been applied on top of this.
+
+By pruning the model we can improve compression of the model file. This can be essential for deploying certain models
+on systems with limited resources - such as embedded systems using Arm Ethos NPU. Also, if the pruned model is run
+on an Arm Ethos NPU then this pruning can improve the execution time of the model.
+
+After pruning is complete we do post-training quantization to quantize the model and then generate a TensorFlow Lite file.
+
+If you are targetting an Arm Ethos-U55 NPU then the output TensorFlow Lite file will also need to be passed through the Vela
+compiler for further optimizations before it can be used.
+
+For more information on using Vela see: https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git/about/
+For more information on weight pruning see: https://www.tensorflow.org/model_optimization/guide/pruning
+"""
+import pathlib
+
+import tensorflow as tf
+import tensorflow_model_optimization as tfmot
+
+from training_utils import get_data, create_model
+from post_training_quantization import post_training_quantize, evaluate_tflite_model
+
+
+def prepare_for_pruning(keras_model):
+    """Prepares a Keras model for pruning."""
+
+    # We use a constant sparsity schedule so the amount of sparsity in the model is kept at the same percent throughout
+    # training. An alternative is PolynomialDecay where sparsity can be gradually increased during training.
+    pruning_schedule = tfmot.sparsity.keras.ConstantSparsity(target_sparsity=0.75, begin_step=0)
+
+    # Apply the pruning wrapper to the whole model so weights in every layer will get pruned. You may find that to avoid
+    # too much accuracy loss only certain non-critical layers in your model should be pruned.
+    pruning_ready_model = tfmot.sparsity.keras.prune_low_magnitude(keras_model, pruning_schedule=pruning_schedule)
+
+    # We must recompile the model after making it ready for pruning.
+    pruning_ready_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
+                                loss=tf.keras.losses.sparse_categorical_crossentropy,
+                                metrics=['accuracy'])
+
+    return pruning_ready_model
+
+
+def main():
+    x_train, y_train, x_test, y_test = get_data()
+    model = create_model()
+
+    # Compile and train the model first.
+    # In general it is easier to do pruning as a fine-tuning step after the model is fully trained.
+    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
+                  loss=tf.keras.losses.sparse_categorical_crossentropy,
+                  metrics=['accuracy'])
+
+    model.fit(x=x_train, y=y_train, batch_size=128, epochs=5, verbose=1, shuffle=True)
+
+    # Test the trained model accuracy.
+    test_loss, test_acc = model.evaluate(x_test, y_test)
+    print(f"Test accuracy before pruning: {test_acc:.3f}")
+
+    # Prepare the model for pruning and add the pruning update callback needed in training.
+    pruned_model = prepare_for_pruning(model)
+    callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]
+
+    # Continue training the model but now with pruning applied - remember to pass in the callbacks!
+    pruned_model.fit(x=x_train, y=y_train, batch_size=128, epochs=1, verbose=1, shuffle=True, callbacks=callbacks)
+    test_loss, test_acc = pruned_model.evaluate(x_test, y_test)
+    print(f"Test accuracy after pruning: {test_acc:.3f}")
+
+    # Remove all variables that pruning only needed in the training phase.
+    model_for_export = tfmot.sparsity.keras.strip_pruning(pruned_model)
+
+    # Apply post-training quantization on top of the pruning and save the resulting TensorFlow Lite model to file.
+    tflite_model = post_training_quantize(model_for_export, x_train)
+
+    tflite_models_dir = pathlib.Path('./conditioned_models/')
+    tflite_models_dir.mkdir(exist_ok=True, parents=True)
+
+    pruned_quant_model_save_path = tflite_models_dir / 'pruned_post_training_quant_model.tflite'
+    with open(pruned_quant_model_save_path, 'wb') as f:
+        f.write(tflite_model)
+
+    # Test the pruned quantized model accuracy. Save time by only testing a subset of the whole data.
+    num_test_samples = 1000
+    evaluate_tflite_model(pruned_quant_model_save_path, x_test[0:num_test_samples], y_test[0:num_test_samples])
+
+
+if __name__ == "__main__":
+    main()