CSAF-235 Arm NN Delegate Image Classificaton Example.

* To be used in developer.arm.com Image Classification with guide.

Signed-off-by: Henri Woodcock henri.woodcock@arm.com
Change-Id: I3dd3b3b7ca3e579be9fd70900cff85c78f3da3f7
diff --git a/samples/ImageClassification/README.md b/samples/ImageClassification/README.md
new file mode 100644
index 0000000..068d0c9
--- /dev/null
+++ b/samples/ImageClassification/README.md
@@ -0,0 +1,135 @@
+# Image Classification with the Arm NN Tensorflow Lite Delegate
+
+This application demonstrates the use of the Arm NN Tensorflow Lite Delegate.
+In this application we integrate the Arm NN Tensorflow Lite Delegate into the
+TensorFlow Lite Python package.
+
+## Before You Begin
+
+This repository assumes you have built, or have downloaded the
+`libarmnnDelegate.so` and `libarmnn.so` from the GitHub releases page. You will
+also need to have built the TensorFlow Lite library from source.
+
+If you have not already installed these, please follow our guides in the ArmNN
+repository. The guide to build the delegate can be found
+[here](../../delegate/BuildGuideNative.md) and the guide to integrate the
+delegate into Python can be found
+[here](../../delegate/IntegrateDelegateIntoPython.md).
+
+
+## Getting Started
+
+Before running the application, we will first need to:
+
+- Install the required Python packages
+- Download this example
+- Download a model and corresponding label mapping
+- Download an example image
+
+1. Install required packages and Git Large File Storage (to download models
+from the Arm ML-Zoo).
+
+  ```bash
+  sudo apt-get install -y python3 python3-pip wget git git-lfs unzip
+  git lfs install
+  ```
+
+2. Clone the Arm NN repository and change directory to this example.
+
+  ```bash
+  git clone https://github.com/arm-software/armnn.git
+  cd armnn/samples/ImageClassification
+  ```
+
+3. Download your model and label mappings.
+
+  For this example we use the `MobileNetV2` model. This model can be found in
+  the Arm ML-Zoo as well as scripts to download the labels for the model.
+
+  ```bash
+  export BASEDIR=$(pwd)
+  #clone the model zoo
+  git clone https://github.com/arm-software/ml-zoo.git
+  #go to the mobilenetv2 uint8 folder
+  cd ml-zoo/models/image_classification/mobilenet_v2_1.0_224/tflite_uint8
+  #generate the labelmapping
+  ./get_class_labels.sh
+  #cd back to this project folder
+  cd BASEDIR
+  #copy your model and label mapping
+  cp ml-zoo/models/image_classification/mobilenet_v2_1.0_224/tflite_uint8/mobilenet_v2_1.0_224_quantized_1_default_1.tflite .
+  cp ml-zoo/models/image_classification/mobilenet_v2_1.0_224/tflite_uint8 labelmappings.txt .
+  ```
+
+4. Download a test image.
+
+  ```bash
+  wget -O cat.png "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
+  ```
+
+5. Download the required Python packages.
+
+  ```bash
+  pip3 install -r requirements.txt
+  ```
+
+6. Copy over your `libtensorflow_lite_all.so` and `libarmnn.so` library files
+you built/downloaded before trying this application to the application
+folder. For example:
+
+  ```bash
+  cp path/to/tensorflow/directory/tensorflow/bazel-bin/libtensorflow_lite_all.so .
+  cp /path/to/armnn/binaries/libarmnn.so .
+  ```
+
+## Folder Structure
+
+You should now have the following folder structure:
+
+```
+.
+├── README.md
+├── run_classifier.py          # script for the demo
+├── libtensorflow_lite_all.so  # tflite library built from tensorflow
+├── libarmnn.so
+├── cat.png                    # downloaded example image
+├── mobilenet_v2_1.0_224_quantized_1_default_1.tflite #tflite model from ml-zoo
+└── labelmappings.txt          # model labelmappings for output processing
+```
+
+## Run the model
+
+```bash
+python3 run_classifier.py \
+--input_image cat.png \
+--model_file mobilenet_v2_1.0_224_quantized_1_default_1.tflite \
+--label_file labelmappings.txt \
+--delegate_path /path/to/delegate/libarmnnDelegate.so.24 \
+--preferred_backends GpuAcc CpuAcc CpuRef
+```
+
+The output prediction will be printed. In this example we get:
+
+```bash
+'tabby, tabby cat'
+```
+
+## Running an inference with the Arm NN TensorFlow Lite Delegate
+
+Compared to your usual TensorFlow Lite projects, using the Arm NN TensorFlow
+Lite Delegate requires one extra step when loading in your model:
+
+```python
+import tflite_runtime.interpreter as tflite
+
+armnn_delegate = tflite.load_delegate("/path/to/delegate/libarmnnDelegate.so",
+  options={
+    "backends": "GpuAcc,CpuAcc,CpuRef",
+    "logging-severity": "info"
+  }
+)
+interpreter = tflite.Interpreter(
+  model_path="mobilenet_v2_1.0_224_quantized_1_default_1.tflite",
+  experimental_delegates=[armnn_delegate]
+)
+```
diff --git a/samples/ImageClassification/requirements.txt b/samples/ImageClassification/requirements.txt
new file mode 100644
index 0000000..3f100b2
--- /dev/null
+++ b/samples/ImageClassification/requirements.txt
@@ -0,0 +1,3 @@
+numpy==1.20.2
+Pillow==8.2.0
+pybind11==2.6.2
diff --git a/samples/ImageClassification/run_classifier.py b/samples/ImageClassification/run_classifier.py
new file mode 100644
index 0000000..1b4b9ed
--- /dev/null
+++ b/samples/ImageClassification/run_classifier.py
@@ -0,0 +1,237 @@
+import argparse
+from pathlib import Path
+from typing import Union
+
+import tflite_runtime.interpreter as tflite
+from PIL import Image
+import numpy as np
+
+
+def check_args(args: argparse.Namespace):
+    """Check the values used in the command-line have acceptable values
+
+    args:
+      - args: argparse.Namespace
+
+    returns:
+      - None
+
+    raises:
+      - FileNotFoundError: if passed files do not exist.
+      - IOError: if files are of incorrect format.
+    """
+    input_image_p = args.input_image
+    if not input_image_p.suffix in (".png", ".jpg", ".jpeg"):
+        raise IOError(
+            "--input_image option should point to an image file of the "
+            "format .jpg, .jpeg, .png"
+        )
+    if not input_image_p.exists():
+        raise FileNotFoundError("Cannot find ", input_image_p.name)
+    model_p = args.model_file
+    if not model_p.suffix == ".tflite":
+        raise IOError("--model_file should point to a tflite file.")
+    if not model_p.exists():
+        raise FileNotFoundError("Cannot find ", model_p.name)
+    label_mapping_p = args.label_file
+    if not label_mapping_p.suffix == ".txt":
+        raise IOError("--label_file expects a .txt file.")
+    if not label_mapping_p.exists():
+        raise FileNotFoundError("Cannot find ", label_mapping_p.name)
+
+    # check all args given in preferred backends make sense
+    supported_backends = ["GpuAcc", "CpuAcc", "CpuRef"]
+    if not all([backend in supported_backends for backend in args.preferred_backends]):
+        raise ValueError("Incorrect backends given. Please choose from "\
+            "'GpuAcc', 'CpuAcc', 'CpuRef'.")
+
+    return None
+
+
+def load_image(image_path: Path, model_input_dims: Union[tuple, list], grayscale: bool):
+    """load an image and put into correct format for the tensorflow lite model
+
+    args:
+      - image_path: pathlib.Path
+      - model_input_dims: tuple (or array-like). (height,width)
+
+    returns:
+      - image: np.array
+    """
+    height, width = model_input_dims
+    # load and resize image
+    image = Image.open(image_path).resize((width, height))
+    # convert to greyscale if expected
+    if grayscale:
+        image = image.convert("LA")
+
+    image = np.expand_dims(image, axis=0)
+
+    return image
+
+
+def load_delegate(delegate_path: Path, backends: list):
+    """load the armnn delegate.
+
+    args:
+      - delegate_path: pathlib.Path -> location of you libarmnnDelegate.so
+      - backends: list -> list of backends you want to use in string format
+
+    returns:
+      - armnn_delegate: tflite.delegate
+    """
+    # create a command separated string
+    backend_string = ",".join(backends)
+    # load delegate
+    armnn_delegate = tflite.load_delegate(
+        library=delegate_path,
+        options={"backends": backend_string, "logging-severity": "info"},
+    )
+
+    return armnn_delegate
+
+
+def load_tf_model(model_path: Path, armnn_delegate: tflite.Delegate):
+    """load a tflite model for use with the armnn delegate.
+
+    args:
+      - model_path: pathlib.Path
+      - armnn_delegate: tflite.TfLiteDelegate
+
+    returns:
+      - interpreter: tflite.Interpreter
+    """
+    interpreter = tflite.Interpreter(
+        model_path=model_path.as_posix(), experimental_delegates=[armnn_delegate]
+    )
+    interpreter.allocate_tensors()
+
+    return interpreter
+
+
+def run_inference(interpreter, input_image):
+    """Run inference on a processed input image and return the output from
+    inference.
+
+    args:
+      - interpreter: tflite_runtime.interpreter.Interpreter
+      - input_image: np.array
+
+    returns:
+      - output_data: np.array
+    """
+    # Get input and output tensors.
+    input_details = interpreter.get_input_details()
+    output_details = interpreter.get_output_details()
+    # Test model on random input data.
+    interpreter.set_tensor(input_details[0]["index"], input_image)
+    interpreter.invoke()
+    output_data = interpreter.get_tensor(output_details[0]["index"])
+
+    return output_data
+
+
+def create_mapping(label_mapping_p):
+    """Creates a Python dictionary mapping an index to a label.
+
+    label_mapping[idx] = label
+
+    args:
+      - label_mapping_p: pathlib.Path
+
+    returns:
+      - label_mapping: dict
+    """
+    idx = 0
+    label_mapping = {}
+    with open(label_mapping_p) as label_mapping_raw:
+        for line in label_mapping_raw:
+            label_mapping[idx] = line
+            idx = 1
+
+    return label_mapping
+
+
+def process_output(output_data, label_mapping):
+    """Process the output tensor into a label from the labelmapping file. Takes
+    the index of the maximum valur from the output array.
+
+    args:
+      - output_data: np.array
+      - label_mapping: dict
+
+    returns:
+      - str: labelmapping for max index.
+    """
+    idx = np.argmax(output_data[0])
+
+    return label_mapping[idx]
+
+
+def main(args):
+    """Run the inference for options passed in the command line.
+
+    args:
+      - args: argparse.Namespace
+
+    returns:
+      - None
+    """
+    # sanity check on args
+    check_args(args)
+    # load in the armnn delegate
+    armnn_delegate = load_delegate(args.delegate_path, args.preferred_backends)
+    # load tflite model
+    interpreter = load_tf_model(args.model_file, armnn_delegate)
+    # get input shape for image resizing
+    input_shape = interpreter.get_input_details()[0]["shape"]
+    height, width = input_shape[1], input_shape[2]
+    input_shape = (height, width)
+    # load input image
+    input_image = load_image(args.input_image, input_shape, False)
+    # get label mapping
+    labelmapping = create_mapping(args.label_file)
+    output_tensor = run_inference(interpreter, input_image)
+    output_prediction = process_output(output_tensor, labelmapping)
+
+    print("Prediction: ", output_prediction)
+
+    return None
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+    parser.add_argument(
+        "--input_image", help="File path of image file", type=Path, required=True
+    )
+    parser.add_argument(
+        "--model_file",
+        help="File path of the model tflite file",
+        type=Path,
+        required=True,
+    )
+    parser.add_argument(
+        "--label_file",
+        help="File path of model labelmapping file",
+        type=Path,
+        required=True,
+    )
+    parser.add_argument(
+        "--delegate_path",
+        help="File path of ArmNN delegate file",
+        type=Path,
+        required=True,
+    )
+    parser.add_argument(
+        "--preferred_backends",
+        help="list of backends in order of preference",
+        type=str,
+        nargs=" ",
+        required=False,
+        default=["CpuAcc", "CpuRef"],
+    )
+    args = parser.parse_args()
+
+    main(args)