Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 1 | # Copyright © 2020 Arm Ltd and Contributors. All rights reserved. |
| 2 | # SPDX-License-Identifier: MIT |
| 3 | |
| 4 | import os |
| 5 | from typing import List, Tuple |
| 6 | |
| 7 | import pyarmnn as ann |
| 8 | import numpy as np |
| 9 | |
| 10 | |
| 11 | def create_network(model_file: str, backends: list, input_names: Tuple[str] = (), output_names: Tuple[str] = ()): |
| 12 | """ |
| 13 | Creates a network based on the model file and a list of backends. |
| 14 | |
| 15 | Args: |
| 16 | model_file: User-specified model file. |
| 17 | backends: List of backends to optimize network. |
| 18 | input_names: |
| 19 | output_names: |
| 20 | |
| 21 | Returns: |
| 22 | net_id: Unique ID of the network to run. |
| 23 | runtime: Runtime context for executing inference. |
| 24 | input_binding_info: Contains essential information about the model input. |
| 25 | output_binding_info: Used to map output tensor and its memory. |
| 26 | """ |
| 27 | if not os.path.exists(model_file): |
| 28 | raise FileNotFoundError(f'Model file not found for: {model_file}') |
| 29 | |
| 30 | _, ext = os.path.splitext(model_file) |
| 31 | if ext == '.tflite': |
| 32 | parser = ann.ITfLiteParser() |
| 33 | else: |
| 34 | raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]") |
| 35 | |
| 36 | network = parser.CreateNetworkFromBinaryFile(model_file) |
| 37 | |
| 38 | # Specify backends to optimize network |
| 39 | preferred_backends = [] |
| 40 | for b in backends: |
| 41 | preferred_backends.append(ann.BackendId(b)) |
| 42 | |
| 43 | # Select appropriate device context and optimize the network for that device |
| 44 | options = ann.CreationOptions() |
| 45 | runtime = ann.IRuntime(options) |
| 46 | opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), |
| 47 | ann.OptimizerOptions()) |
| 48 | print(f'Preferred backends: {backends}\n{runtime.GetDeviceSpec()}\n' |
| 49 | f'Optimization warnings: {messages}') |
| 50 | |
| 51 | # Load the optimized network onto the Runtime device |
| 52 | net_id, _ = runtime.LoadNetwork(opt_network) |
| 53 | |
| 54 | # Get input and output binding information |
| 55 | graph_id = parser.GetSubgraphCount() - 1 |
| 56 | input_names = parser.GetSubgraphInputTensorNames(graph_id) |
| 57 | input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0]) |
| 58 | output_names = parser.GetSubgraphOutputTensorNames(graph_id) |
| 59 | output_binding_info = [] |
| 60 | for output_name in output_names: |
| 61 | out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name) |
| 62 | output_binding_info.append(out_bind_info) |
| 63 | return net_id, runtime, input_binding_info, output_binding_info |
| 64 | |
| 65 | |
| 66 | def execute_network(input_tensors: list, output_tensors: list, runtime, net_id: int) -> List[np.ndarray]: |
| 67 | """ |
| 68 | Executes inference for the loaded network. |
| 69 | |
| 70 | Args: |
| 71 | input_tensors: The input frame tensor. |
| 72 | output_tensors: The output tensor from output node. |
| 73 | runtime: Runtime context for executing inference. |
| 74 | net_id: Unique ID of the network to run. |
| 75 | |
| 76 | Returns: |
| 77 | list: Inference results as a list of ndarrays. |
| 78 | """ |
| 79 | runtime.EnqueueWorkload(net_id, input_tensors, output_tensors) |
| 80 | output = ann.workload_tensors_to_ndarray(output_tensors) |
| 81 | return output |
| 82 | |
| 83 | |
| 84 | class ArmnnNetworkExecutor: |
| 85 | |
| 86 | def __init__(self, model_file: str, backends: list): |
| 87 | """ |
| 88 | Creates an inference executor for a given network and a list of backends. |
| 89 | |
| 90 | Args: |
| 91 | model_file: User-specified model file. |
| 92 | backends: List of backends to optimize network. |
| 93 | """ |
| 94 | self.network_id, self.runtime, self.input_binding_info, self.output_binding_info = create_network(model_file, |
| 95 | backends) |
| 96 | self.output_tensors = ann.make_output_tensors(self.output_binding_info) |
| 97 | |
| 98 | def run(self, input_tensors: list) -> List[np.ndarray]: |
| 99 | """ |
| 100 | Executes inference for the loaded network. |
| 101 | |
| 102 | Args: |
| 103 | input_tensors: The input frame tensor. |
| 104 | |
| 105 | Returns: |
| 106 | list: Inference results as a list of ndarrays. |
| 107 | """ |
| 108 | return execute_network(input_tensors, self.output_tensors, self.runtime, self.network_id) |