blob: 10f5e6e6fbe6c4e69f024a27be53537106eb813e [file] [log] [blame]
Raviv Shalev97ddc062021-12-07 15:18:09 +02001# Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4import os
5from typing import List, Tuple
6
7import numpy as np
8from tflite_runtime import interpreter as tflite
9
10class TFLiteNetworkExecutor:
11
12 def __init__(self, model_file: str, backends: list, delegate_path: str):
13 """
14 Creates an inference executor for a given network and a list of backends.
15
16 Args:
17 model_file: User-specified model file.
18 backends: List of backends to optimize network.
19 delegate_path: tflite delegate file path (.so).
20 """
21 self.model_file = model_file
22 self.backends = backends
23 self.delegate_path = delegate_path
24 self.interpreter, self.input_details, self.output_details = self.create_network()
25
26 def run(self, input_data_list: list) -> List[np.ndarray]:
27 """
28 Executes inference for the loaded network.
29
30 Args:
31 input_data_list: List of input frames.
32
33 Returns:
34 list: Inference results as a list of ndarrays.
35 """
36 output = []
37 for index, input_data in enumerate(input_data_list):
38 self.interpreter.set_tensor(self.input_details[index]['index'], input_data)
39 self.interpreter.invoke()
40 for curr_output in self.output_details:
41 output.append(self.interpreter.get_tensor(curr_output['index']))
42
43 return output
44
45 def create_network(self):
46 """
47 Creates a network based on the model file and a list of backends.
48
49 Returns:
50 interpreter: A TensorFlow Lite object for executing inference.
51 input_details: Contains essential information about the model input.
52 output_details: Used to map output tensor and its memory.
53 """
54
55 # Controls whether optimizations are used or not.
56 # Please note that optimizations can improve performance in some cases, but it can also
57 # degrade the performance in other cases. Accuracy might also be affected.
58
59 optimization_enable = "true"
60
61 if not os.path.exists(self.model_file):
62 raise FileNotFoundError(f'Model file not found for: {self.model_file}')
63
64 _, ext = os.path.splitext(self.model_file)
65 if ext == '.tflite':
66 armnn_delegate = tflite.load_delegate(library=self.delegate_path,
67 options={"backends": ','.join(self.backends), "logging-severity": "info",
68 "enable-fast-math": optimization_enable,
69 "reduce-fp32-to-fp16": optimization_enable})
70 interpreter = tflite.Interpreter(model_path=self.model_file,
71 experimental_delegates=[armnn_delegate])
72 interpreter.allocate_tensors()
73 else:
74 raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]")
75
76 # Get input and output binding information
77 input_details = interpreter.get_input_details()
78 output_details = interpreter.get_output_details()
79
80 return interpreter, input_details, output_details
81
82 def get_data_type(self):
83 """
84 Get the input data type of the initiated network.
85
86 Returns:
87 numpy data type or None if doesn't exist in the if condition.
88 """
89 return self.input_details[0]['dtype']
90
91 def get_shape(self):
92 """
93 Get the input shape of the initiated network.
94
95 Returns:
96 tuple: The Shape of the network input.
97 """
98 return tuple(self.input_details[0]['shape'])