blob: 6e2c53c43daa7451ea110b3df552c09feb89b97b [file] [log] [blame]
Éanna Ó Catháin145c88f2020-11-16 14:12:11 +00001# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4import os
5from typing import List, Tuple
6
7import pyarmnn as ann
8import numpy as np
9
10
11def create_network(model_file: str, backends: list, input_names: Tuple[str] = (), output_names: Tuple[str] = ()):
12 """
13 Creates a network based on the model file and a list of backends.
14
15 Args:
16 model_file: User-specified model file.
17 backends: List of backends to optimize network.
18 input_names:
19 output_names:
20
21 Returns:
22 net_id: Unique ID of the network to run.
23 runtime: Runtime context for executing inference.
24 input_binding_info: Contains essential information about the model input.
25 output_binding_info: Used to map output tensor and its memory.
26 """
27 if not os.path.exists(model_file):
28 raise FileNotFoundError(f'Model file not found for: {model_file}')
29
30 _, ext = os.path.splitext(model_file)
31 if ext == '.tflite':
32 parser = ann.ITfLiteParser()
33 else:
34 raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]")
35
36 network = parser.CreateNetworkFromBinaryFile(model_file)
37
38 # Specify backends to optimize network
39 preferred_backends = []
40 for b in backends:
41 preferred_backends.append(ann.BackendId(b))
42
43 # Select appropriate device context and optimize the network for that device
44 options = ann.CreationOptions()
45 runtime = ann.IRuntime(options)
46 opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(),
47 ann.OptimizerOptions())
48 print(f'Preferred backends: {backends}\n{runtime.GetDeviceSpec()}\n'
49 f'Optimization warnings: {messages}')
50
51 # Load the optimized network onto the Runtime device
52 net_id, _ = runtime.LoadNetwork(opt_network)
53
54 # Get input and output binding information
55 graph_id = parser.GetSubgraphCount() - 1
56 input_names = parser.GetSubgraphInputTensorNames(graph_id)
57 input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
58 output_names = parser.GetSubgraphOutputTensorNames(graph_id)
59 output_binding_info = []
60 for output_name in output_names:
61 out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name)
62 output_binding_info.append(out_bind_info)
63 return net_id, runtime, input_binding_info, output_binding_info
64
65
66def execute_network(input_tensors: list, output_tensors: list, runtime, net_id: int) -> List[np.ndarray]:
67 """
68 Executes inference for the loaded network.
69
70 Args:
71 input_tensors: The input frame tensor.
72 output_tensors: The output tensor from output node.
73 runtime: Runtime context for executing inference.
74 net_id: Unique ID of the network to run.
75
76 Returns:
77 list: Inference results as a list of ndarrays.
78 """
79 runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
80 output = ann.workload_tensors_to_ndarray(output_tensors)
81 return output
82
83
84class ArmnnNetworkExecutor:
85
86 def __init__(self, model_file: str, backends: list):
87 """
88 Creates an inference executor for a given network and a list of backends.
89
90 Args:
91 model_file: User-specified model file.
92 backends: List of backends to optimize network.
93 """
94 self.network_id, self.runtime, self.input_binding_info, self.output_binding_info = create_network(model_file,
95 backends)
96 self.output_tensors = ann.make_output_tensors(self.output_binding_info)
97
98 def run(self, input_tensors: list) -> List[np.ndarray]:
99 """
100 Executes inference for the loaded network.
101
102 Args:
103 input_tensors: The input frame tensor.
104
105 Returns:
106 list: Inference results as a list of ndarrays.
107 """
108 return execute_network(input_tensors, self.output_tensors, self.runtime, self.network_id)