blob: 657f9d35594514ee7b26c8d561430b6b00bef1a0 [file] [log] [blame]
Éanna Ó Catháin6c3dee42020-09-10 13:02:37 +01001# Copyright © 2020 NXP and Contributors. All rights reserved.
Pavel Macenauerd0fedae2020-04-15 14:52:57 +00002# SPDX-License-Identifier: MIT
3
4from urllib.parse import urlparse
Pavel Macenauerd0fedae2020-04-15 14:52:57 +00005from PIL import Image
Pavel Macenauer09daef82020-06-02 11:54:59 +00006from zipfile import ZipFile
7import os
Pavel Macenauerd0fedae2020-04-15 14:52:57 +00008import pyarmnn as ann
9import numpy as np
10import requests
11import argparse
12import warnings
13
Pavel Macenauer09daef82020-06-02 11:54:59 +000014DEFAULT_IMAGE_URL = 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg'
15
16
17def run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info):
18 """Runs inference on a set of images.
19
20 Args:
21 runtime: Arm NN runtime
22 net_id: Network ID
23 images: Loaded images to run inference on
24 labels: Loaded labels per class
25 input_binding_info: Network input information
26 output_binding_info: Network output information
27
28 Returns:
29 None
30 """
31 output_tensors = ann.make_output_tensors([output_binding_info])
32 for idx, im in enumerate(images):
33 # Create input tensors
34 input_tensors = ann.make_input_tensors([input_binding_info], [im])
35
36 # Run inference
37 print("Running inference({0}) ...".format(idx))
38 runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
39
40 # Process output
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000041 # output tensor has a shape (1, 1001)
42 out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0][0]
Pavel Macenauer09daef82020-06-02 11:54:59 +000043 results = np.argsort(out_tensor)[::-1]
44 print_top_n(5, results, labels, out_tensor)
45
46
47def unzip_file(filename: str):
48 """Unzips a file.
49
50 Args:
51 filename(str): Name of the file
52
53 Returns:
54 None
55 """
56 with ZipFile(filename, 'r') as zip_obj:
57 zip_obj.extractall()
58
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000059
60def parse_command_line(desc: str = ""):
61 """Adds arguments to the script.
62
63 Args:
Pavel Macenauer09daef82020-06-02 11:54:59 +000064 desc (str): Script description
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000065
66 Returns:
Pavel Macenauer09daef82020-06-02 11:54:59 +000067 Namespace: Arguments to the script command
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000068 """
69 parser = argparse.ArgumentParser(description=desc)
70 parser.add_argument("-v", "--verbose", help="Increase output verbosity",
71 action="store_true")
Pavel Macenauer09daef82020-06-02 11:54:59 +000072 parser.add_argument("-d", "--data-dir", help="Data directory which contains all the images.",
73 action="store", default="")
74 parser.add_argument("-m", "--model-dir",
Nikhil Raj5d955cf2021-04-19 16:59:48 +010075 help="Model directory which contains the model file (TFLite, ONNX).", action="store",
Pavel Macenauer09daef82020-06-02 11:54:59 +000076 default="")
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000077 return parser.parse_args()
78
79
80def __create_network(model_file: str, backends: list, parser=None):
81 """Creates a network based on a file and parser type.
82
83 Args:
Pavel Macenauer09daef82020-06-02 11:54:59 +000084 model_file (str): Path of the model file
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000085 backends (list): List of backends to use when running inference.
86 parser_type: Parser instance. (pyarmnn.ITFliteParser/pyarmnn.IOnnxParser...)
87
88 Returns:
Pavel Macenauer09daef82020-06-02 11:54:59 +000089 int: Network ID
90 IParser: TF Lite parser instance
91 IRuntime: Runtime object instance
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000092 """
93 args = parse_command_line()
94 options = ann.CreationOptions()
95 runtime = ann.IRuntime(options)
96
97 if parser is None:
98 # try to determine what parser to create based on model extension
99 _, ext = os.path.splitext(model_file)
100 if ext == ".onnx":
101 parser = ann.IOnnxParser()
102 elif ext == ".tflite":
103 parser = ann.ITfLiteParser()
104 assert (parser is not None)
105
106 network = parser.CreateNetworkFromBinaryFile(model_file)
107
108 preferred_backends = []
109 for b in backends:
110 preferred_backends.append(ann.BackendId(b))
111
112 opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
113 ann.OptimizerOptions())
114 if args.verbose:
115 for m in messages:
116 warnings.warn(m)
117
118 net_id, w = runtime.LoadNetwork(opt_network)
119 if args.verbose and w:
120 warnings.warn(w)
121
122 return net_id, parser, runtime
123
124
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +0000125def create_tflite_network(model_file: str, backends: list = ('CpuAcc', 'CpuRef')):
Jan Eilers841aca12020-08-12 14:59:06 +0100126 """Creates a network from a tflite model file.
Pavel Macenauerd0fedae2020-04-15 14:52:57 +0000127
128 Args:
129 model_file (str): Path of the model file.
130 backends (list): List of backends to use when running inference.
131
132 Returns:
133 int: Network ID.
134 int: Graph ID.
135 ITFliteParser: TF Lite parser instance.
136 IRuntime: Runtime object instance.
137 """
138 net_id, parser, runtime = __create_network(model_file, backends, ann.ITfLiteParser())
139 graph_id = parser.GetSubgraphCount() - 1
140
141 return net_id, graph_id, parser, runtime
142
143
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +0000144def create_onnx_network(model_file: str, backends: list = ('CpuAcc', 'CpuRef')):
Jan Eilers841aca12020-08-12 14:59:06 +0100145 """Creates a network from an onnx model file.
Pavel Macenauerd0fedae2020-04-15 14:52:57 +0000146
147 Args:
148 model_file (str): Path of the model file.
149 backends (list): List of backends to use when running inference.
150
151 Returns:
152 int: Network ID.
153 IOnnxParser: ONNX parser instance.
154 IRuntime: Runtime object instance.
155 """
156 return __create_network(model_file, backends, ann.IOnnxParser())
157
158
159def preprocess_default(img: Image, width: int, height: int, data_type, scale: float, mean: list,
160 stddev: list):
161 """Default preprocessing image function.
162
163 Args:
164 img (PIL.Image): PIL.Image object instance.
165 width (int): Width to resize to.
166 height (int): Height to resize to.
167 data_type: Data Type to cast the image to.
168 scale (float): Scaling value.
169 mean (list): RGB mean offset.
170 stddev (list): RGB standard deviation.
171
172 Returns:
173 np.array: Resized and preprocessed image.
174 """
175 img = img.resize((width, height), Image.BILINEAR)
176 img = img.convert('RGB')
177 img = np.array(img)
178 img = np.reshape(img, (-1, 3)) # reshape to [RGB][RGB]...
179 img = ((img / scale) - mean) / stddev
180 img = img.flatten().astype(data_type)
181 return img
182
183
184def load_images(image_files: list, input_width: int, input_height: int, data_type=np.uint8,
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +0000185 scale: float = 1., mean: list = (0., 0., 0.), stddev: list = (1., 1., 1.),
Pavel Macenauerd0fedae2020-04-15 14:52:57 +0000186 preprocess_fn=preprocess_default):
187 """Loads images, resizes and performs any additional preprocessing to run inference.
188
189 Args:
190 img (list): List of PIL.Image object instances.
191 input_width (int): Width to resize to.
192 input_height (int): Height to resize to.
193 data_type: Data Type to cast the image to.
194 scale (float): Scaling value.
195 mean (list): RGB mean offset.
196 stddev (list): RGB standard deviation.
197 preprocess_fn: Preprocessing function.
198
199 Returns:
200 np.array: Resized and preprocessed images.
201 """
202 images = []
203 for i in image_files:
204 img = Image.open(i)
205 img = preprocess_fn(img, input_width, input_height, data_type, scale, mean, stddev)
206 images.append(img)
207 return images
208
209
210def load_labels(label_file: str):
211 """Loads a labels file containing a label per line.
212
213 Args:
214 label_file (str): Labels file path.
215
216 Returns:
217 list: List of labels read from a file.
218 """
219 with open(label_file, 'r') as f:
220 labels = [l.rstrip() for l in f]
221 return labels
Pavel Macenauerd0fedae2020-04-15 14:52:57 +0000222
223
224def print_top_n(N: int, results: list, labels: list, prob: list):
225 """Prints TOP-N results
226
227 Args:
228 N (int): Result count to print.
229 results (list): Top prediction indices.
230 labels (list): A list of labels for every class.
231 prob (list): A list of probabilities for every class.
232
233 Returns:
234 None
235 """
236 assert (len(results) >= 1 and len(results) == len(labels) == len(prob))
237 for i in range(min(len(results), N)):
238 print("class={0} ; value={1}".format(labels[results[i]], prob[results[i]]))
239
240
Pavel Macenauer09daef82020-06-02 11:54:59 +0000241def download_file(url: str, force: bool = False, filename: str = None):
Pavel Macenauerd0fedae2020-04-15 14:52:57 +0000242 """Downloads a file.
243
244 Args:
245 url (str): File url.
246 force (bool): Forces to download the file even if it exists.
247 filename (str): Renames the file when set.
248
Pavel Macenauer09daef82020-06-02 11:54:59 +0000249 Raises:
250 RuntimeError: If for some reason download fails.
251
Pavel Macenauerd0fedae2020-04-15 14:52:57 +0000252 Returns:
253 str: Path to the downloaded file.
254 """
Pavel Macenauer09daef82020-06-02 11:54:59 +0000255 try:
256 if filename is None: # extract filename from url when None
257 filename = urlparse(url)
258 filename = os.path.basename(filename.path)
Pavel Macenauerd0fedae2020-04-15 14:52:57 +0000259
Pavel Macenauer09daef82020-06-02 11:54:59 +0000260 print("Downloading '{0}' from '{1}' ...".format(filename, url))
261 if not os.path.exists(filename) or force is True:
262 r = requests.get(url)
263 with open(filename, 'wb') as f:
264 f.write(r.content)
265 print("Finished.")
266 else:
267 print("File already exists.")
268 except:
269 raise RuntimeError("Unable to download file.")
Pavel Macenauerd0fedae2020-04-15 14:52:57 +0000270
271 return filename
Pavel Macenauer09daef82020-06-02 11:54:59 +0000272
273
274def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str = None, download_url: str = None):
275 """Gets model and labels.
276
277 Args:
278 model_dir(str): Folder in which model and label files can be found
279 model (str): Name of the model file
280 labels (str): Name of the labels file
281 archive (str): Name of the archive file (optional - need to provide only labels and model)
282 download_url(str or list): Archive url or urls if multiple files (optional - to to provide only to download it)
283
284 Returns:
285 tuple (str, str): Output label and model filenames
286 """
287 labels = os.path.join(model_dir, labels)
288 model = os.path.join(model_dir, model)
289
290 if os.path.exists(labels) and os.path.exists(model):
291 print("Found model ({0}) and labels ({1}).".format(model, labels))
292 elif archive is not None and os.path.exists(os.path.join(model_dir, archive)):
293 print("Found archive ({0}). Unzipping ...".format(archive))
294 unzip_file(archive)
295 elif download_url is not None:
296 print("Model, labels or archive not found. Downloading ...".format(archive))
297 try:
298 if isinstance(download_url, str):
299 download_url = [download_url]
300 for dl in download_url:
301 archive = download_file(dl)
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +0000302 if dl.lower().endswith(".zip"):
303 unzip_file(archive)
Pavel Macenauer09daef82020-06-02 11:54:59 +0000304 except RuntimeError:
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +0000305 print("Unable to download file ({}).".format(download_url))
Pavel Macenauer09daef82020-06-02 11:54:59 +0000306
307 if not os.path.exists(labels) or not os.path.exists(model):
308 raise RuntimeError("Unable to provide model and labels.")
309
310 return model, labels
311
312
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +0000313def list_images(folder: str = None, formats: list = ('.jpg', '.jpeg')):
Pavel Macenauer09daef82020-06-02 11:54:59 +0000314 """Lists files of a certain format in a folder.
315
316 Args:
317 folder (str): Path to the folder to search
318 formats (list): List of supported files
319
320 Returns:
321 list: A list of found files
322 """
323 files = []
324 if folder and not os.path.exists(folder):
325 print("Folder '{}' does not exist.".format(folder))
326 return files
327
328 for file in os.listdir(folder if folder else os.getcwd()):
329 for frmt in formats:
330 if file.lower().endswith(frmt):
331 files.append(os.path.join(folder, file) if folder else file)
332 break # only the format loop
333
334 return files
335
336
337def get_images(image_dir: str, image_url: str = DEFAULT_IMAGE_URL):
338 """Gets image.
339
340 Args:
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +0000341 image_dir (str): Image filename
Pavel Macenauer09daef82020-06-02 11:54:59 +0000342 image_url (str): Image url
343
344 Returns:
345 str: Output image filename
346 """
347 images = list_images(image_dir)
348 if not images and image_url is not None:
349 print("No images found. Downloading ...")
350 try:
351 images = [download_file(image_url)]
352 except RuntimeError:
353 print("Unable to download file ({0}).".format(image_url))
354
355 if not images:
356 raise RuntimeError("Unable to provide images.")
357
358 return images