blob: 72262fc520e6f7d4b211ca4252989a2ba3de5ca8 [file] [log] [blame]
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +00001# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4import os
5from typing import List, Tuple
6
7import pyarmnn as ann
8import numpy as np
9
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000010class ArmnnNetworkExecutor:
11
12 def __init__(self, model_file: str, backends: list):
13 """
14 Creates an inference executor for a given network and a list of backends.
15
16 Args:
17 model_file: User-specified model file.
18 backends: List of backends to optimize network.
19 """
Raviv Shalev97ddc062021-12-07 15:18:09 +020020 self.model_file = model_file
21 self.backends = backends
22 self.network_id, self.runtime, self.input_binding_info, self.output_binding_info = self.create_network()
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000023 self.output_tensors = ann.make_output_tensors(self.output_binding_info)
24
Raviv Shalev97ddc062021-12-07 15:18:09 +020025 def run(self, input_data_list: list) -> List[np.ndarray]:
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000026 """
Raviv Shalev97ddc062021-12-07 15:18:09 +020027 Creates input tensors from input data and executes inference with the loaded network.
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000028
29 Args:
Raviv Shalev97ddc062021-12-07 15:18:09 +020030 input_data_list: List of input frames.
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +000031
32 Returns:
33 list: Inference results as a list of ndarrays.
34 """
Raviv Shalev97ddc062021-12-07 15:18:09 +020035 input_tensors = ann.make_input_tensors(self.input_binding_info, input_data_list)
36 self.runtime.EnqueueWorkload(self.network_id, input_tensors, self.output_tensors)
37 output = ann.workload_tensors_to_ndarray(self.output_tensors)
38
39 return output
40
41 def create_network(self):
42 """
43 Creates a network based on the model file and a list of backends.
44
45 Returns:
46 net_id: Unique ID of the network to run.
47 runtime: Runtime context for executing inference.
48 input_binding_info: Contains essential information about the model input.
49 output_binding_info: Used to map output tensor and its memory.
50 """
51 if not os.path.exists(self.model_file):
52 raise FileNotFoundError(f'Model file not found for: {self.model_file}')
53
54 _, ext = os.path.splitext(self.model_file)
55 if ext == '.tflite':
56 parser = ann.ITfLiteParser()
57 else:
58 raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]")
59
60 network = parser.CreateNetworkFromBinaryFile(self.model_file)
61
62 # Specify backends to optimize network
63 preferred_backends = []
64 for b in self.backends:
65 preferred_backends.append(ann.BackendId(b))
66
67 # Select appropriate device context and optimize the network for that device
68 options = ann.CreationOptions()
69 runtime = ann.IRuntime(options)
70 opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
71 ann.OptimizerOptions())
72 print(f'Preferred backends: {self.backends}\n{runtime.GetDeviceSpec()}\n'
73 f'Optimization warnings: {messages}')
74
75 # Load the optimized network onto the Runtime device
76 net_id, _ = runtime.LoadNetwork(opt_network)
77
78 # Get input and output binding information
79 graph_id = parser.GetSubgraphCount() - 1
80 input_names = parser.GetSubgraphInputTensorNames(graph_id)
81 input_binding_info = []
82 for input_name in input_names:
83 in_bind_info = parser.GetNetworkInputBindingInfo(graph_id, input_name)
84 input_binding_info.append(in_bind_info)
85 output_names = parser.GetSubgraphOutputTensorNames(graph_id)
86 output_binding_info = []
87 for output_name in output_names:
88 out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name)
89 output_binding_info.append(out_bind_info)
90 return net_id, runtime, input_binding_info, output_binding_info
91
92 def get_data_type(self):
93 """
94 Get the input data type of the initiated network.
95
96 Returns:
97 numpy data type or None if doesn't exist in the if condition.
98 """
99 if self.input_binding_info[0][1].GetDataType() == ann.DataType_Float32:
100 return np.float32
101 elif self.input_binding_info[0][1].GetDataType() == ann.DataType_QAsymmU8:
102 return np.uint8
103 elif self.input_binding_info[0][1].GetDataType() == ann.DataType_QAsymmS8:
104 return np.int8
105 else:
106 return None
107
108 def get_shape(self):
109 """
110 Get the input shape of the initiated network.
111
112 Returns:
113 tuple: The Shape of the network input.
114 """
115 return tuple(self.input_binding_info[0][1].GetShape())
116
117 def get_input_quantization_scale(self, idx):
118 """
119 Get the input quantization scale of the initiated network.
120
121 Returns:
122 The quantization scale of the network input.
123 """
124 return self.input_binding_info[idx][1].GetQuantizationScale()
125
126 def get_input_quantization_offset(self, idx):
127 """
128 Get the input quantization offset of the initiated network.
129
130 Returns:
131 The quantization offset of the network input.
132 """
133 return self.input_binding_info[idx][1].GetQuantizationOffset()
134
135 def is_output_quantized(self, idx):
136 """
137 Get True/False if output tensor is quantized or not respectively.
138
139 Returns:
140 True if output is quantized and False otherwise.
141 """
142 return self.output_binding_info[idx][1].IsQuantized()
143
144 def get_output_quantization_scale(self, idx):
145 """
146 Get the output quantization offset of the initiated network.
147
148 Returns:
149 The quantization offset of the network output.
150 """
151 return self.output_binding_info[idx][1].GetQuantizationScale()
152
153 def get_output_quantization_offset(self, idx):
154 """
155 Get the output quantization offset of the initiated network.
156
157 Returns:
158 The quantization offset of the network output.
159 """
160 return self.output_binding_info[idx][1].GetQuantizationOffset()
161