Henri Woodcock | 3b38eed | 2021-05-19 13:41:44 +0100 | [diff] [blame] | 1 | import argparse |
| 2 | from pathlib import Path |
| 3 | from typing import Union |
| 4 | |
| 5 | import tflite_runtime.interpreter as tflite |
| 6 | from PIL import Image |
| 7 | import numpy as np |
| 8 | |
| 9 | |
| 10 | def check_args(args: argparse.Namespace): |
| 11 | """Check the values used in the command-line have acceptable values |
| 12 | |
| 13 | args: |
| 14 | - args: argparse.Namespace |
| 15 | |
| 16 | returns: |
| 17 | - None |
| 18 | |
| 19 | raises: |
| 20 | - FileNotFoundError: if passed files do not exist. |
| 21 | - IOError: if files are of incorrect format. |
| 22 | """ |
| 23 | input_image_p = args.input_image |
| 24 | if not input_image_p.suffix in (".png", ".jpg", ".jpeg"): |
| 25 | raise IOError( |
| 26 | "--input_image option should point to an image file of the " |
| 27 | "format .jpg, .jpeg, .png" |
| 28 | ) |
| 29 | if not input_image_p.exists(): |
| 30 | raise FileNotFoundError("Cannot find ", input_image_p.name) |
| 31 | model_p = args.model_file |
| 32 | if not model_p.suffix == ".tflite": |
| 33 | raise IOError("--model_file should point to a tflite file.") |
| 34 | if not model_p.exists(): |
| 35 | raise FileNotFoundError("Cannot find ", model_p.name) |
| 36 | label_mapping_p = args.label_file |
| 37 | if not label_mapping_p.suffix == ".txt": |
| 38 | raise IOError("--label_file expects a .txt file.") |
| 39 | if not label_mapping_p.exists(): |
| 40 | raise FileNotFoundError("Cannot find ", label_mapping_p.name) |
| 41 | |
| 42 | # check all args given in preferred backends make sense |
| 43 | supported_backends = ["GpuAcc", "CpuAcc", "CpuRef"] |
| 44 | if not all([backend in supported_backends for backend in args.preferred_backends]): |
| 45 | raise ValueError("Incorrect backends given. Please choose from "\ |
| 46 | "'GpuAcc', 'CpuAcc', 'CpuRef'.") |
| 47 | |
| 48 | return None |
| 49 | |
| 50 | |
| 51 | def load_image(image_path: Path, model_input_dims: Union[tuple, list], grayscale: bool): |
| 52 | """load an image and put into correct format for the tensorflow lite model |
| 53 | |
| 54 | args: |
| 55 | - image_path: pathlib.Path |
| 56 | - model_input_dims: tuple (or array-like). (height,width) |
| 57 | |
| 58 | returns: |
| 59 | - image: np.array |
| 60 | """ |
| 61 | height, width = model_input_dims |
| 62 | # load and resize image |
| 63 | image = Image.open(image_path).resize((width, height)) |
| 64 | # convert to greyscale if expected |
| 65 | if grayscale: |
| 66 | image = image.convert("LA") |
| 67 | |
| 68 | image = np.expand_dims(image, axis=0) |
| 69 | |
| 70 | return image |
| 71 | |
| 72 | |
| 73 | def load_delegate(delegate_path: Path, backends: list): |
| 74 | """load the armnn delegate. |
| 75 | |
| 76 | args: |
| 77 | - delegate_path: pathlib.Path -> location of you libarmnnDelegate.so |
| 78 | - backends: list -> list of backends you want to use in string format |
| 79 | |
| 80 | returns: |
| 81 | - armnn_delegate: tflite.delegate |
| 82 | """ |
| 83 | # create a command separated string |
| 84 | backend_string = ",".join(backends) |
| 85 | # load delegate |
| 86 | armnn_delegate = tflite.load_delegate( |
| 87 | library=delegate_path, |
| 88 | options={"backends": backend_string, "logging-severity": "info"}, |
| 89 | ) |
| 90 | |
| 91 | return armnn_delegate |
| 92 | |
| 93 | |
| 94 | def load_tf_model(model_path: Path, armnn_delegate: tflite.Delegate): |
| 95 | """load a tflite model for use with the armnn delegate. |
| 96 | |
| 97 | args: |
| 98 | - model_path: pathlib.Path |
| 99 | - armnn_delegate: tflite.TfLiteDelegate |
| 100 | |
| 101 | returns: |
| 102 | - interpreter: tflite.Interpreter |
| 103 | """ |
| 104 | interpreter = tflite.Interpreter( |
| 105 | model_path=model_path.as_posix(), experimental_delegates=[armnn_delegate] |
| 106 | ) |
| 107 | interpreter.allocate_tensors() |
| 108 | |
| 109 | return interpreter |
| 110 | |
| 111 | |
| 112 | def run_inference(interpreter, input_image): |
| 113 | """Run inference on a processed input image and return the output from |
| 114 | inference. |
| 115 | |
| 116 | args: |
| 117 | - interpreter: tflite_runtime.interpreter.Interpreter |
| 118 | - input_image: np.array |
| 119 | |
| 120 | returns: |
| 121 | - output_data: np.array |
| 122 | """ |
| 123 | # Get input and output tensors. |
| 124 | input_details = interpreter.get_input_details() |
| 125 | output_details = interpreter.get_output_details() |
| 126 | # Test model on random input data. |
| 127 | interpreter.set_tensor(input_details[0]["index"], input_image) |
| 128 | interpreter.invoke() |
| 129 | output_data = interpreter.get_tensor(output_details[0]["index"]) |
| 130 | |
| 131 | return output_data |
| 132 | |
| 133 | |
| 134 | def create_mapping(label_mapping_p): |
| 135 | """Creates a Python dictionary mapping an index to a label. |
| 136 | |
| 137 | label_mapping[idx] = label |
| 138 | |
| 139 | args: |
| 140 | - label_mapping_p: pathlib.Path |
| 141 | |
| 142 | returns: |
| 143 | - label_mapping: dict |
| 144 | """ |
| 145 | idx = 0 |
| 146 | label_mapping = {} |
| 147 | with open(label_mapping_p) as label_mapping_raw: |
| 148 | for line in label_mapping_raw: |
| 149 | label_mapping[idx] = line |
Henri Woodcock | f1c7f00 | 2021-05-26 17:37:18 +0100 | [diff] [blame] | 150 | idx += 1 |
Henri Woodcock | 3b38eed | 2021-05-19 13:41:44 +0100 | [diff] [blame] | 151 | |
| 152 | return label_mapping |
| 153 | |
| 154 | |
| 155 | def process_output(output_data, label_mapping): |
| 156 | """Process the output tensor into a label from the labelmapping file. Takes |
| 157 | the index of the maximum valur from the output array. |
| 158 | |
| 159 | args: |
| 160 | - output_data: np.array |
| 161 | - label_mapping: dict |
| 162 | |
| 163 | returns: |
| 164 | - str: labelmapping for max index. |
| 165 | """ |
| 166 | idx = np.argmax(output_data[0]) |
| 167 | |
| 168 | return label_mapping[idx] |
| 169 | |
| 170 | |
| 171 | def main(args): |
| 172 | """Run the inference for options passed in the command line. |
| 173 | |
| 174 | args: |
| 175 | - args: argparse.Namespace |
| 176 | |
| 177 | returns: |
| 178 | - None |
| 179 | """ |
| 180 | # sanity check on args |
| 181 | check_args(args) |
| 182 | # load in the armnn delegate |
| 183 | armnn_delegate = load_delegate(args.delegate_path, args.preferred_backends) |
| 184 | # load tflite model |
| 185 | interpreter = load_tf_model(args.model_file, armnn_delegate) |
| 186 | # get input shape for image resizing |
| 187 | input_shape = interpreter.get_input_details()[0]["shape"] |
| 188 | height, width = input_shape[1], input_shape[2] |
| 189 | input_shape = (height, width) |
| 190 | # load input image |
| 191 | input_image = load_image(args.input_image, input_shape, False) |
| 192 | # get label mapping |
| 193 | labelmapping = create_mapping(args.label_file) |
| 194 | output_tensor = run_inference(interpreter, input_image) |
| 195 | output_prediction = process_output(output_tensor, labelmapping) |
| 196 | |
| 197 | print("Prediction: ", output_prediction) |
| 198 | |
| 199 | return None |
| 200 | |
| 201 | |
| 202 | if __name__ == "__main__": |
| 203 | parser = argparse.ArgumentParser( |
| 204 | formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| 205 | ) |
| 206 | parser.add_argument( |
| 207 | "--input_image", help="File path of image file", type=Path, required=True |
| 208 | ) |
| 209 | parser.add_argument( |
| 210 | "--model_file", |
| 211 | help="File path of the model tflite file", |
| 212 | type=Path, |
| 213 | required=True, |
| 214 | ) |
| 215 | parser.add_argument( |
| 216 | "--label_file", |
| 217 | help="File path of model labelmapping file", |
| 218 | type=Path, |
| 219 | required=True, |
| 220 | ) |
| 221 | parser.add_argument( |
| 222 | "--delegate_path", |
| 223 | help="File path of ArmNN delegate file", |
| 224 | type=Path, |
| 225 | required=True, |
| 226 | ) |
| 227 | parser.add_argument( |
| 228 | "--preferred_backends", |
| 229 | help="list of backends in order of preference", |
| 230 | type=str, |
Henri Woodcock | f1c7f00 | 2021-05-26 17:37:18 +0100 | [diff] [blame] | 231 | nargs="+", |
Henri Woodcock | 3b38eed | 2021-05-19 13:41:44 +0100 | [diff] [blame] | 232 | required=False, |
| 233 | default=["CpuAcc", "CpuRef"], |
| 234 | ) |
| 235 | args = parser.parse_args() |
| 236 | |
| 237 | main(args) |