Éanna Ó Catháin | 6c3dee4 | 2020-09-10 13:02:37 +0100 | [diff] [blame] | 1 | # Copyright © 2020 NXP and Contributors. All rights reserved. |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 2 | # SPDX-License-Identifier: MIT |
| 3 | |
| 4 | from urllib.parse import urlparse |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 5 | from PIL import Image |
Pavel Macenauer | 09daef8 | 2020-06-02 11:54:59 +0000 | [diff] [blame] | 6 | from zipfile import ZipFile |
| 7 | import os |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 8 | import pyarmnn as ann |
| 9 | import numpy as np |
| 10 | import requests |
| 11 | import argparse |
| 12 | import warnings |
| 13 | |
Pavel Macenauer | 09daef8 | 2020-06-02 11:54:59 +0000 | [diff] [blame] | 14 | DEFAULT_IMAGE_URL = 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg' |
| 15 | |
| 16 | |
| 17 | def run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info): |
| 18 | """Runs inference on a set of images. |
| 19 | |
| 20 | Args: |
| 21 | runtime: Arm NN runtime |
| 22 | net_id: Network ID |
| 23 | images: Loaded images to run inference on |
| 24 | labels: Loaded labels per class |
| 25 | input_binding_info: Network input information |
| 26 | output_binding_info: Network output information |
| 27 | |
| 28 | Returns: |
| 29 | None |
| 30 | """ |
| 31 | output_tensors = ann.make_output_tensors([output_binding_info]) |
| 32 | for idx, im in enumerate(images): |
| 33 | # Create input tensors |
| 34 | input_tensors = ann.make_input_tensors([input_binding_info], [im]) |
| 35 | |
| 36 | # Run inference |
| 37 | print("Running inference({0}) ...".format(idx)) |
| 38 | runtime.EnqueueWorkload(net_id, input_tensors, output_tensors) |
| 39 | |
| 40 | # Process output |
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 41 | # output tensor has a shape (1, 1001) |
| 42 | out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0][0] |
Pavel Macenauer | 09daef8 | 2020-06-02 11:54:59 +0000 | [diff] [blame] | 43 | results = np.argsort(out_tensor)[::-1] |
| 44 | print_top_n(5, results, labels, out_tensor) |
| 45 | |
| 46 | |
| 47 | def unzip_file(filename: str): |
| 48 | """Unzips a file. |
| 49 | |
| 50 | Args: |
| 51 | filename(str): Name of the file |
| 52 | |
| 53 | Returns: |
| 54 | None |
| 55 | """ |
| 56 | with ZipFile(filename, 'r') as zip_obj: |
| 57 | zip_obj.extractall() |
| 58 | |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 59 | |
| 60 | def parse_command_line(desc: str = ""): |
| 61 | """Adds arguments to the script. |
| 62 | |
| 63 | Args: |
Pavel Macenauer | 09daef8 | 2020-06-02 11:54:59 +0000 | [diff] [blame] | 64 | desc (str): Script description |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 65 | |
| 66 | Returns: |
Pavel Macenauer | 09daef8 | 2020-06-02 11:54:59 +0000 | [diff] [blame] | 67 | Namespace: Arguments to the script command |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 68 | """ |
| 69 | parser = argparse.ArgumentParser(description=desc) |
| 70 | parser.add_argument("-v", "--verbose", help="Increase output verbosity", |
| 71 | action="store_true") |
Pavel Macenauer | 09daef8 | 2020-06-02 11:54:59 +0000 | [diff] [blame] | 72 | parser.add_argument("-d", "--data-dir", help="Data directory which contains all the images.", |
| 73 | action="store", default="") |
| 74 | parser.add_argument("-m", "--model-dir", |
Nikhil Raj | 5d955cf | 2021-04-19 16:59:48 +0100 | [diff] [blame] | 75 | help="Model directory which contains the model file (TFLite, ONNX).", action="store", |
Pavel Macenauer | 09daef8 | 2020-06-02 11:54:59 +0000 | [diff] [blame] | 76 | default="") |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 77 | return parser.parse_args() |
| 78 | |
| 79 | |
| 80 | def __create_network(model_file: str, backends: list, parser=None): |
| 81 | """Creates a network based on a file and parser type. |
| 82 | |
| 83 | Args: |
Pavel Macenauer | 09daef8 | 2020-06-02 11:54:59 +0000 | [diff] [blame] | 84 | model_file (str): Path of the model file |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 85 | backends (list): List of backends to use when running inference. |
| 86 | parser_type: Parser instance. (pyarmnn.ITFliteParser/pyarmnn.IOnnxParser...) |
| 87 | |
| 88 | Returns: |
Pavel Macenauer | 09daef8 | 2020-06-02 11:54:59 +0000 | [diff] [blame] | 89 | int: Network ID |
| 90 | IParser: TF Lite parser instance |
| 91 | IRuntime: Runtime object instance |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 92 | """ |
| 93 | args = parse_command_line() |
| 94 | options = ann.CreationOptions() |
| 95 | runtime = ann.IRuntime(options) |
| 96 | |
| 97 | if parser is None: |
| 98 | # try to determine what parser to create based on model extension |
| 99 | _, ext = os.path.splitext(model_file) |
| 100 | if ext == ".onnx": |
| 101 | parser = ann.IOnnxParser() |
| 102 | elif ext == ".tflite": |
| 103 | parser = ann.ITfLiteParser() |
| 104 | assert (parser is not None) |
| 105 | |
| 106 | network = parser.CreateNetworkFromBinaryFile(model_file) |
| 107 | |
| 108 | preferred_backends = [] |
| 109 | for b in backends: |
| 110 | preferred_backends.append(ann.BackendId(b)) |
| 111 | |
| 112 | opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), |
| 113 | ann.OptimizerOptions()) |
| 114 | if args.verbose: |
| 115 | for m in messages: |
| 116 | warnings.warn(m) |
| 117 | |
| 118 | net_id, w = runtime.LoadNetwork(opt_network) |
| 119 | if args.verbose and w: |
| 120 | warnings.warn(w) |
| 121 | |
| 122 | return net_id, parser, runtime |
| 123 | |
| 124 | |
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 125 | def create_tflite_network(model_file: str, backends: list = ('CpuAcc', 'CpuRef')): |
Jan Eilers | 841aca1 | 2020-08-12 14:59:06 +0100 | [diff] [blame] | 126 | """Creates a network from a tflite model file. |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 127 | |
| 128 | Args: |
| 129 | model_file (str): Path of the model file. |
| 130 | backends (list): List of backends to use when running inference. |
| 131 | |
| 132 | Returns: |
| 133 | int: Network ID. |
| 134 | int: Graph ID. |
| 135 | ITFliteParser: TF Lite parser instance. |
| 136 | IRuntime: Runtime object instance. |
| 137 | """ |
| 138 | net_id, parser, runtime = __create_network(model_file, backends, ann.ITfLiteParser()) |
| 139 | graph_id = parser.GetSubgraphCount() - 1 |
| 140 | |
| 141 | return net_id, graph_id, parser, runtime |
| 142 | |
| 143 | |
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 144 | def create_onnx_network(model_file: str, backends: list = ('CpuAcc', 'CpuRef')): |
Jan Eilers | 841aca1 | 2020-08-12 14:59:06 +0100 | [diff] [blame] | 145 | """Creates a network from an onnx model file. |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 146 | |
| 147 | Args: |
| 148 | model_file (str): Path of the model file. |
| 149 | backends (list): List of backends to use when running inference. |
| 150 | |
| 151 | Returns: |
| 152 | int: Network ID. |
| 153 | IOnnxParser: ONNX parser instance. |
| 154 | IRuntime: Runtime object instance. |
| 155 | """ |
| 156 | return __create_network(model_file, backends, ann.IOnnxParser()) |
| 157 | |
| 158 | |
| 159 | def preprocess_default(img: Image, width: int, height: int, data_type, scale: float, mean: list, |
| 160 | stddev: list): |
| 161 | """Default preprocessing image function. |
| 162 | |
| 163 | Args: |
| 164 | img (PIL.Image): PIL.Image object instance. |
| 165 | width (int): Width to resize to. |
| 166 | height (int): Height to resize to. |
| 167 | data_type: Data Type to cast the image to. |
| 168 | scale (float): Scaling value. |
| 169 | mean (list): RGB mean offset. |
| 170 | stddev (list): RGB standard deviation. |
| 171 | |
| 172 | Returns: |
| 173 | np.array: Resized and preprocessed image. |
| 174 | """ |
| 175 | img = img.resize((width, height), Image.BILINEAR) |
| 176 | img = img.convert('RGB') |
| 177 | img = np.array(img) |
| 178 | img = np.reshape(img, (-1, 3)) # reshape to [RGB][RGB]... |
| 179 | img = ((img / scale) - mean) / stddev |
| 180 | img = img.flatten().astype(data_type) |
| 181 | return img |
| 182 | |
| 183 | |
| 184 | def load_images(image_files: list, input_width: int, input_height: int, data_type=np.uint8, |
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 185 | scale: float = 1., mean: list = (0., 0., 0.), stddev: list = (1., 1., 1.), |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 186 | preprocess_fn=preprocess_default): |
| 187 | """Loads images, resizes and performs any additional preprocessing to run inference. |
| 188 | |
| 189 | Args: |
| 190 | img (list): List of PIL.Image object instances. |
| 191 | input_width (int): Width to resize to. |
| 192 | input_height (int): Height to resize to. |
| 193 | data_type: Data Type to cast the image to. |
| 194 | scale (float): Scaling value. |
| 195 | mean (list): RGB mean offset. |
| 196 | stddev (list): RGB standard deviation. |
| 197 | preprocess_fn: Preprocessing function. |
| 198 | |
| 199 | Returns: |
| 200 | np.array: Resized and preprocessed images. |
| 201 | """ |
| 202 | images = [] |
| 203 | for i in image_files: |
| 204 | img = Image.open(i) |
| 205 | img = preprocess_fn(img, input_width, input_height, data_type, scale, mean, stddev) |
| 206 | images.append(img) |
| 207 | return images |
| 208 | |
| 209 | |
| 210 | def load_labels(label_file: str): |
| 211 | """Loads a labels file containing a label per line. |
| 212 | |
| 213 | Args: |
| 214 | label_file (str): Labels file path. |
| 215 | |
| 216 | Returns: |
| 217 | list: List of labels read from a file. |
| 218 | """ |
| 219 | with open(label_file, 'r') as f: |
| 220 | labels = [l.rstrip() for l in f] |
| 221 | return labels |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 222 | |
| 223 | |
| 224 | def print_top_n(N: int, results: list, labels: list, prob: list): |
| 225 | """Prints TOP-N results |
| 226 | |
| 227 | Args: |
| 228 | N (int): Result count to print. |
| 229 | results (list): Top prediction indices. |
| 230 | labels (list): A list of labels for every class. |
| 231 | prob (list): A list of probabilities for every class. |
| 232 | |
| 233 | Returns: |
| 234 | None |
| 235 | """ |
| 236 | assert (len(results) >= 1 and len(results) == len(labels) == len(prob)) |
| 237 | for i in range(min(len(results), N)): |
| 238 | print("class={0} ; value={1}".format(labels[results[i]], prob[results[i]])) |
| 239 | |
| 240 | |
Pavel Macenauer | 09daef8 | 2020-06-02 11:54:59 +0000 | [diff] [blame] | 241 | def download_file(url: str, force: bool = False, filename: str = None): |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 242 | """Downloads a file. |
| 243 | |
| 244 | Args: |
| 245 | url (str): File url. |
| 246 | force (bool): Forces to download the file even if it exists. |
| 247 | filename (str): Renames the file when set. |
| 248 | |
Pavel Macenauer | 09daef8 | 2020-06-02 11:54:59 +0000 | [diff] [blame] | 249 | Raises: |
| 250 | RuntimeError: If for some reason download fails. |
| 251 | |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 252 | Returns: |
| 253 | str: Path to the downloaded file. |
| 254 | """ |
Pavel Macenauer | 09daef8 | 2020-06-02 11:54:59 +0000 | [diff] [blame] | 255 | try: |
| 256 | if filename is None: # extract filename from url when None |
| 257 | filename = urlparse(url) |
| 258 | filename = os.path.basename(filename.path) |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 259 | |
Pavel Macenauer | 09daef8 | 2020-06-02 11:54:59 +0000 | [diff] [blame] | 260 | print("Downloading '{0}' from '{1}' ...".format(filename, url)) |
| 261 | if not os.path.exists(filename) or force is True: |
| 262 | r = requests.get(url) |
| 263 | with open(filename, 'wb') as f: |
| 264 | f.write(r.content) |
| 265 | print("Finished.") |
| 266 | else: |
| 267 | print("File already exists.") |
| 268 | except: |
| 269 | raise RuntimeError("Unable to download file.") |
Pavel Macenauer | d0fedae | 2020-04-15 14:52:57 +0000 | [diff] [blame] | 270 | |
| 271 | return filename |
Pavel Macenauer | 09daef8 | 2020-06-02 11:54:59 +0000 | [diff] [blame] | 272 | |
| 273 | |
| 274 | def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str = None, download_url: str = None): |
| 275 | """Gets model and labels. |
| 276 | |
| 277 | Args: |
| 278 | model_dir(str): Folder in which model and label files can be found |
| 279 | model (str): Name of the model file |
| 280 | labels (str): Name of the labels file |
| 281 | archive (str): Name of the archive file (optional - need to provide only labels and model) |
| 282 | download_url(str or list): Archive url or urls if multiple files (optional - to to provide only to download it) |
| 283 | |
| 284 | Returns: |
| 285 | tuple (str, str): Output label and model filenames |
| 286 | """ |
| 287 | labels = os.path.join(model_dir, labels) |
| 288 | model = os.path.join(model_dir, model) |
| 289 | |
| 290 | if os.path.exists(labels) and os.path.exists(model): |
| 291 | print("Found model ({0}) and labels ({1}).".format(model, labels)) |
| 292 | elif archive is not None and os.path.exists(os.path.join(model_dir, archive)): |
| 293 | print("Found archive ({0}). Unzipping ...".format(archive)) |
| 294 | unzip_file(archive) |
| 295 | elif download_url is not None: |
| 296 | print("Model, labels or archive not found. Downloading ...".format(archive)) |
| 297 | try: |
| 298 | if isinstance(download_url, str): |
| 299 | download_url = [download_url] |
| 300 | for dl in download_url: |
| 301 | archive = download_file(dl) |
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 302 | if dl.lower().endswith(".zip"): |
| 303 | unzip_file(archive) |
Pavel Macenauer | 09daef8 | 2020-06-02 11:54:59 +0000 | [diff] [blame] | 304 | except RuntimeError: |
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 305 | print("Unable to download file ({}).".format(download_url)) |
Pavel Macenauer | 09daef8 | 2020-06-02 11:54:59 +0000 | [diff] [blame] | 306 | |
| 307 | if not os.path.exists(labels) or not os.path.exists(model): |
| 308 | raise RuntimeError("Unable to provide model and labels.") |
| 309 | |
| 310 | return model, labels |
| 311 | |
| 312 | |
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 313 | def list_images(folder: str = None, formats: list = ('.jpg', '.jpeg')): |
Pavel Macenauer | 09daef8 | 2020-06-02 11:54:59 +0000 | [diff] [blame] | 314 | """Lists files of a certain format in a folder. |
| 315 | |
| 316 | Args: |
| 317 | folder (str): Path to the folder to search |
| 318 | formats (list): List of supported files |
| 319 | |
| 320 | Returns: |
| 321 | list: A list of found files |
| 322 | """ |
| 323 | files = [] |
| 324 | if folder and not os.path.exists(folder): |
| 325 | print("Folder '{}' does not exist.".format(folder)) |
| 326 | return files |
| 327 | |
| 328 | for file in os.listdir(folder if folder else os.getcwd()): |
| 329 | for frmt in formats: |
| 330 | if file.lower().endswith(frmt): |
| 331 | files.append(os.path.join(folder, file) if folder else file) |
| 332 | break # only the format loop |
| 333 | |
| 334 | return files |
| 335 | |
| 336 | |
| 337 | def get_images(image_dir: str, image_url: str = DEFAULT_IMAGE_URL): |
| 338 | """Gets image. |
| 339 | |
| 340 | Args: |
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 341 | image_dir (str): Image filename |
Pavel Macenauer | 09daef8 | 2020-06-02 11:54:59 +0000 | [diff] [blame] | 342 | image_url (str): Image url |
| 343 | |
| 344 | Returns: |
| 345 | str: Output image filename |
| 346 | """ |
| 347 | images = list_images(image_dir) |
| 348 | if not images and image_url is not None: |
| 349 | print("No images found. Downloading ...") |
| 350 | try: |
| 351 | images = [download_file(image_url)] |
| 352 | except RuntimeError: |
| 353 | print("Unable to download file ({0}).".format(image_url)) |
| 354 | |
| 355 | if not images: |
| 356 | raise RuntimeError("Unable to provide images.") |
| 357 | |
| 358 | return images |