Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 1 | # Copyright 2020 NXP |
| 2 | # SPDX-License-Identifier: MIT |
| 3 | |
| 4 | from urllib.parse import urlparse |
| 5 | import os |
| 6 | from PIL import Image |
| 7 | import pyarmnn as ann |
| 8 | import numpy as np |
| 9 | import requests |
| 10 | import argparse |
| 11 | import warnings |
| 12 | |
| 13 | |
| 14 | def parse_command_line(desc: str = ""): |
| 15 | """Adds arguments to the script. |
| 16 | |
| 17 | Args: |
| 18 | desc(str): Script description. |
| 19 | |
| 20 | Returns: |
| 21 | Namespace: Arguments to the script command. |
| 22 | """ |
| 23 | parser = argparse.ArgumentParser(description=desc) |
| 24 | parser.add_argument("-v", "--verbose", help="Increase output verbosity", |
| 25 | action="store_true") |
| 26 | return parser.parse_args() |
| 27 | |
| 28 | |
| 29 | def __create_network(model_file: str, backends: list, parser=None): |
| 30 | """Creates a network based on a file and parser type. |
| 31 | |
| 32 | Args: |
| 33 | model_file (str): Path of the model file. |
| 34 | backends (list): List of backends to use when running inference. |
| 35 | parser_type: Parser instance. (pyarmnn.ITFliteParser/pyarmnn.IOnnxParser...) |
| 36 | |
| 37 | Returns: |
| 38 | int: Network ID. |
| 39 | int: Graph ID. |
| 40 | IParser: TF Lite parser instance. |
| 41 | IRuntime: Runtime object instance. |
| 42 | """ |
| 43 | args = parse_command_line() |
| 44 | options = ann.CreationOptions() |
| 45 | runtime = ann.IRuntime(options) |
| 46 | |
| 47 | if parser is None: |
| 48 | # try to determine what parser to create based on model extension |
| 49 | _, ext = os.path.splitext(model_file) |
| 50 | if ext == ".onnx": |
| 51 | parser = ann.IOnnxParser() |
| 52 | elif ext == ".tflite": |
| 53 | parser = ann.ITfLiteParser() |
| 54 | assert (parser is not None) |
| 55 | |
| 56 | network = parser.CreateNetworkFromBinaryFile(model_file) |
| 57 | |
| 58 | preferred_backends = [] |
| 59 | for b in backends: |
| 60 | preferred_backends.append(ann.BackendId(b)) |
| 61 | |
| 62 | opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), |
| 63 | ann.OptimizerOptions()) |
| 64 | if args.verbose: |
| 65 | for m in messages: |
| 66 | warnings.warn(m) |
| 67 | |
| 68 | net_id, w = runtime.LoadNetwork(opt_network) |
| 69 | if args.verbose and w: |
| 70 | warnings.warn(w) |
| 71 | |
| 72 | return net_id, parser, runtime |
| 73 | |
| 74 | |
| 75 | def create_tflite_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']): |
Jan Eilers | 841aca1 | 2020-08-12 14:59:06 +0100 | [diff] [blame] | 76 | """Creates a network from a tflite model file. |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 77 | |
| 78 | Args: |
| 79 | model_file (str): Path of the model file. |
| 80 | backends (list): List of backends to use when running inference. |
| 81 | |
| 82 | Returns: |
| 83 | int: Network ID. |
| 84 | int: Graph ID. |
| 85 | ITFliteParser: TF Lite parser instance. |
| 86 | IRuntime: Runtime object instance. |
| 87 | """ |
| 88 | net_id, parser, runtime = __create_network(model_file, backends, ann.ITfLiteParser()) |
| 89 | graph_id = parser.GetSubgraphCount() - 1 |
| 90 | |
| 91 | return net_id, graph_id, parser, runtime |
| 92 | |
| 93 | |
| 94 | def create_onnx_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']): |
Jan Eilers | 841aca1 | 2020-08-12 14:59:06 +0100 | [diff] [blame] | 95 | """Creates a network from an onnx model file. |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 96 | |
| 97 | Args: |
| 98 | model_file (str): Path of the model file. |
| 99 | backends (list): List of backends to use when running inference. |
| 100 | |
| 101 | Returns: |
| 102 | int: Network ID. |
| 103 | IOnnxParser: ONNX parser instance. |
| 104 | IRuntime: Runtime object instance. |
| 105 | """ |
| 106 | return __create_network(model_file, backends, ann.IOnnxParser()) |
| 107 | |
| 108 | |
| 109 | def preprocess_default(img: Image, width: int, height: int, data_type, scale: float, mean: list, |
| 110 | stddev: list): |
| 111 | """Default preprocessing image function. |
| 112 | |
| 113 | Args: |
| 114 | img (PIL.Image): PIL.Image object instance. |
| 115 | width (int): Width to resize to. |
| 116 | height (int): Height to resize to. |
| 117 | data_type: Data Type to cast the image to. |
| 118 | scale (float): Scaling value. |
| 119 | mean (list): RGB mean offset. |
| 120 | stddev (list): RGB standard deviation. |
| 121 | |
| 122 | Returns: |
| 123 | np.array: Resized and preprocessed image. |
| 124 | """ |
| 125 | img = img.resize((width, height), Image.BILINEAR) |
| 126 | img = img.convert('RGB') |
| 127 | img = np.array(img) |
| 128 | img = np.reshape(img, (-1, 3)) # reshape to [RGB][RGB]... |
| 129 | img = ((img / scale) - mean) / stddev |
| 130 | img = img.flatten().astype(data_type) |
| 131 | return img |
| 132 | |
| 133 | |
| 134 | def load_images(image_files: list, input_width: int, input_height: int, data_type=np.uint8, |
| 135 | scale: float = 1., mean: list = [0., 0., 0.], stddev: list = [1., 1., 1.], |
| 136 | preprocess_fn=preprocess_default): |
| 137 | """Loads images, resizes and performs any additional preprocessing to run inference. |
| 138 | |
| 139 | Args: |
| 140 | img (list): List of PIL.Image object instances. |
| 141 | input_width (int): Width to resize to. |
| 142 | input_height (int): Height to resize to. |
| 143 | data_type: Data Type to cast the image to. |
| 144 | scale (float): Scaling value. |
| 145 | mean (list): RGB mean offset. |
| 146 | stddev (list): RGB standard deviation. |
| 147 | preprocess_fn: Preprocessing function. |
| 148 | |
| 149 | Returns: |
| 150 | np.array: Resized and preprocessed images. |
| 151 | """ |
| 152 | images = [] |
| 153 | for i in image_files: |
| 154 | img = Image.open(i) |
| 155 | img = preprocess_fn(img, input_width, input_height, data_type, scale, mean, stddev) |
| 156 | images.append(img) |
| 157 | return images |
| 158 | |
| 159 | |
| 160 | def load_labels(label_file: str): |
| 161 | """Loads a labels file containing a label per line. |
| 162 | |
| 163 | Args: |
| 164 | label_file (str): Labels file path. |
| 165 | |
| 166 | Returns: |
| 167 | list: List of labels read from a file. |
| 168 | """ |
| 169 | with open(label_file, 'r') as f: |
| 170 | labels = [l.rstrip() for l in f] |
| 171 | return labels |
| 172 | return None |
| 173 | |
| 174 | |
| 175 | def print_top_n(N: int, results: list, labels: list, prob: list): |
| 176 | """Prints TOP-N results |
| 177 | |
| 178 | Args: |
| 179 | N (int): Result count to print. |
| 180 | results (list): Top prediction indices. |
| 181 | labels (list): A list of labels for every class. |
| 182 | prob (list): A list of probabilities for every class. |
| 183 | |
| 184 | Returns: |
| 185 | None |
| 186 | """ |
| 187 | assert (len(results) >= 1 and len(results) == len(labels) == len(prob)) |
| 188 | for i in range(min(len(results), N)): |
| 189 | print("class={0} ; value={1}".format(labels[results[i]], prob[results[i]])) |
| 190 | |
| 191 | |
| 192 | def download_file(url: str, force: bool = False, filename: str = None, dest: str = "tmp"): |
| 193 | """Downloads a file. |
| 194 | |
| 195 | Args: |
| 196 | url (str): File url. |
| 197 | force (bool): Forces to download the file even if it exists. |
| 198 | filename (str): Renames the file when set. |
| 199 | |
| 200 | Returns: |
| 201 | str: Path to the downloaded file. |
| 202 | """ |
| 203 | if filename is None: # extract filename from url when None |
| 204 | filename = urlparse(url) |
| 205 | filename = os.path.basename(filename.path) |
| 206 | |
| 207 | if str is not None: |
| 208 | if not os.path.exists(dest): |
| 209 | os.makedirs(dest) |
| 210 | filename = os.path.join(dest, filename) |
| 211 | |
| 212 | print("Downloading '{0}' from '{1}' ...".format(filename, url)) |
| 213 | if not os.path.exists(filename) or force is True: |
| 214 | r = requests.get(url) |
| 215 | with open(filename, 'wb') as f: |
| 216 | f.write(r.content) |
| 217 | print("Finished.") |
| 218 | else: |
| 219 | print("File already exists.") |
| 220 | |
| 221 | return filename |