PyArmNN Updates

* Updated setup.py to raise error on mandatory ext
* Updated examples section of main readme
* Added readme to img class
* Moved img class to new subdir

Change-Id: Iea5f6d87c97e571b8ca5636268231506538840c7
Signed-off-by: Éanna Ó Catháin <eanna.ocathain@arm.com>
Signed-off-by: Jakub Sujak <jakub.sujak@arm.com>
diff --git a/python/pyarmnn/examples/image_classification/example_utils.py b/python/pyarmnn/examples/image_classification/example_utils.py
new file mode 100644
index 0000000..090ce2f
--- /dev/null
+++ b/python/pyarmnn/examples/image_classification/example_utils.py
@@ -0,0 +1,358 @@
+# Copyright © 2020 NXP and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+from urllib.parse import urlparse
+from PIL import Image
+from zipfile import ZipFile
+import os
+import pyarmnn as ann
+import numpy as np
+import requests
+import argparse
+import warnings
+
+DEFAULT_IMAGE_URL = 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg'
+
+
+def run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info):
+    """Runs inference on a set of images.
+
+    Args:
+        runtime: Arm NN runtime
+        net_id: Network ID
+        images: Loaded images to run inference on
+        labels: Loaded labels per class
+        input_binding_info: Network input information
+        output_binding_info: Network output information
+
+    Returns:
+        None
+    """
+    output_tensors = ann.make_output_tensors([output_binding_info])
+    for idx, im in enumerate(images):
+        # Create input tensors
+        input_tensors = ann.make_input_tensors([input_binding_info], [im])
+
+        # Run inference
+        print("Running inference({0}) ...".format(idx))
+        runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
+
+        # Process output
+        out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0]
+        results = np.argsort(out_tensor)[::-1]
+        print_top_n(5, results, labels, out_tensor)
+
+
+def unzip_file(filename: str):
+    """Unzips a file.
+
+    Args:
+        filename(str): Name of the file
+
+    Returns:
+        None
+    """
+    with ZipFile(filename, 'r') as zip_obj:
+        zip_obj.extractall()
+
+
+def parse_command_line(desc: str = ""):
+    """Adds arguments to the script.
+
+    Args:
+        desc (str): Script description
+
+    Returns:
+        Namespace: Arguments to the script command
+    """
+    parser = argparse.ArgumentParser(description=desc)
+    parser.add_argument("-v", "--verbose", help="Increase output verbosity",
+                        action="store_true")
+    parser.add_argument("-d", "--data-dir", help="Data directory which contains all the images.",
+                        action="store", default="")
+    parser.add_argument("-m", "--model-dir",
+                        help="Model directory which contains the model file (TF, TFLite, ONNX, Caffe).", action="store",
+                        default="")
+    return parser.parse_args()
+
+
+def __create_network(model_file: str, backends: list, parser=None):
+    """Creates a network based on a file and parser type.
+
+    Args:
+        model_file (str): Path of the model file
+        backends (list): List of backends to use when running inference.
+        parser_type: Parser instance. (pyarmnn.ITFliteParser/pyarmnn.IOnnxParser...)
+
+    Returns:
+        int: Network ID
+        IParser: TF Lite parser instance
+        IRuntime: Runtime object instance
+    """
+    args = parse_command_line()
+    options = ann.CreationOptions()
+    runtime = ann.IRuntime(options)
+
+    if parser is None:
+        # try to determine what parser to create based on model extension
+        _, ext = os.path.splitext(model_file)
+        if ext == ".onnx":
+            parser = ann.IOnnxParser()
+        elif ext == ".tflite":
+            parser = ann.ITfLiteParser()
+    assert (parser is not None)
+
+    network = parser.CreateNetworkFromBinaryFile(model_file)
+
+    preferred_backends = []
+    for b in backends:
+        preferred_backends.append(ann.BackendId(b))
+
+    opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
+                                         ann.OptimizerOptions())
+    if args.verbose:
+        for m in messages:
+            warnings.warn(m)
+
+    net_id, w = runtime.LoadNetwork(opt_network)
+    if args.verbose and w:
+        warnings.warn(w)
+
+    return net_id, parser, runtime
+
+
+def create_tflite_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
+    """Creates a network from a tflite model file.
+
+    Args:
+        model_file (str): Path of the model file.
+        backends (list): List of backends to use when running inference.
+
+    Returns:
+        int: Network ID.
+        int: Graph ID.
+        ITFliteParser: TF Lite parser instance.
+        IRuntime: Runtime object instance.
+    """
+    net_id, parser, runtime = __create_network(model_file, backends, ann.ITfLiteParser())
+    graph_id = parser.GetSubgraphCount() - 1
+
+    return net_id, graph_id, parser, runtime
+
+
+def create_onnx_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
+    """Creates a network from an onnx model file.
+
+    Args:
+        model_file (str): Path of the model file.
+        backends (list): List of backends to use when running inference.
+
+    Returns:
+        int: Network ID.
+        IOnnxParser: ONNX parser instance.
+        IRuntime: Runtime object instance.
+    """
+    return __create_network(model_file, backends, ann.IOnnxParser())
+
+
+def preprocess_default(img: Image, width: int, height: int, data_type, scale: float, mean: list,
+                       stddev: list):
+    """Default preprocessing image function.
+
+    Args:
+        img (PIL.Image): PIL.Image object instance.
+        width (int): Width to resize to.
+        height (int): Height to resize to.
+        data_type: Data Type to cast the image to.
+        scale (float): Scaling value.
+        mean (list): RGB mean offset.
+        stddev (list): RGB standard deviation.
+
+    Returns:
+        np.array: Resized and preprocessed image.
+    """
+    img = img.resize((width, height), Image.BILINEAR)
+    img = img.convert('RGB')
+    img = np.array(img)
+    img = np.reshape(img, (-1, 3))  # reshape to [RGB][RGB]...
+    img = ((img / scale) - mean) / stddev
+    img = img.flatten().astype(data_type)
+    return img
+
+
+def load_images(image_files: list, input_width: int, input_height: int, data_type=np.uint8,
+                scale: float = 1., mean: list = [0., 0., 0.], stddev: list = [1., 1., 1.],
+                preprocess_fn=preprocess_default):
+    """Loads images, resizes and performs any additional preprocessing to run inference.
+
+    Args:
+        img (list): List of PIL.Image object instances.
+        input_width (int): Width to resize to.
+        input_height (int): Height to resize to.
+        data_type: Data Type to cast the image to.
+        scale (float): Scaling value.
+        mean (list): RGB mean offset.
+        stddev (list): RGB standard deviation.
+        preprocess_fn: Preprocessing function.
+
+    Returns:
+        np.array: Resized and preprocessed images.
+    """
+    images = []
+    for i in image_files:
+        img = Image.open(i)
+        img = preprocess_fn(img, input_width, input_height, data_type, scale, mean, stddev)
+        images.append(img)
+    return images
+
+
+def load_labels(label_file: str):
+    """Loads a labels file containing a label per line.
+
+    Args:
+        label_file (str): Labels file path.
+
+    Returns:
+        list: List of labels read from a file.
+    """
+    with open(label_file, 'r') as f:
+        labels = [l.rstrip() for l in f]
+        return labels
+    return None
+
+
+def print_top_n(N: int, results: list, labels: list, prob: list):
+    """Prints TOP-N results
+
+    Args:
+        N (int): Result count to print.
+        results (list): Top prediction indices.
+        labels (list): A list of labels for every class.
+        prob (list): A list of probabilities for every class.
+
+    Returns:
+        None
+    """
+    assert (len(results) >= 1 and len(results) == len(labels) == len(prob))
+    for i in range(min(len(results), N)):
+        print("class={0} ; value={1}".format(labels[results[i]], prob[results[i]]))
+
+
+def download_file(url: str, force: bool = False, filename: str = None):
+    """Downloads a file.
+
+    Args:
+        url (str): File url.
+        force (bool): Forces to download the file even if it exists.
+        filename (str): Renames the file when set.
+
+    Raises:
+        RuntimeError: If for some reason download fails.
+
+    Returns:
+        str: Path to the downloaded file.
+    """
+    try:
+        if filename is None:  # extract filename from url when None
+            filename = urlparse(url)
+            filename = os.path.basename(filename.path)
+
+        print("Downloading '{0}' from '{1}' ...".format(filename, url))
+        if not os.path.exists(filename) or force is True:
+            r = requests.get(url)
+            with open(filename, 'wb') as f:
+                f.write(r.content)
+            print("Finished.")
+        else:
+            print("File already exists.")
+    except:
+        raise RuntimeError("Unable to download file.")
+
+    return filename
+
+
+def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str = None, download_url: str = None):
+    """Gets model and labels.
+
+    Args:
+        model_dir(str): Folder in which model and label files can be found
+        model (str): Name of the model file
+        labels (str): Name of the labels file
+        archive (str): Name of the archive file (optional - need to provide only labels and model)
+        download_url(str or list): Archive url or urls if multiple files (optional - to to provide only to download it)
+
+    Returns:
+        tuple (str, str): Output label and model filenames
+    """
+    labels = os.path.join(model_dir, labels)
+    model = os.path.join(model_dir, model)
+
+    if os.path.exists(labels) and os.path.exists(model):
+        print("Found model ({0}) and labels ({1}).".format(model, labels))
+    elif archive is not None and os.path.exists(os.path.join(model_dir, archive)):
+        print("Found archive ({0}). Unzipping ...".format(archive))
+        unzip_file(archive)
+    elif download_url is not None:
+        print("Model, labels or archive not found. Downloading ...".format(archive))
+        try:
+            if isinstance(download_url, str):
+                download_url = [download_url]
+            for dl in download_url:
+                archive = download_file(dl)
+            if dl.lower().endswith(".zip"):
+                unzip_file(archive)
+        except RuntimeError:
+            print("Unable to download file ({}).".format(archive_url))
+
+    if not os.path.exists(labels) or not os.path.exists(model):
+        raise RuntimeError("Unable to provide model and labels.")
+
+    return model, labels
+
+
+def list_images(folder: str = None, formats: list = ['.jpg', '.jpeg']):
+    """Lists files of a certain format in a folder.
+
+    Args:
+        folder (str): Path to the folder to search
+        formats (list): List of supported files
+
+    Returns:
+        list: A list of found files
+    """
+    files = []
+    if folder and not os.path.exists(folder):
+        print("Folder '{}' does not exist.".format(folder))
+        return files
+
+    for file in os.listdir(folder if folder else os.getcwd()):
+        for frmt in formats:
+            if file.lower().endswith(frmt):
+                files.append(os.path.join(folder, file) if folder else file)
+                break  # only the format loop
+
+    return files
+
+
+def get_images(image_dir: str, image_url: str = DEFAULT_IMAGE_URL):
+    """Gets image.
+
+    Args:
+        image (str): Image filename
+        image_url (str): Image url
+
+    Returns:
+        str: Output image filename
+    """
+    images = list_images(image_dir)
+    if not images and image_url is not None:
+        print("No images found. Downloading ...")
+        try:
+            images = [download_file(image_url)]
+        except RuntimeError:
+            print("Unable to download file ({0}).".format(image_url))
+
+    if not images:
+        raise RuntimeError("Unable to provide images.")
+
+    return images