blob: e5425dde52a2c91a896ec45d49643592813819c8 [file] [log] [blame]
Pavel Macenauerd0fedae2020-04-15 14:52:57 +00001# Copyright 2020 NXP
2# SPDX-License-Identifier: MIT
3
4from urllib.parse import urlparse
Pavel Macenauerd0fedae2020-04-15 14:52:57 +00005from PIL import Image
Pavel Macenauer09daef82020-06-02 11:54:59 +00006from zipfile import ZipFile
7import os
Pavel Macenauerd0fedae2020-04-15 14:52:57 +00008import pyarmnn as ann
9import numpy as np
10import requests
11import argparse
12import warnings
13
Pavel Macenauer09daef82020-06-02 11:54:59 +000014DEFAULT_IMAGE_URL = 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg'
15
16
17def run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info):
18 """Runs inference on a set of images.
19
20 Args:
21 runtime: Arm NN runtime
22 net_id: Network ID
23 images: Loaded images to run inference on
24 labels: Loaded labels per class
25 input_binding_info: Network input information
26 output_binding_info: Network output information
27
28 Returns:
29 None
30 """
31 output_tensors = ann.make_output_tensors([output_binding_info])
32 for idx, im in enumerate(images):
33 # Create input tensors
34 input_tensors = ann.make_input_tensors([input_binding_info], [im])
35
36 # Run inference
37 print("Running inference({0}) ...".format(idx))
38 runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
39
40 # Process output
41 out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0]
42 results = np.argsort(out_tensor)[::-1]
43 print_top_n(5, results, labels, out_tensor)
44
45
46def unzip_file(filename: str):
47 """Unzips a file.
48
49 Args:
50 filename(str): Name of the file
51
52 Returns:
53 None
54 """
55 with ZipFile(filename, 'r') as zip_obj:
56 zip_obj.extractall()
57
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000058
59def parse_command_line(desc: str = ""):
60 """Adds arguments to the script.
61
62 Args:
Pavel Macenauer09daef82020-06-02 11:54:59 +000063 desc (str): Script description
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000064
65 Returns:
Pavel Macenauer09daef82020-06-02 11:54:59 +000066 Namespace: Arguments to the script command
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000067 """
68 parser = argparse.ArgumentParser(description=desc)
69 parser.add_argument("-v", "--verbose", help="Increase output verbosity",
70 action="store_true")
Pavel Macenauer09daef82020-06-02 11:54:59 +000071 parser.add_argument("-d", "--data-dir", help="Data directory which contains all the images.",
72 action="store", default="")
73 parser.add_argument("-m", "--model-dir",
74 help="Model directory which contains the model file (TF, TFLite, ONNX, Caffe).", action="store",
75 default="")
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000076 return parser.parse_args()
77
78
79def __create_network(model_file: str, backends: list, parser=None):
80 """Creates a network based on a file and parser type.
81
82 Args:
Pavel Macenauer09daef82020-06-02 11:54:59 +000083 model_file (str): Path of the model file
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000084 backends (list): List of backends to use when running inference.
85 parser_type: Parser instance. (pyarmnn.ITFliteParser/pyarmnn.IOnnxParser...)
86
87 Returns:
Pavel Macenauer09daef82020-06-02 11:54:59 +000088 int: Network ID
89 IParser: TF Lite parser instance
90 IRuntime: Runtime object instance
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000091 """
92 args = parse_command_line()
93 options = ann.CreationOptions()
94 runtime = ann.IRuntime(options)
95
96 if parser is None:
97 # try to determine what parser to create based on model extension
98 _, ext = os.path.splitext(model_file)
99 if ext == ".onnx":
100 parser = ann.IOnnxParser()
101 elif ext == ".tflite":
102 parser = ann.ITfLiteParser()
103 assert (parser is not None)
104
105 network = parser.CreateNetworkFromBinaryFile(model_file)
106
107 preferred_backends = []
108 for b in backends:
109 preferred_backends.append(ann.BackendId(b))
110
111 opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
112 ann.OptimizerOptions())
113 if args.verbose:
114 for m in messages:
115 warnings.warn(m)
116
117 net_id, w = runtime.LoadNetwork(opt_network)
118 if args.verbose and w:
119 warnings.warn(w)
120
121 return net_id, parser, runtime
122
123
124def create_tflite_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
Jan Eilers841aca12020-08-12 14:59:06 +0100125 """Creates a network from a tflite model file.
Pavel Macenauerd0fedae2020-04-15 14:52:57 +0000126
127 Args:
128 model_file (str): Path of the model file.
129 backends (list): List of backends to use when running inference.
130
131 Returns:
132 int: Network ID.
133 int: Graph ID.
134 ITFliteParser: TF Lite parser instance.
135 IRuntime: Runtime object instance.
136 """
137 net_id, parser, runtime = __create_network(model_file, backends, ann.ITfLiteParser())
138 graph_id = parser.GetSubgraphCount() - 1
139
140 return net_id, graph_id, parser, runtime
141
142
143def create_onnx_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
Jan Eilers841aca12020-08-12 14:59:06 +0100144 """Creates a network from an onnx model file.
Pavel Macenauerd0fedae2020-04-15 14:52:57 +0000145
146 Args:
147 model_file (str): Path of the model file.
148 backends (list): List of backends to use when running inference.
149
150 Returns:
151 int: Network ID.
152 IOnnxParser: ONNX parser instance.
153 IRuntime: Runtime object instance.
154 """
155 return __create_network(model_file, backends, ann.IOnnxParser())
156
157
158def preprocess_default(img: Image, width: int, height: int, data_type, scale: float, mean: list,
159 stddev: list):
160 """Default preprocessing image function.
161
162 Args:
163 img (PIL.Image): PIL.Image object instance.
164 width (int): Width to resize to.
165 height (int): Height to resize to.
166 data_type: Data Type to cast the image to.
167 scale (float): Scaling value.
168 mean (list): RGB mean offset.
169 stddev (list): RGB standard deviation.
170
171 Returns:
172 np.array: Resized and preprocessed image.
173 """
174 img = img.resize((width, height), Image.BILINEAR)
175 img = img.convert('RGB')
176 img = np.array(img)
177 img = np.reshape(img, (-1, 3)) # reshape to [RGB][RGB]...
178 img = ((img / scale) - mean) / stddev
179 img = img.flatten().astype(data_type)
180 return img
181
182
183def load_images(image_files: list, input_width: int, input_height: int, data_type=np.uint8,
184 scale: float = 1., mean: list = [0., 0., 0.], stddev: list = [1., 1., 1.],
185 preprocess_fn=preprocess_default):
186 """Loads images, resizes and performs any additional preprocessing to run inference.
187
188 Args:
189 img (list): List of PIL.Image object instances.
190 input_width (int): Width to resize to.
191 input_height (int): Height to resize to.
192 data_type: Data Type to cast the image to.
193 scale (float): Scaling value.
194 mean (list): RGB mean offset.
195 stddev (list): RGB standard deviation.
196 preprocess_fn: Preprocessing function.
197
198 Returns:
199 np.array: Resized and preprocessed images.
200 """
201 images = []
202 for i in image_files:
203 img = Image.open(i)
204 img = preprocess_fn(img, input_width, input_height, data_type, scale, mean, stddev)
205 images.append(img)
206 return images
207
208
209def load_labels(label_file: str):
210 """Loads a labels file containing a label per line.
211
212 Args:
213 label_file (str): Labels file path.
214
215 Returns:
216 list: List of labels read from a file.
217 """
218 with open(label_file, 'r') as f:
219 labels = [l.rstrip() for l in f]
220 return labels
221 return None
222
223
224def print_top_n(N: int, results: list, labels: list, prob: list):
225 """Prints TOP-N results
226
227 Args:
228 N (int): Result count to print.
229 results (list): Top prediction indices.
230 labels (list): A list of labels for every class.
231 prob (list): A list of probabilities for every class.
232
233 Returns:
234 None
235 """
236 assert (len(results) >= 1 and len(results) == len(labels) == len(prob))
237 for i in range(min(len(results), N)):
238 print("class={0} ; value={1}".format(labels[results[i]], prob[results[i]]))
239
240
Pavel Macenauer09daef82020-06-02 11:54:59 +0000241def download_file(url: str, force: bool = False, filename: str = None):
Pavel Macenauerd0fedae2020-04-15 14:52:57 +0000242 """Downloads a file.
243
244 Args:
245 url (str): File url.
246 force (bool): Forces to download the file even if it exists.
247 filename (str): Renames the file when set.
248
Pavel Macenauer09daef82020-06-02 11:54:59 +0000249 Raises:
250 RuntimeError: If for some reason download fails.
251
Pavel Macenauerd0fedae2020-04-15 14:52:57 +0000252 Returns:
253 str: Path to the downloaded file.
254 """
Pavel Macenauer09daef82020-06-02 11:54:59 +0000255 try:
256 if filename is None: # extract filename from url when None
257 filename = urlparse(url)
258 filename = os.path.basename(filename.path)
Pavel Macenauerd0fedae2020-04-15 14:52:57 +0000259
Pavel Macenauer09daef82020-06-02 11:54:59 +0000260 print("Downloading '{0}' from '{1}' ...".format(filename, url))
261 if not os.path.exists(filename) or force is True:
262 r = requests.get(url)
263 with open(filename, 'wb') as f:
264 f.write(r.content)
265 print("Finished.")
266 else:
267 print("File already exists.")
268 except:
269 raise RuntimeError("Unable to download file.")
Pavel Macenauerd0fedae2020-04-15 14:52:57 +0000270
271 return filename
Pavel Macenauer09daef82020-06-02 11:54:59 +0000272
273
274def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str = None, download_url: str = None):
275 """Gets model and labels.
276
277 Args:
278 model_dir(str): Folder in which model and label files can be found
279 model (str): Name of the model file
280 labels (str): Name of the labels file
281 archive (str): Name of the archive file (optional - need to provide only labels and model)
282 download_url(str or list): Archive url or urls if multiple files (optional - to to provide only to download it)
283
284 Returns:
285 tuple (str, str): Output label and model filenames
286 """
287 labels = os.path.join(model_dir, labels)
288 model = os.path.join(model_dir, model)
289
290 if os.path.exists(labels) and os.path.exists(model):
291 print("Found model ({0}) and labels ({1}).".format(model, labels))
292 elif archive is not None and os.path.exists(os.path.join(model_dir, archive)):
293 print("Found archive ({0}). Unzipping ...".format(archive))
294 unzip_file(archive)
295 elif download_url is not None:
296 print("Model, labels or archive not found. Downloading ...".format(archive))
297 try:
298 if isinstance(download_url, str):
299 download_url = [download_url]
300 for dl in download_url:
301 archive = download_file(dl)
302 if dl.lower().endswith(".zip"):
303 unzip_file(archive)
304 except RuntimeError:
305 print("Unable to download file ({}).".format(archive_url))
306
307 if not os.path.exists(labels) or not os.path.exists(model):
308 raise RuntimeError("Unable to provide model and labels.")
309
310 return model, labels
311
312
313def list_images(folder: str = None, formats: list = ['.jpg', '.jpeg']):
314 """Lists files of a certain format in a folder.
315
316 Args:
317 folder (str): Path to the folder to search
318 formats (list): List of supported files
319
320 Returns:
321 list: A list of found files
322 """
323 files = []
324 if folder and not os.path.exists(folder):
325 print("Folder '{}' does not exist.".format(folder))
326 return files
327
328 for file in os.listdir(folder if folder else os.getcwd()):
329 for frmt in formats:
330 if file.lower().endswith(frmt):
331 files.append(os.path.join(folder, file) if folder else file)
332 break # only the format loop
333
334 return files
335
336
337def get_images(image_dir: str, image_url: str = DEFAULT_IMAGE_URL):
338 """Gets image.
339
340 Args:
341 image (str): Image filename
342 image_url (str): Image url
343
344 Returns:
345 str: Output image filename
346 """
347 images = list_images(image_dir)
348 if not images and image_url is not None:
349 print("No images found. Downloading ...")
350 try:
351 images = [download_file(image_url)]
352 except RuntimeError:
353 print("Unable to download file ({0}).".format(image_url))
354
355 if not images:
356 raise RuntimeError("Unable to provide images.")
357
358 return images