blob: 4ce8b8b84efd418f18d5f389723e1d0e265d0081 [file] [log] [blame]
Jim Flynn6217c3d2022-06-14 10:58:23 +01001#
2# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3# SPDX-License-Identifier: MIT
4#
Henri Woodcock3b38eed2021-05-19 13:41:44 +01005import argparse
6from pathlib import Path
7from typing import Union
8
9import tflite_runtime.interpreter as tflite
10from PIL import Image
11import numpy as np
12
13
14def check_args(args: argparse.Namespace):
15 """Check the values used in the command-line have acceptable values
16
17 args:
18 - args: argparse.Namespace
19
20 returns:
21 - None
22
23 raises:
24 - FileNotFoundError: if passed files do not exist.
25 - IOError: if files are of incorrect format.
26 """
27 input_image_p = args.input_image
28 if not input_image_p.suffix in (".png", ".jpg", ".jpeg"):
29 raise IOError(
30 "--input_image option should point to an image file of the "
31 "format .jpg, .jpeg, .png"
32 )
33 if not input_image_p.exists():
34 raise FileNotFoundError("Cannot find ", input_image_p.name)
35 model_p = args.model_file
36 if not model_p.suffix == ".tflite":
37 raise IOError("--model_file should point to a tflite file.")
38 if not model_p.exists():
39 raise FileNotFoundError("Cannot find ", model_p.name)
40 label_mapping_p = args.label_file
41 if not label_mapping_p.suffix == ".txt":
42 raise IOError("--label_file expects a .txt file.")
43 if not label_mapping_p.exists():
44 raise FileNotFoundError("Cannot find ", label_mapping_p.name)
45
46 # check all args given in preferred backends make sense
47 supported_backends = ["GpuAcc", "CpuAcc", "CpuRef"]
48 if not all([backend in supported_backends for backend in args.preferred_backends]):
49 raise ValueError("Incorrect backends given. Please choose from "\
50 "'GpuAcc', 'CpuAcc', 'CpuRef'.")
51
52 return None
53
54
55def load_image(image_path: Path, model_input_dims: Union[tuple, list], grayscale: bool):
56 """load an image and put into correct format for the tensorflow lite model
57
58 args:
59 - image_path: pathlib.Path
60 - model_input_dims: tuple (or array-like). (height,width)
61
62 returns:
63 - image: np.array
64 """
65 height, width = model_input_dims
66 # load and resize image
67 image = Image.open(image_path).resize((width, height))
68 # convert to greyscale if expected
69 if grayscale:
70 image = image.convert("LA")
71
72 image = np.expand_dims(image, axis=0)
73
74 return image
75
76
77def load_delegate(delegate_path: Path, backends: list):
78 """load the armnn delegate.
79
80 args:
81 - delegate_path: pathlib.Path -> location of you libarmnnDelegate.so
82 - backends: list -> list of backends you want to use in string format
83
84 returns:
85 - armnn_delegate: tflite.delegate
86 """
87 # create a command separated string
88 backend_string = ",".join(backends)
89 # load delegate
90 armnn_delegate = tflite.load_delegate(
91 library=delegate_path,
92 options={"backends": backend_string, "logging-severity": "info"},
93 )
94
95 return armnn_delegate
96
97
98def load_tf_model(model_path: Path, armnn_delegate: tflite.Delegate):
99 """load a tflite model for use with the armnn delegate.
100
101 args:
102 - model_path: pathlib.Path
103 - armnn_delegate: tflite.TfLiteDelegate
104
105 returns:
106 - interpreter: tflite.Interpreter
107 """
108 interpreter = tflite.Interpreter(
109 model_path=model_path.as_posix(), experimental_delegates=[armnn_delegate]
110 )
111 interpreter.allocate_tensors()
112
113 return interpreter
114
115
116def run_inference(interpreter, input_image):
117 """Run inference on a processed input image and return the output from
118 inference.
119
120 args:
121 - interpreter: tflite_runtime.interpreter.Interpreter
122 - input_image: np.array
123
124 returns:
125 - output_data: np.array
126 """
127 # Get input and output tensors.
128 input_details = interpreter.get_input_details()
129 output_details = interpreter.get_output_details()
130 # Test model on random input data.
131 interpreter.set_tensor(input_details[0]["index"], input_image)
132 interpreter.invoke()
133 output_data = interpreter.get_tensor(output_details[0]["index"])
134
135 return output_data
136
137
138def create_mapping(label_mapping_p):
139 """Creates a Python dictionary mapping an index to a label.
140
141 label_mapping[idx] = label
142
143 args:
144 - label_mapping_p: pathlib.Path
145
146 returns:
147 - label_mapping: dict
148 """
149 idx = 0
150 label_mapping = {}
151 with open(label_mapping_p) as label_mapping_raw:
152 for line in label_mapping_raw:
153 label_mapping[idx] = line
Henri Woodcockf1c7f002021-05-26 17:37:18 +0100154 idx += 1
Henri Woodcock3b38eed2021-05-19 13:41:44 +0100155
156 return label_mapping
157
158
159def process_output(output_data, label_mapping):
160 """Process the output tensor into a label from the labelmapping file. Takes
161 the index of the maximum valur from the output array.
162
163 args:
164 - output_data: np.array
165 - label_mapping: dict
166
167 returns:
168 - str: labelmapping for max index.
169 """
170 idx = np.argmax(output_data[0])
171
172 return label_mapping[idx]
173
174
175def main(args):
176 """Run the inference for options passed in the command line.
177
178 args:
179 - args: argparse.Namespace
180
181 returns:
182 - None
183 """
184 # sanity check on args
185 check_args(args)
186 # load in the armnn delegate
187 armnn_delegate = load_delegate(args.delegate_path, args.preferred_backends)
188 # load tflite model
189 interpreter = load_tf_model(args.model_file, armnn_delegate)
190 # get input shape for image resizing
191 input_shape = interpreter.get_input_details()[0]["shape"]
192 height, width = input_shape[1], input_shape[2]
193 input_shape = (height, width)
194 # load input image
195 input_image = load_image(args.input_image, input_shape, False)
196 # get label mapping
197 labelmapping = create_mapping(args.label_file)
198 output_tensor = run_inference(interpreter, input_image)
199 output_prediction = process_output(output_tensor, labelmapping)
200
201 print("Prediction: ", output_prediction)
202
203 return None
204
205
206if __name__ == "__main__":
207 parser = argparse.ArgumentParser(
208 formatter_class=argparse.ArgumentDefaultsHelpFormatter
209 )
210 parser.add_argument(
211 "--input_image", help="File path of image file", type=Path, required=True
212 )
213 parser.add_argument(
214 "--model_file",
215 help="File path of the model tflite file",
216 type=Path,
217 required=True,
218 )
219 parser.add_argument(
220 "--label_file",
221 help="File path of model labelmapping file",
222 type=Path,
223 required=True,
224 )
225 parser.add_argument(
226 "--delegate_path",
227 help="File path of ArmNN delegate file",
228 type=Path,
229 required=True,
230 )
231 parser.add_argument(
232 "--preferred_backends",
233 help="list of backends in order of preference",
234 type=str,
Henri Woodcockf1c7f002021-05-26 17:37:18 +0100235 nargs="+",
Henri Woodcock3b38eed2021-05-19 13:41:44 +0100236 required=False,
237 default=["CpuAcc", "CpuRef"],
238 )
239 args = parser.parse_args()
240
241 main(args)