Raviv Shalev | 97ddc06 | 2021-12-07 15:18:09 +0200 | [diff] [blame] | 1 | # Copyright © 2022 Arm Ltd and Contributors. All rights reserved. |
| 2 | # SPDX-License-Identifier: MIT |
| 3 | |
| 4 | import os |
| 5 | from typing import List, Tuple |
| 6 | |
| 7 | import numpy as np |
| 8 | from tflite_runtime import interpreter as tflite |
| 9 | |
| 10 | class TFLiteNetworkExecutor: |
| 11 | |
| 12 | def __init__(self, model_file: str, backends: list, delegate_path: str): |
| 13 | """ |
| 14 | Creates an inference executor for a given network and a list of backends. |
| 15 | |
| 16 | Args: |
| 17 | model_file: User-specified model file. |
| 18 | backends: List of backends to optimize network. |
| 19 | delegate_path: tflite delegate file path (.so). |
| 20 | """ |
| 21 | self.model_file = model_file |
| 22 | self.backends = backends |
| 23 | self.delegate_path = delegate_path |
| 24 | self.interpreter, self.input_details, self.output_details = self.create_network() |
| 25 | |
| 26 | def run(self, input_data_list: list) -> List[np.ndarray]: |
| 27 | """ |
| 28 | Executes inference for the loaded network. |
| 29 | |
| 30 | Args: |
| 31 | input_data_list: List of input frames. |
| 32 | |
| 33 | Returns: |
| 34 | list: Inference results as a list of ndarrays. |
| 35 | """ |
| 36 | output = [] |
| 37 | for index, input_data in enumerate(input_data_list): |
| 38 | self.interpreter.set_tensor(self.input_details[index]['index'], input_data) |
| 39 | self.interpreter.invoke() |
| 40 | for curr_output in self.output_details: |
| 41 | output.append(self.interpreter.get_tensor(curr_output['index'])) |
| 42 | |
| 43 | return output |
| 44 | |
| 45 | def create_network(self): |
| 46 | """ |
| 47 | Creates a network based on the model file and a list of backends. |
| 48 | |
| 49 | Returns: |
| 50 | interpreter: A TensorFlow Lite object for executing inference. |
| 51 | input_details: Contains essential information about the model input. |
| 52 | output_details: Used to map output tensor and its memory. |
| 53 | """ |
| 54 | |
| 55 | # Controls whether optimizations are used or not. |
| 56 | # Please note that optimizations can improve performance in some cases, but it can also |
| 57 | # degrade the performance in other cases. Accuracy might also be affected. |
| 58 | |
| 59 | optimization_enable = "true" |
| 60 | |
| 61 | if not os.path.exists(self.model_file): |
| 62 | raise FileNotFoundError(f'Model file not found for: {self.model_file}') |
| 63 | |
| 64 | _, ext = os.path.splitext(self.model_file) |
| 65 | if ext == '.tflite': |
| 66 | armnn_delegate = tflite.load_delegate(library=self.delegate_path, |
| 67 | options={"backends": ','.join(self.backends), "logging-severity": "info", |
| 68 | "enable-fast-math": optimization_enable, |
| 69 | "reduce-fp32-to-fp16": optimization_enable}) |
| 70 | interpreter = tflite.Interpreter(model_path=self.model_file, |
| 71 | experimental_delegates=[armnn_delegate]) |
| 72 | interpreter.allocate_tensors() |
| 73 | else: |
| 74 | raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]") |
| 75 | |
| 76 | # Get input and output binding information |
| 77 | input_details = interpreter.get_input_details() |
| 78 | output_details = interpreter.get_output_details() |
| 79 | |
| 80 | return interpreter, input_details, output_details |
| 81 | |
| 82 | def get_data_type(self): |
| 83 | """ |
| 84 | Get the input data type of the initiated network. |
| 85 | |
| 86 | Returns: |
| 87 | numpy data type or None if doesn't exist in the if condition. |
| 88 | """ |
| 89 | return self.input_details[0]['dtype'] |
| 90 | |
| 91 | def get_shape(self): |
| 92 | """ |
| 93 | Get the input shape of the initiated network. |
| 94 | |
| 95 | Returns: |
| 96 | tuple: The Shape of the network input. |
| 97 | """ |
| 98 | return tuple(self.input_details[0]['shape']) |