blob: b3f79ea63cea544399cc938654f9e8ed3967b0b1 [file] [log] [blame]
Henri Woodcock3b38eed2021-05-19 13:41:44 +01001import argparse
2from pathlib import Path
3from typing import Union
4
5import tflite_runtime.interpreter as tflite
6from PIL import Image
7import numpy as np
8
9
10def check_args(args: argparse.Namespace):
11 """Check the values used in the command-line have acceptable values
12
13 args:
14 - args: argparse.Namespace
15
16 returns:
17 - None
18
19 raises:
20 - FileNotFoundError: if passed files do not exist.
21 - IOError: if files are of incorrect format.
22 """
23 input_image_p = args.input_image
24 if not input_image_p.suffix in (".png", ".jpg", ".jpeg"):
25 raise IOError(
26 "--input_image option should point to an image file of the "
27 "format .jpg, .jpeg, .png"
28 )
29 if not input_image_p.exists():
30 raise FileNotFoundError("Cannot find ", input_image_p.name)
31 model_p = args.model_file
32 if not model_p.suffix == ".tflite":
33 raise IOError("--model_file should point to a tflite file.")
34 if not model_p.exists():
35 raise FileNotFoundError("Cannot find ", model_p.name)
36 label_mapping_p = args.label_file
37 if not label_mapping_p.suffix == ".txt":
38 raise IOError("--label_file expects a .txt file.")
39 if not label_mapping_p.exists():
40 raise FileNotFoundError("Cannot find ", label_mapping_p.name)
41
42 # check all args given in preferred backends make sense
43 supported_backends = ["GpuAcc", "CpuAcc", "CpuRef"]
44 if not all([backend in supported_backends for backend in args.preferred_backends]):
45 raise ValueError("Incorrect backends given. Please choose from "\
46 "'GpuAcc', 'CpuAcc', 'CpuRef'.")
47
48 return None
49
50
51def load_image(image_path: Path, model_input_dims: Union[tuple, list], grayscale: bool):
52 """load an image and put into correct format for the tensorflow lite model
53
54 args:
55 - image_path: pathlib.Path
56 - model_input_dims: tuple (or array-like). (height,width)
57
58 returns:
59 - image: np.array
60 """
61 height, width = model_input_dims
62 # load and resize image
63 image = Image.open(image_path).resize((width, height))
64 # convert to greyscale if expected
65 if grayscale:
66 image = image.convert("LA")
67
68 image = np.expand_dims(image, axis=0)
69
70 return image
71
72
73def load_delegate(delegate_path: Path, backends: list):
74 """load the armnn delegate.
75
76 args:
77 - delegate_path: pathlib.Path -> location of you libarmnnDelegate.so
78 - backends: list -> list of backends you want to use in string format
79
80 returns:
81 - armnn_delegate: tflite.delegate
82 """
83 # create a command separated string
84 backend_string = ",".join(backends)
85 # load delegate
86 armnn_delegate = tflite.load_delegate(
87 library=delegate_path,
88 options={"backends": backend_string, "logging-severity": "info"},
89 )
90
91 return armnn_delegate
92
93
94def load_tf_model(model_path: Path, armnn_delegate: tflite.Delegate):
95 """load a tflite model for use with the armnn delegate.
96
97 args:
98 - model_path: pathlib.Path
99 - armnn_delegate: tflite.TfLiteDelegate
100
101 returns:
102 - interpreter: tflite.Interpreter
103 """
104 interpreter = tflite.Interpreter(
105 model_path=model_path.as_posix(), experimental_delegates=[armnn_delegate]
106 )
107 interpreter.allocate_tensors()
108
109 return interpreter
110
111
112def run_inference(interpreter, input_image):
113 """Run inference on a processed input image and return the output from
114 inference.
115
116 args:
117 - interpreter: tflite_runtime.interpreter.Interpreter
118 - input_image: np.array
119
120 returns:
121 - output_data: np.array
122 """
123 # Get input and output tensors.
124 input_details = interpreter.get_input_details()
125 output_details = interpreter.get_output_details()
126 # Test model on random input data.
127 interpreter.set_tensor(input_details[0]["index"], input_image)
128 interpreter.invoke()
129 output_data = interpreter.get_tensor(output_details[0]["index"])
130
131 return output_data
132
133
134def create_mapping(label_mapping_p):
135 """Creates a Python dictionary mapping an index to a label.
136
137 label_mapping[idx] = label
138
139 args:
140 - label_mapping_p: pathlib.Path
141
142 returns:
143 - label_mapping: dict
144 """
145 idx = 0
146 label_mapping = {}
147 with open(label_mapping_p) as label_mapping_raw:
148 for line in label_mapping_raw:
149 label_mapping[idx] = line
Henri Woodcockf1c7f002021-05-26 17:37:18 +0100150 idx += 1
Henri Woodcock3b38eed2021-05-19 13:41:44 +0100151
152 return label_mapping
153
154
155def process_output(output_data, label_mapping):
156 """Process the output tensor into a label from the labelmapping file. Takes
157 the index of the maximum valur from the output array.
158
159 args:
160 - output_data: np.array
161 - label_mapping: dict
162
163 returns:
164 - str: labelmapping for max index.
165 """
166 idx = np.argmax(output_data[0])
167
168 return label_mapping[idx]
169
170
171def main(args):
172 """Run the inference for options passed in the command line.
173
174 args:
175 - args: argparse.Namespace
176
177 returns:
178 - None
179 """
180 # sanity check on args
181 check_args(args)
182 # load in the armnn delegate
183 armnn_delegate = load_delegate(args.delegate_path, args.preferred_backends)
184 # load tflite model
185 interpreter = load_tf_model(args.model_file, armnn_delegate)
186 # get input shape for image resizing
187 input_shape = interpreter.get_input_details()[0]["shape"]
188 height, width = input_shape[1], input_shape[2]
189 input_shape = (height, width)
190 # load input image
191 input_image = load_image(args.input_image, input_shape, False)
192 # get label mapping
193 labelmapping = create_mapping(args.label_file)
194 output_tensor = run_inference(interpreter, input_image)
195 output_prediction = process_output(output_tensor, labelmapping)
196
197 print("Prediction: ", output_prediction)
198
199 return None
200
201
202if __name__ == "__main__":
203 parser = argparse.ArgumentParser(
204 formatter_class=argparse.ArgumentDefaultsHelpFormatter
205 )
206 parser.add_argument(
207 "--input_image", help="File path of image file", type=Path, required=True
208 )
209 parser.add_argument(
210 "--model_file",
211 help="File path of the model tflite file",
212 type=Path,
213 required=True,
214 )
215 parser.add_argument(
216 "--label_file",
217 help="File path of model labelmapping file",
218 type=Path,
219 required=True,
220 )
221 parser.add_argument(
222 "--delegate_path",
223 help="File path of ArmNN delegate file",
224 type=Path,
225 required=True,
226 )
227 parser.add_argument(
228 "--preferred_backends",
229 help="list of backends in order of preference",
230 type=str,
Henri Woodcockf1c7f002021-05-26 17:37:18 +0100231 nargs="+",
Henri Woodcock3b38eed2021-05-19 13:41:44 +0100232 required=False,
233 default=["CpuAcc", "CpuRef"],
234 )
235 args = parser.parse_args()
236
237 main(args)