blob: 4ce8b8b84efd418f18d5f389723e1d0e265d0081 [file] [log] [blame]
#
# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
#
import argparse
from pathlib import Path
from typing import Union
import tflite_runtime.interpreter as tflite
from PIL import Image
import numpy as np
def check_args(args: argparse.Namespace):
"""Check the values used in the command-line have acceptable values
args:
- args: argparse.Namespace
returns:
- None
raises:
- FileNotFoundError: if passed files do not exist.
- IOError: if files are of incorrect format.
"""
input_image_p = args.input_image
if not input_image_p.suffix in (".png", ".jpg", ".jpeg"):
raise IOError(
"--input_image option should point to an image file of the "
"format .jpg, .jpeg, .png"
)
if not input_image_p.exists():
raise FileNotFoundError("Cannot find ", input_image_p.name)
model_p = args.model_file
if not model_p.suffix == ".tflite":
raise IOError("--model_file should point to a tflite file.")
if not model_p.exists():
raise FileNotFoundError("Cannot find ", model_p.name)
label_mapping_p = args.label_file
if not label_mapping_p.suffix == ".txt":
raise IOError("--label_file expects a .txt file.")
if not label_mapping_p.exists():
raise FileNotFoundError("Cannot find ", label_mapping_p.name)
# check all args given in preferred backends make sense
supported_backends = ["GpuAcc", "CpuAcc", "CpuRef"]
if not all([backend in supported_backends for backend in args.preferred_backends]):
raise ValueError("Incorrect backends given. Please choose from "\
"'GpuAcc', 'CpuAcc', 'CpuRef'.")
return None
def load_image(image_path: Path, model_input_dims: Union[tuple, list], grayscale: bool):
"""load an image and put into correct format for the tensorflow lite model
args:
- image_path: pathlib.Path
- model_input_dims: tuple (or array-like). (height,width)
returns:
- image: np.array
"""
height, width = model_input_dims
# load and resize image
image = Image.open(image_path).resize((width, height))
# convert to greyscale if expected
if grayscale:
image = image.convert("LA")
image = np.expand_dims(image, axis=0)
return image
def load_delegate(delegate_path: Path, backends: list):
"""load the armnn delegate.
args:
- delegate_path: pathlib.Path -> location of you libarmnnDelegate.so
- backends: list -> list of backends you want to use in string format
returns:
- armnn_delegate: tflite.delegate
"""
# create a command separated string
backend_string = ",".join(backends)
# load delegate
armnn_delegate = tflite.load_delegate(
library=delegate_path,
options={"backends": backend_string, "logging-severity": "info"},
)
return armnn_delegate
def load_tf_model(model_path: Path, armnn_delegate: tflite.Delegate):
"""load a tflite model for use with the armnn delegate.
args:
- model_path: pathlib.Path
- armnn_delegate: tflite.TfLiteDelegate
returns:
- interpreter: tflite.Interpreter
"""
interpreter = tflite.Interpreter(
model_path=model_path.as_posix(), experimental_delegates=[armnn_delegate]
)
interpreter.allocate_tensors()
return interpreter
def run_inference(interpreter, input_image):
"""Run inference on a processed input image and return the output from
inference.
args:
- interpreter: tflite_runtime.interpreter.Interpreter
- input_image: np.array
returns:
- output_data: np.array
"""
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Test model on random input data.
interpreter.set_tensor(input_details[0]["index"], input_image)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]["index"])
return output_data
def create_mapping(label_mapping_p):
"""Creates a Python dictionary mapping an index to a label.
label_mapping[idx] = label
args:
- label_mapping_p: pathlib.Path
returns:
- label_mapping: dict
"""
idx = 0
label_mapping = {}
with open(label_mapping_p) as label_mapping_raw:
for line in label_mapping_raw:
label_mapping[idx] = line
idx += 1
return label_mapping
def process_output(output_data, label_mapping):
"""Process the output tensor into a label from the labelmapping file. Takes
the index of the maximum valur from the output array.
args:
- output_data: np.array
- label_mapping: dict
returns:
- str: labelmapping for max index.
"""
idx = np.argmax(output_data[0])
return label_mapping[idx]
def main(args):
"""Run the inference for options passed in the command line.
args:
- args: argparse.Namespace
returns:
- None
"""
# sanity check on args
check_args(args)
# load in the armnn delegate
armnn_delegate = load_delegate(args.delegate_path, args.preferred_backends)
# load tflite model
interpreter = load_tf_model(args.model_file, armnn_delegate)
# get input shape for image resizing
input_shape = interpreter.get_input_details()[0]["shape"]
height, width = input_shape[1], input_shape[2]
input_shape = (height, width)
# load input image
input_image = load_image(args.input_image, input_shape, False)
# get label mapping
labelmapping = create_mapping(args.label_file)
output_tensor = run_inference(interpreter, input_image)
output_prediction = process_output(output_tensor, labelmapping)
print("Prediction: ", output_prediction)
return None
if __name__ == "__main__":
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--input_image", help="File path of image file", type=Path, required=True
)
parser.add_argument(
"--model_file",
help="File path of the model tflite file",
type=Path,
required=True,
)
parser.add_argument(
"--label_file",
help="File path of model labelmapping file",
type=Path,
required=True,
)
parser.add_argument(
"--delegate_path",
help="File path of ArmNN delegate file",
type=Path,
required=True,
)
parser.add_argument(
"--preferred_backends",
help="list of backends in order of preference",
type=str,
nargs="+",
required=False,
default=["CpuAcc", "CpuRef"],
)
args = parser.parse_args()
main(args)