| # Copyright 2020 NXP |
| # SPDX-License-Identifier: MIT |
| |
| from urllib.parse import urlparse |
| from PIL import Image |
| from zipfile import ZipFile |
| import os |
| import pyarmnn as ann |
| import numpy as np |
| import requests |
| import argparse |
| import warnings |
| |
| DEFAULT_IMAGE_URL = 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg' |
| |
| |
| def run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info): |
| """Runs inference on a set of images. |
| |
| Args: |
| runtime: Arm NN runtime |
| net_id: Network ID |
| images: Loaded images to run inference on |
| labels: Loaded labels per class |
| input_binding_info: Network input information |
| output_binding_info: Network output information |
| |
| Returns: |
| None |
| """ |
| output_tensors = ann.make_output_tensors([output_binding_info]) |
| for idx, im in enumerate(images): |
| # Create input tensors |
| input_tensors = ann.make_input_tensors([input_binding_info], [im]) |
| |
| # Run inference |
| print("Running inference({0}) ...".format(idx)) |
| runtime.EnqueueWorkload(net_id, input_tensors, output_tensors) |
| |
| # Process output |
| out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0] |
| results = np.argsort(out_tensor)[::-1] |
| print_top_n(5, results, labels, out_tensor) |
| |
| |
| def unzip_file(filename: str): |
| """Unzips a file. |
| |
| Args: |
| filename(str): Name of the file |
| |
| Returns: |
| None |
| """ |
| with ZipFile(filename, 'r') as zip_obj: |
| zip_obj.extractall() |
| |
| |
| def parse_command_line(desc: str = ""): |
| """Adds arguments to the script. |
| |
| Args: |
| desc (str): Script description |
| |
| Returns: |
| Namespace: Arguments to the script command |
| """ |
| parser = argparse.ArgumentParser(description=desc) |
| parser.add_argument("-v", "--verbose", help="Increase output verbosity", |
| action="store_true") |
| parser.add_argument("-d", "--data-dir", help="Data directory which contains all the images.", |
| action="store", default="") |
| parser.add_argument("-m", "--model-dir", |
| help="Model directory which contains the model file (TF, TFLite, ONNX, Caffe).", action="store", |
| default="") |
| return parser.parse_args() |
| |
| |
| def __create_network(model_file: str, backends: list, parser=None): |
| """Creates a network based on a file and parser type. |
| |
| Args: |
| model_file (str): Path of the model file |
| backends (list): List of backends to use when running inference. |
| parser_type: Parser instance. (pyarmnn.ITFliteParser/pyarmnn.IOnnxParser...) |
| |
| Returns: |
| int: Network ID |
| IParser: TF Lite parser instance |
| IRuntime: Runtime object instance |
| """ |
| args = parse_command_line() |
| options = ann.CreationOptions() |
| runtime = ann.IRuntime(options) |
| |
| if parser is None: |
| # try to determine what parser to create based on model extension |
| _, ext = os.path.splitext(model_file) |
| if ext == ".onnx": |
| parser = ann.IOnnxParser() |
| elif ext == ".tflite": |
| parser = ann.ITfLiteParser() |
| assert (parser is not None) |
| |
| network = parser.CreateNetworkFromBinaryFile(model_file) |
| |
| preferred_backends = [] |
| for b in backends: |
| preferred_backends.append(ann.BackendId(b)) |
| |
| opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), |
| ann.OptimizerOptions()) |
| if args.verbose: |
| for m in messages: |
| warnings.warn(m) |
| |
| net_id, w = runtime.LoadNetwork(opt_network) |
| if args.verbose and w: |
| warnings.warn(w) |
| |
| return net_id, parser, runtime |
| |
| |
| def create_tflite_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']): |
| """Creates a network from a tflite model file. |
| |
| Args: |
| model_file (str): Path of the model file. |
| backends (list): List of backends to use when running inference. |
| |
| Returns: |
| int: Network ID. |
| int: Graph ID. |
| ITFliteParser: TF Lite parser instance. |
| IRuntime: Runtime object instance. |
| """ |
| net_id, parser, runtime = __create_network(model_file, backends, ann.ITfLiteParser()) |
| graph_id = parser.GetSubgraphCount() - 1 |
| |
| return net_id, graph_id, parser, runtime |
| |
| |
| def create_onnx_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']): |
| """Creates a network from an onnx model file. |
| |
| Args: |
| model_file (str): Path of the model file. |
| backends (list): List of backends to use when running inference. |
| |
| Returns: |
| int: Network ID. |
| IOnnxParser: ONNX parser instance. |
| IRuntime: Runtime object instance. |
| """ |
| return __create_network(model_file, backends, ann.IOnnxParser()) |
| |
| |
| def preprocess_default(img: Image, width: int, height: int, data_type, scale: float, mean: list, |
| stddev: list): |
| """Default preprocessing image function. |
| |
| Args: |
| img (PIL.Image): PIL.Image object instance. |
| width (int): Width to resize to. |
| height (int): Height to resize to. |
| data_type: Data Type to cast the image to. |
| scale (float): Scaling value. |
| mean (list): RGB mean offset. |
| stddev (list): RGB standard deviation. |
| |
| Returns: |
| np.array: Resized and preprocessed image. |
| """ |
| img = img.resize((width, height), Image.BILINEAR) |
| img = img.convert('RGB') |
| img = np.array(img) |
| img = np.reshape(img, (-1, 3)) # reshape to [RGB][RGB]... |
| img = ((img / scale) - mean) / stddev |
| img = img.flatten().astype(data_type) |
| return img |
| |
| |
| def load_images(image_files: list, input_width: int, input_height: int, data_type=np.uint8, |
| scale: float = 1., mean: list = [0., 0., 0.], stddev: list = [1., 1., 1.], |
| preprocess_fn=preprocess_default): |
| """Loads images, resizes and performs any additional preprocessing to run inference. |
| |
| Args: |
| img (list): List of PIL.Image object instances. |
| input_width (int): Width to resize to. |
| input_height (int): Height to resize to. |
| data_type: Data Type to cast the image to. |
| scale (float): Scaling value. |
| mean (list): RGB mean offset. |
| stddev (list): RGB standard deviation. |
| preprocess_fn: Preprocessing function. |
| |
| Returns: |
| np.array: Resized and preprocessed images. |
| """ |
| images = [] |
| for i in image_files: |
| img = Image.open(i) |
| img = preprocess_fn(img, input_width, input_height, data_type, scale, mean, stddev) |
| images.append(img) |
| return images |
| |
| |
| def load_labels(label_file: str): |
| """Loads a labels file containing a label per line. |
| |
| Args: |
| label_file (str): Labels file path. |
| |
| Returns: |
| list: List of labels read from a file. |
| """ |
| with open(label_file, 'r') as f: |
| labels = [l.rstrip() for l in f] |
| return labels |
| return None |
| |
| |
| def print_top_n(N: int, results: list, labels: list, prob: list): |
| """Prints TOP-N results |
| |
| Args: |
| N (int): Result count to print. |
| results (list): Top prediction indices. |
| labels (list): A list of labels for every class. |
| prob (list): A list of probabilities for every class. |
| |
| Returns: |
| None |
| """ |
| assert (len(results) >= 1 and len(results) == len(labels) == len(prob)) |
| for i in range(min(len(results), N)): |
| print("class={0} ; value={1}".format(labels[results[i]], prob[results[i]])) |
| |
| |
| def download_file(url: str, force: bool = False, filename: str = None): |
| """Downloads a file. |
| |
| Args: |
| url (str): File url. |
| force (bool): Forces to download the file even if it exists. |
| filename (str): Renames the file when set. |
| |
| Raises: |
| RuntimeError: If for some reason download fails. |
| |
| Returns: |
| str: Path to the downloaded file. |
| """ |
| try: |
| if filename is None: # extract filename from url when None |
| filename = urlparse(url) |
| filename = os.path.basename(filename.path) |
| |
| print("Downloading '{0}' from '{1}' ...".format(filename, url)) |
| if not os.path.exists(filename) or force is True: |
| r = requests.get(url) |
| with open(filename, 'wb') as f: |
| f.write(r.content) |
| print("Finished.") |
| else: |
| print("File already exists.") |
| except: |
| raise RuntimeError("Unable to download file.") |
| |
| return filename |
| |
| |
| def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str = None, download_url: str = None): |
| """Gets model and labels. |
| |
| Args: |
| model_dir(str): Folder in which model and label files can be found |
| model (str): Name of the model file |
| labels (str): Name of the labels file |
| archive (str): Name of the archive file (optional - need to provide only labels and model) |
| download_url(str or list): Archive url or urls if multiple files (optional - to to provide only to download it) |
| |
| Returns: |
| tuple (str, str): Output label and model filenames |
| """ |
| labels = os.path.join(model_dir, labels) |
| model = os.path.join(model_dir, model) |
| |
| if os.path.exists(labels) and os.path.exists(model): |
| print("Found model ({0}) and labels ({1}).".format(model, labels)) |
| elif archive is not None and os.path.exists(os.path.join(model_dir, archive)): |
| print("Found archive ({0}). Unzipping ...".format(archive)) |
| unzip_file(archive) |
| elif download_url is not None: |
| print("Model, labels or archive not found. Downloading ...".format(archive)) |
| try: |
| if isinstance(download_url, str): |
| download_url = [download_url] |
| for dl in download_url: |
| archive = download_file(dl) |
| if dl.lower().endswith(".zip"): |
| unzip_file(archive) |
| except RuntimeError: |
| print("Unable to download file ({}).".format(archive_url)) |
| |
| if not os.path.exists(labels) or not os.path.exists(model): |
| raise RuntimeError("Unable to provide model and labels.") |
| |
| return model, labels |
| |
| |
| def list_images(folder: str = None, formats: list = ['.jpg', '.jpeg']): |
| """Lists files of a certain format in a folder. |
| |
| Args: |
| folder (str): Path to the folder to search |
| formats (list): List of supported files |
| |
| Returns: |
| list: A list of found files |
| """ |
| files = [] |
| if folder and not os.path.exists(folder): |
| print("Folder '{}' does not exist.".format(folder)) |
| return files |
| |
| for file in os.listdir(folder if folder else os.getcwd()): |
| for frmt in formats: |
| if file.lower().endswith(frmt): |
| files.append(os.path.join(folder, file) if folder else file) |
| break # only the format loop |
| |
| return files |
| |
| |
| def get_images(image_dir: str, image_url: str = DEFAULT_IMAGE_URL): |
| """Gets image. |
| |
| Args: |
| image (str): Image filename |
| image_url (str): Image url |
| |
| Returns: |
| str: Output image filename |
| """ |
| images = list_images(image_dir) |
| if not images and image_url is not None: |
| print("No images found. Downloading ...") |
| try: |
| images = [download_file(image_url)] |
| except RuntimeError: |
| print("Unable to download file ({0}).".format(image_url)) |
| |
| if not images: |
| raise RuntimeError("Unable to provide images.") |
| |
| return images |