blob: 10f5e6e6fbe6c4e69f024a27be53537106eb813e [file] [log] [blame]
# Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
import os
from typing import List, Tuple
import numpy as np
from tflite_runtime import interpreter as tflite
class TFLiteNetworkExecutor:
def __init__(self, model_file: str, backends: list, delegate_path: str):
"""
Creates an inference executor for a given network and a list of backends.
Args:
model_file: User-specified model file.
backends: List of backends to optimize network.
delegate_path: tflite delegate file path (.so).
"""
self.model_file = model_file
self.backends = backends
self.delegate_path = delegate_path
self.interpreter, self.input_details, self.output_details = self.create_network()
def run(self, input_data_list: list) -> List[np.ndarray]:
"""
Executes inference for the loaded network.
Args:
input_data_list: List of input frames.
Returns:
list: Inference results as a list of ndarrays.
"""
output = []
for index, input_data in enumerate(input_data_list):
self.interpreter.set_tensor(self.input_details[index]['index'], input_data)
self.interpreter.invoke()
for curr_output in self.output_details:
output.append(self.interpreter.get_tensor(curr_output['index']))
return output
def create_network(self):
"""
Creates a network based on the model file and a list of backends.
Returns:
interpreter: A TensorFlow Lite object for executing inference.
input_details: Contains essential information about the model input.
output_details: Used to map output tensor and its memory.
"""
# Controls whether optimizations are used or not.
# Please note that optimizations can improve performance in some cases, but it can also
# degrade the performance in other cases. Accuracy might also be affected.
optimization_enable = "true"
if not os.path.exists(self.model_file):
raise FileNotFoundError(f'Model file not found for: {self.model_file}')
_, ext = os.path.splitext(self.model_file)
if ext == '.tflite':
armnn_delegate = tflite.load_delegate(library=self.delegate_path,
options={"backends": ','.join(self.backends), "logging-severity": "info",
"enable-fast-math": optimization_enable,
"reduce-fp32-to-fp16": optimization_enable})
interpreter = tflite.Interpreter(model_path=self.model_file,
experimental_delegates=[armnn_delegate])
interpreter.allocate_tensors()
else:
raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]")
# Get input and output binding information
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
return interpreter, input_details, output_details
def get_data_type(self):
"""
Get the input data type of the initiated network.
Returns:
numpy data type or None if doesn't exist in the if condition.
"""
return self.input_details[0]['dtype']
def get_shape(self):
"""
Get the input shape of the initiated network.
Returns:
tuple: The Shape of the network input.
"""
return tuple(self.input_details[0]['shape'])