blob: 5ef30f23315db3a4aa8ab7400d479f7d35f094dd [file] [log] [blame]
Pavel Macenauerd0fedae2020-04-15 14:52:57 +00001# Copyright 2020 NXP
2# SPDX-License-Identifier: MIT
3
4from urllib.parse import urlparse
5import os
6from PIL import Image
7import pyarmnn as ann
8import numpy as np
9import requests
10import argparse
11import warnings
12
13
14def parse_command_line(desc: str = ""):
15 """Adds arguments to the script.
16
17 Args:
18 desc(str): Script description.
19
20 Returns:
21 Namespace: Arguments to the script command.
22 """
23 parser = argparse.ArgumentParser(description=desc)
24 parser.add_argument("-v", "--verbose", help="Increase output verbosity",
25 action="store_true")
26 return parser.parse_args()
27
28
29def __create_network(model_file: str, backends: list, parser=None):
30 """Creates a network based on a file and parser type.
31
32 Args:
33 model_file (str): Path of the model file.
34 backends (list): List of backends to use when running inference.
35 parser_type: Parser instance. (pyarmnn.ITFliteParser/pyarmnn.IOnnxParser...)
36
37 Returns:
38 int: Network ID.
39 int: Graph ID.
40 IParser: TF Lite parser instance.
41 IRuntime: Runtime object instance.
42 """
43 args = parse_command_line()
44 options = ann.CreationOptions()
45 runtime = ann.IRuntime(options)
46
47 if parser is None:
48 # try to determine what parser to create based on model extension
49 _, ext = os.path.splitext(model_file)
50 if ext == ".onnx":
51 parser = ann.IOnnxParser()
52 elif ext == ".tflite":
53 parser = ann.ITfLiteParser()
54 assert (parser is not None)
55
56 network = parser.CreateNetworkFromBinaryFile(model_file)
57
58 preferred_backends = []
59 for b in backends:
60 preferred_backends.append(ann.BackendId(b))
61
62 opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
63 ann.OptimizerOptions())
64 if args.verbose:
65 for m in messages:
66 warnings.warn(m)
67
68 net_id, w = runtime.LoadNetwork(opt_network)
69 if args.verbose and w:
70 warnings.warn(w)
71
72 return net_id, parser, runtime
73
74
75def create_tflite_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
Jan Eilers841aca12020-08-12 14:59:06 +010076 """Creates a network from a tflite model file.
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000077
78 Args:
79 model_file (str): Path of the model file.
80 backends (list): List of backends to use when running inference.
81
82 Returns:
83 int: Network ID.
84 int: Graph ID.
85 ITFliteParser: TF Lite parser instance.
86 IRuntime: Runtime object instance.
87 """
88 net_id, parser, runtime = __create_network(model_file, backends, ann.ITfLiteParser())
89 graph_id = parser.GetSubgraphCount() - 1
90
91 return net_id, graph_id, parser, runtime
92
93
94def create_onnx_network(model_file: str, backends: list = ['CpuAcc', 'CpuRef']):
Jan Eilers841aca12020-08-12 14:59:06 +010095 """Creates a network from an onnx model file.
Pavel Macenauerd0fedae2020-04-15 14:52:57 +000096
97 Args:
98 model_file (str): Path of the model file.
99 backends (list): List of backends to use when running inference.
100
101 Returns:
102 int: Network ID.
103 IOnnxParser: ONNX parser instance.
104 IRuntime: Runtime object instance.
105 """
106 return __create_network(model_file, backends, ann.IOnnxParser())
107
108
109def preprocess_default(img: Image, width: int, height: int, data_type, scale: float, mean: list,
110 stddev: list):
111 """Default preprocessing image function.
112
113 Args:
114 img (PIL.Image): PIL.Image object instance.
115 width (int): Width to resize to.
116 height (int): Height to resize to.
117 data_type: Data Type to cast the image to.
118 scale (float): Scaling value.
119 mean (list): RGB mean offset.
120 stddev (list): RGB standard deviation.
121
122 Returns:
123 np.array: Resized and preprocessed image.
124 """
125 img = img.resize((width, height), Image.BILINEAR)
126 img = img.convert('RGB')
127 img = np.array(img)
128 img = np.reshape(img, (-1, 3)) # reshape to [RGB][RGB]...
129 img = ((img / scale) - mean) / stddev
130 img = img.flatten().astype(data_type)
131 return img
132
133
134def load_images(image_files: list, input_width: int, input_height: int, data_type=np.uint8,
135 scale: float = 1., mean: list = [0., 0., 0.], stddev: list = [1., 1., 1.],
136 preprocess_fn=preprocess_default):
137 """Loads images, resizes and performs any additional preprocessing to run inference.
138
139 Args:
140 img (list): List of PIL.Image object instances.
141 input_width (int): Width to resize to.
142 input_height (int): Height to resize to.
143 data_type: Data Type to cast the image to.
144 scale (float): Scaling value.
145 mean (list): RGB mean offset.
146 stddev (list): RGB standard deviation.
147 preprocess_fn: Preprocessing function.
148
149 Returns:
150 np.array: Resized and preprocessed images.
151 """
152 images = []
153 for i in image_files:
154 img = Image.open(i)
155 img = preprocess_fn(img, input_width, input_height, data_type, scale, mean, stddev)
156 images.append(img)
157 return images
158
159
160def load_labels(label_file: str):
161 """Loads a labels file containing a label per line.
162
163 Args:
164 label_file (str): Labels file path.
165
166 Returns:
167 list: List of labels read from a file.
168 """
169 with open(label_file, 'r') as f:
170 labels = [l.rstrip() for l in f]
171 return labels
172 return None
173
174
175def print_top_n(N: int, results: list, labels: list, prob: list):
176 """Prints TOP-N results
177
178 Args:
179 N (int): Result count to print.
180 results (list): Top prediction indices.
181 labels (list): A list of labels for every class.
182 prob (list): A list of probabilities for every class.
183
184 Returns:
185 None
186 """
187 assert (len(results) >= 1 and len(results) == len(labels) == len(prob))
188 for i in range(min(len(results), N)):
189 print("class={0} ; value={1}".format(labels[results[i]], prob[results[i]]))
190
191
192def download_file(url: str, force: bool = False, filename: str = None, dest: str = "tmp"):
193 """Downloads a file.
194
195 Args:
196 url (str): File url.
197 force (bool): Forces to download the file even if it exists.
198 filename (str): Renames the file when set.
199
200 Returns:
201 str: Path to the downloaded file.
202 """
203 if filename is None: # extract filename from url when None
204 filename = urlparse(url)
205 filename = os.path.basename(filename.path)
206
207 if str is not None:
208 if not os.path.exists(dest):
209 os.makedirs(dest)
210 filename = os.path.join(dest, filename)
211
212 print("Downloading '{0}' from '{1}' ...".format(filename, url))
213 if not os.path.exists(filename) or force is True:
214 r = requests.get(url)
215 with open(filename, 'wb') as f:
216 f.write(r.content)
217 print("Finished.")
218 else:
219 print("File already exists.")
220
221 return filename