PyArmNN example scripts

Change-Id: I2a5c3d291d19982c536c6b7341c01bb7c289871a
Signed-off-by: Pavel Macenauer <pavel.macenauer@nxp.com>
diff --git a/python/pyarmnn/README.md b/python/pyarmnn/README.md
index 7028ff8..a8f7573 100644
--- a/python/pyarmnn/README.md
+++ b/python/pyarmnn/README.md
@@ -122,6 +122,18 @@
 print(results)
 ```
 
+#### Examples
+
+To further explore PyArmNN API there are several examples provided in the examples folder running classification on an image. To run them first install the dependencies:
+ ```bash
+$ pip install -r examples/requirements.txt
+```
+Afterwards simply execute the example scripts, e.g.:
+ ```bash
+$ python tflite_mobilenetv1_quantized.py
+```
+All resources are downloaded during execution, so if you do not have access to the internet, you may need to download these manually. `example_utils.py` contains code shared between the examples. 
+
 # Setup development environment
 
 Before, proceeding to the next steps, make sure that:
diff --git a/python/pyarmnn/examples/example_utils.py b/python/pyarmnn/examples/example_utils.py
new file mode 100644
index 0000000..f4d1e4e
--- /dev/null
+++ b/python/pyarmnn/examples/example_utils.py
@@ -0,0 +1,221 @@
+# Copyright 2020 NXP
+# SPDX-License-Identifier: MIT
+
+from urllib.parse import urlparse
+import os
+from PIL import Image
+import pyarmnn as ann
+import numpy as np
+import requests
+import argparse
+import warnings
+
+
+def parse_command_line(desc: str = ""):
+    """Adds arguments to the script.
+
+    Args:
+        desc(str): Script description.
+
+    Returns:
+        Namespace: Arguments to the script command.
+    """
+    parser = argparse.ArgumentParser(description=desc)
+    parser.add_argument("-v", "--verbose", help="Increase output verbosity",
+                        action="store_true")
+    return parser.parse_args()
+
+
+def __create_network(model_file: str, backends: list, parser=None):
+    """Creates a network based on a file and parser type.
+
+    Args:
+        model_file (str): Path of the model file.
+        backends (list): List of backends to use when running inference.
+        parser_type: Parser instance. (pyarmnn.ITFliteParser/pyarmnn.IOnnxParser...)
+
+    Returns:
+        int: Network ID.
+        int: Graph ID.
+        IParser: TF Lite parser instance.
+        IRuntime: Runtime object instance.
+    """
+    args = parse_command_line()
+    options = ann.CreationOptions()
+    runtime = ann.IRuntime(options)
+
+    if parser is None:
+        # try to determine what parser to create based on model extension
+        _, ext = os.path.splitext(model_file)
+        if ext == ".onnx":
+            parser = ann.IOnnxParser()
+        elif ext == ".tflite":
+            parser = ann.ITfLiteParser()
+    assert (parser is not None)
+
+    network = parser.CreateNetworkFromBinaryFile(model_file)
+
+    preferred_backends = []
+    for b in backends:
+        preferred_backends.append(ann.BackendId(b))
+
+    opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
+                                         ann.OptimizerOptions())
+    if args.verbose:
+        for m in messages:
+            warnings.warn(m)
+
+    net_id, w = runtime.LoadNetwork(opt_network)
+    if args.verbose and w:
+        warnings.warn(w)
+
+    return net_id, parser, runtime
+
+
+def create_tflite_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
+    """Creates a network from an onnx model file.
+
+    Args:
+        model_file (str): Path of the model file.
+        backends (list): List of backends to use when running inference.
+
+    Returns:
+        int: Network ID.
+        int: Graph ID.
+        ITFliteParser: TF Lite parser instance.
+        IRuntime: Runtime object instance.
+    """
+    net_id, parser, runtime = __create_network(model_file, backends, ann.ITfLiteParser())
+    graph_id = parser.GetSubgraphCount() - 1
+
+    return net_id, graph_id, parser, runtime
+
+
+def create_onnx_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
+    """Creates a network from a tflite model file.
+
+    Args:
+        model_file (str): Path of the model file.
+        backends (list): List of backends to use when running inference.
+
+    Returns:
+        int: Network ID.
+        IOnnxParser: ONNX parser instance.
+        IRuntime: Runtime object instance.
+    """
+    return __create_network(model_file, backends, ann.IOnnxParser())
+
+
+def preprocess_default(img: Image, width: int, height: int, data_type, scale: float, mean: list,
+                       stddev: list):
+    """Default preprocessing image function.
+
+    Args:
+        img (PIL.Image): PIL.Image object instance.
+        width (int): Width to resize to.
+        height (int): Height to resize to.
+        data_type: Data Type to cast the image to.
+        scale (float): Scaling value.
+        mean (list): RGB mean offset.
+        stddev (list): RGB standard deviation.
+
+    Returns:
+        np.array: Resized and preprocessed image.
+    """
+    img = img.resize((width, height), Image.BILINEAR)
+    img = img.convert('RGB')
+    img = np.array(img)
+    img = np.reshape(img, (-1, 3))  # reshape to [RGB][RGB]...
+    img = ((img / scale) - mean) / stddev
+    img = img.flatten().astype(data_type)
+    return img
+
+
+def load_images(image_files: list, input_width: int, input_height: int, data_type=np.uint8,
+                scale: float = 1., mean: list = [0., 0., 0.], stddev: list = [1., 1., 1.],
+                preprocess_fn=preprocess_default):
+    """Loads images, resizes and performs any additional preprocessing to run inference.
+
+    Args:
+        img (list): List of PIL.Image object instances.
+        input_width (int): Width to resize to.
+        input_height (int): Height to resize to.
+        data_type: Data Type to cast the image to.
+        scale (float): Scaling value.
+        mean (list): RGB mean offset.
+        stddev (list): RGB standard deviation.
+        preprocess_fn: Preprocessing function.
+
+    Returns:
+        np.array: Resized and preprocessed images.
+    """
+    images = []
+    for i in image_files:
+        img = Image.open(i)
+        img = preprocess_fn(img, input_width, input_height, data_type, scale, mean, stddev)
+        images.append(img)
+    return images
+
+
+def load_labels(label_file: str):
+    """Loads a labels file containing a label per line.
+
+    Args:
+        label_file (str): Labels file path.
+
+    Returns:
+        list: List of labels read from a file.
+    """
+    with open(label_file, 'r') as f:
+        labels = [l.rstrip() for l in f]
+        return labels
+    return None
+
+
+def print_top_n(N: int, results: list, labels: list, prob: list):
+    """Prints TOP-N results
+
+    Args:
+        N (int): Result count to print.
+        results (list): Top prediction indices.
+        labels (list): A list of labels for every class.
+        prob (list): A list of probabilities for every class.
+
+    Returns:
+        None
+    """
+    assert (len(results) >= 1 and len(results) == len(labels) == len(prob))
+    for i in range(min(len(results), N)):
+        print("class={0} ; value={1}".format(labels[results[i]], prob[results[i]]))
+
+
+def download_file(url: str, force: bool = False, filename: str = None, dest: str = "tmp"):
+    """Downloads a file.
+
+    Args:
+        url (str): File url.
+        force (bool): Forces to download the file even if it exists.
+        filename (str): Renames the file when set.
+
+    Returns:
+        str: Path to the downloaded file.
+    """
+    if filename is None:  # extract filename from url when None
+        filename = urlparse(url)
+        filename = os.path.basename(filename.path)
+
+    if str is not None:
+        if not os.path.exists(dest):
+            os.makedirs(dest)
+        filename = os.path.join(dest, filename)
+
+    print("Downloading '{0}' from '{1}' ...".format(filename, url))
+    if not os.path.exists(filename) or force is True:
+        r = requests.get(url)
+        with open(filename, 'wb') as f:
+            f.write(r.content)
+        print("Finished.")
+    else:
+        print("File already exists.")
+
+    return filename
diff --git a/python/pyarmnn/examples/onnx_mobilenetv2.py b/python/pyarmnn/examples/onnx_mobilenetv2.py
new file mode 100644
index 0000000..b6d5d8c
--- /dev/null
+++ b/python/pyarmnn/examples/onnx_mobilenetv2.py
@@ -0,0 +1,86 @@
+# Copyright 2020 NXP
+# SPDX-License-Identifier: MIT
+
+import pyarmnn as ann
+import numpy as np
+from PIL import Image
+import example_utils as eu
+
+
+def preprocess_onnx(img: Image, width: int, height: int, data_type, scale: float, mean: list,
+                    stddev: list):
+    """Preprocessing function for ONNX imagenet models based on:
+    https://github.com/onnx/models/blob/master/vision/classification/imagenet_inference.ipynb
+
+    Args:
+        img (PIL.Image): Loaded PIL.Image
+        width (int): Target image width
+        height (int): Target image height
+        data_type: Image datatype (np.uint8 or np.float32)
+        scale (float): Scaling factor
+        mean: RGB mean values
+        stddev: RGB standard deviation
+
+    Returns:
+        np.array: Preprocess image as Numpy array
+    """
+    img = img.resize((256, 256), Image.BILINEAR)
+    # first rescale to 256,256 and then center crop
+    left = (256 - width) / 2
+    top = (256 - height) / 2
+    right = (256 + width) / 2
+    bottom = (256 + height) / 2
+    img = img.crop((left, top, right, bottom))
+    img = img.convert('RGB')
+    img = np.array(img)
+    img = np.reshape(img, (-1, 3))  # reshape to [RGB][RGB]...
+    img = ((img / scale) - mean) / stddev
+    # NHWC to NCHW conversion, by default NHWC is expected
+    # image is loaded as [RGB][RGB][RGB]... transposing it makes it [RRR...][GGG...][BBB...]
+    img = np.transpose(img)
+    img = img.flatten().astype(data_type)  # flatten into a 1D tensor and convert to float32
+    return img
+
+
+if __name__ == "__main__":
+    # Download resources
+    kitten_filename = eu.download_file('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
+    labels_filename = eu.download_file('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
+    model_filename = eu.download_file(
+        'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/mobilenetv2-1.0.onnx')
+
+    # Create a network from a model file
+    net_id, parser, runtime = eu.create_onnx_network(model_filename)
+
+    # Load input information from the model and create input tensors
+    input_binding_info = parser.GetNetworkInputBindingInfo("data")
+
+    # Load output information from the model and create output tensors
+    output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0")
+    output_tensors = ann.make_output_tensors([output_binding_info])
+
+    # Load labels
+    labels = eu.load_labels(labels_filename)
+
+    # Load images and resize to expected size
+    image_names = [kitten_filename]
+    images = eu.load_images(image_names,
+                            224, 224,
+                            np.float32,
+                            255.0,
+                            [0.485, 0.456, 0.406],
+                            [0.229, 0.224, 0.225],
+                            preprocess_onnx)
+
+    for idx, im in enumerate(images):
+        # Create input tensors
+        input_tensors = ann.make_input_tensors([input_binding_info], [im])
+
+        # Run inference
+        print("Running inference on '{0}' ...".format(image_names[idx]))
+        runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
+
+        # Process output
+        out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0][0]
+        results = np.argsort(out_tensor)[::-1]
+        eu.print_top_n(5, results, labels, out_tensor)
diff --git a/python/pyarmnn/examples/requirements.txt b/python/pyarmnn/examples/requirements.txt
new file mode 100644
index 0000000..9af2b27
--- /dev/null
+++ b/python/pyarmnn/examples/requirements.txt
@@ -0,0 +1,5 @@
+requests>=2.23.0
+urllib3>=1.25.8
+Pillow>=6.1.0
+numpy>=1.18.1
+pyarmnn>=19.8.0
diff --git a/python/pyarmnn/examples/tflite_mobilenetv1_quantized.py b/python/pyarmnn/examples/tflite_mobilenetv1_quantized.py
new file mode 100644
index 0000000..8cc5295
--- /dev/null
+++ b/python/pyarmnn/examples/tflite_mobilenetv1_quantized.py
@@ -0,0 +1,71 @@
+# Copyright 2020 NXP
+# SPDX-License-Identifier: MIT
+
+from zipfile import ZipFile
+import numpy as np
+import pyarmnn as ann
+import example_utils as eu
+import os
+
+
+def unzip_file(filename):
+    """Unzips a file to its current location.
+
+    Args:
+        filename (str): Name of the archive.
+
+    Returns:
+        str: Directory path of the extracted files.
+    """
+    with ZipFile(filename, 'r') as zip_obj:
+        zip_obj.extractall(os.path.dirname(filename))
+    return os.path.dirname(filename)
+
+
+if __name__ == "__main__":
+    # Download resources
+    archive_filename = eu.download_file(
+        'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip')
+    dir_path = unzip_file(archive_filename)
+    # names of the files in the archive
+    labels_filename = os.path.join(dir_path, 'labels_mobilenet_quant_v1_224.txt')
+    model_filename = os.path.join(dir_path, 'mobilenet_v1_1.0_224_quant.tflite')
+    kitten_filename = eu.download_file('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
+
+    # Create a network from the model file
+    net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename)
+
+    # Load input information from the model
+    # tflite has all the need information in the model unlike other formats
+    input_names = parser.GetSubgraphInputTensorNames(graph_id)
+    assert len(input_names) == 1  # there should be 1 input tensor in mobilenet
+
+    input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
+    input_width = input_binding_info[1].GetShape()[1]
+    input_height = input_binding_info[1].GetShape()[2]
+
+    # Load output information from the model and create output tensors
+    output_names = parser.GetSubgraphOutputTensorNames(graph_id)
+    assert len(output_names) == 1  # and only one output tensor
+    output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0])
+    output_tensors = ann.make_output_tensors([output_binding_info])
+
+    # Load labels file
+    labels = eu.load_labels(labels_filename)
+
+    # Load images and resize to expected size
+    image_names = [kitten_filename]
+    images = eu.load_images(image_names, input_width, input_height)
+
+    for idx, im in enumerate(images):
+        # Create input tensors
+        input_tensors = ann.make_input_tensors([input_binding_info], [im])
+
+        # Run inference
+        print("Running inference on '{0}' ...".format(image_names[idx]))
+        runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
+
+        # Process output
+        out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0][0]
+        results = np.argsort(out_tensor)[::-1]
+        eu.print_top_n(5, results, labels, out_tensor)