Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 1 | # Copyright © 2020 Arm Ltd and Contributors. All rights reserved. |
| 2 | # SPDX-License-Identifier: MIT |
| 3 | |
| 4 | import os |
| 5 | from typing import List, Tuple |
| 6 | |
| 7 | import pyarmnn as ann |
| 8 | import numpy as np |
| 9 | |
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 10 | class ArmnnNetworkExecutor: |
| 11 | |
| 12 | def __init__(self, model_file: str, backends: list): |
| 13 | """ |
| 14 | Creates an inference executor for a given network and a list of backends. |
| 15 | |
| 16 | Args: |
| 17 | model_file: User-specified model file. |
| 18 | backends: List of backends to optimize network. |
| 19 | """ |
Raviv Shalev | 97ddc06 | 2021-12-07 15:18:09 +0200 | [diff] [blame] | 20 | self.model_file = model_file |
| 21 | self.backends = backends |
| 22 | self.network_id, self.runtime, self.input_binding_info, self.output_binding_info = self.create_network() |
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 23 | self.output_tensors = ann.make_output_tensors(self.output_binding_info) |
| 24 | |
Raviv Shalev | 97ddc06 | 2021-12-07 15:18:09 +0200 | [diff] [blame] | 25 | def run(self, input_data_list: list) -> List[np.ndarray]: |
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 26 | """ |
Raviv Shalev | 97ddc06 | 2021-12-07 15:18:09 +0200 | [diff] [blame] | 27 | Creates input tensors from input data and executes inference with the loaded network. |
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 28 | |
| 29 | Args: |
Raviv Shalev | 97ddc06 | 2021-12-07 15:18:09 +0200 | [diff] [blame] | 30 | input_data_list: List of input frames. |
Éanna Ó Catháin | 145c88f | 2020-11-16 14:12:11 +0000 | [diff] [blame] | 31 | |
| 32 | Returns: |
| 33 | list: Inference results as a list of ndarrays. |
| 34 | """ |
Raviv Shalev | 97ddc06 | 2021-12-07 15:18:09 +0200 | [diff] [blame] | 35 | input_tensors = ann.make_input_tensors(self.input_binding_info, input_data_list) |
| 36 | self.runtime.EnqueueWorkload(self.network_id, input_tensors, self.output_tensors) |
| 37 | output = ann.workload_tensors_to_ndarray(self.output_tensors) |
| 38 | |
| 39 | return output |
| 40 | |
| 41 | def create_network(self): |
| 42 | """ |
| 43 | Creates a network based on the model file and a list of backends. |
| 44 | |
| 45 | Returns: |
| 46 | net_id: Unique ID of the network to run. |
| 47 | runtime: Runtime context for executing inference. |
| 48 | input_binding_info: Contains essential information about the model input. |
| 49 | output_binding_info: Used to map output tensor and its memory. |
| 50 | """ |
| 51 | if not os.path.exists(self.model_file): |
| 52 | raise FileNotFoundError(f'Model file not found for: {self.model_file}') |
| 53 | |
| 54 | _, ext = os.path.splitext(self.model_file) |
| 55 | if ext == '.tflite': |
| 56 | parser = ann.ITfLiteParser() |
| 57 | else: |
| 58 | raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]") |
| 59 | |
| 60 | network = parser.CreateNetworkFromBinaryFile(self.model_file) |
| 61 | |
| 62 | # Specify backends to optimize network |
| 63 | preferred_backends = [] |
| 64 | for b in self.backends: |
| 65 | preferred_backends.append(ann.BackendId(b)) |
| 66 | |
| 67 | # Select appropriate device context and optimize the network for that device |
| 68 | options = ann.CreationOptions() |
| 69 | runtime = ann.IRuntime(options) |
| 70 | opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), |
| 71 | ann.OptimizerOptions()) |
| 72 | print(f'Preferred backends: {self.backends}\n{runtime.GetDeviceSpec()}\n' |
| 73 | f'Optimization warnings: {messages}') |
| 74 | |
| 75 | # Load the optimized network onto the Runtime device |
| 76 | net_id, _ = runtime.LoadNetwork(opt_network) |
| 77 | |
| 78 | # Get input and output binding information |
| 79 | graph_id = parser.GetSubgraphCount() - 1 |
| 80 | input_names = parser.GetSubgraphInputTensorNames(graph_id) |
| 81 | input_binding_info = [] |
| 82 | for input_name in input_names: |
| 83 | in_bind_info = parser.GetNetworkInputBindingInfo(graph_id, input_name) |
| 84 | input_binding_info.append(in_bind_info) |
| 85 | output_names = parser.GetSubgraphOutputTensorNames(graph_id) |
| 86 | output_binding_info = [] |
| 87 | for output_name in output_names: |
| 88 | out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name) |
| 89 | output_binding_info.append(out_bind_info) |
| 90 | return net_id, runtime, input_binding_info, output_binding_info |
| 91 | |
| 92 | def get_data_type(self): |
| 93 | """ |
| 94 | Get the input data type of the initiated network. |
| 95 | |
| 96 | Returns: |
| 97 | numpy data type or None if doesn't exist in the if condition. |
| 98 | """ |
| 99 | if self.input_binding_info[0][1].GetDataType() == ann.DataType_Float32: |
| 100 | return np.float32 |
| 101 | elif self.input_binding_info[0][1].GetDataType() == ann.DataType_QAsymmU8: |
| 102 | return np.uint8 |
| 103 | elif self.input_binding_info[0][1].GetDataType() == ann.DataType_QAsymmS8: |
| 104 | return np.int8 |
| 105 | else: |
| 106 | return None |
| 107 | |
| 108 | def get_shape(self): |
| 109 | """ |
| 110 | Get the input shape of the initiated network. |
| 111 | |
| 112 | Returns: |
| 113 | tuple: The Shape of the network input. |
| 114 | """ |
| 115 | return tuple(self.input_binding_info[0][1].GetShape()) |
| 116 | |
| 117 | def get_input_quantization_scale(self, idx): |
| 118 | """ |
| 119 | Get the input quantization scale of the initiated network. |
| 120 | |
| 121 | Returns: |
| 122 | The quantization scale of the network input. |
| 123 | """ |
| 124 | return self.input_binding_info[idx][1].GetQuantizationScale() |
| 125 | |
| 126 | def get_input_quantization_offset(self, idx): |
| 127 | """ |
| 128 | Get the input quantization offset of the initiated network. |
| 129 | |
| 130 | Returns: |
| 131 | The quantization offset of the network input. |
| 132 | """ |
| 133 | return self.input_binding_info[idx][1].GetQuantizationOffset() |
| 134 | |
| 135 | def is_output_quantized(self, idx): |
| 136 | """ |
| 137 | Get True/False if output tensor is quantized or not respectively. |
| 138 | |
| 139 | Returns: |
| 140 | True if output is quantized and False otherwise. |
| 141 | """ |
| 142 | return self.output_binding_info[idx][1].IsQuantized() |
| 143 | |
| 144 | def get_output_quantization_scale(self, idx): |
| 145 | """ |
| 146 | Get the output quantization offset of the initiated network. |
| 147 | |
| 148 | Returns: |
| 149 | The quantization offset of the network output. |
| 150 | """ |
| 151 | return self.output_binding_info[idx][1].GetQuantizationScale() |
| 152 | |
| 153 | def get_output_quantization_offset(self, idx): |
| 154 | """ |
| 155 | Get the output quantization offset of the initiated network. |
| 156 | |
| 157 | Returns: |
| 158 | The quantization offset of the network output. |
| 159 | """ |
| 160 | return self.output_binding_info[idx][1].GetQuantizationOffset() |
| 161 | |