Jan Eilers | 2cd1847 | 2020-12-15 10:42:38 +0000 | [diff] [blame] | 1 | # Copyright © 2020 Arm Ltd and Contributors. All rights reserved. |
| 2 | # SPDX-License-Identifier: MIT |
| 3 | |
| 4 | import tflite_runtime.interpreter as tflite |
| 5 | import numpy as np |
| 6 | import os |
| 7 | |
| 8 | |
| 9 | def run_mock_model(delegate, test_data_folder): |
| 10 | model_path = os.path.join(test_data_folder, 'mock_model.tflite') |
| 11 | interpreter = tflite.Interpreter(model_path=model_path, |
| 12 | experimental_delegates=[delegate]) |
| 13 | interpreter.allocate_tensors() |
| 14 | |
| 15 | # Get input and output tensors. |
| 16 | input_details = interpreter.get_input_details() |
| 17 | output_details = interpreter.get_output_details() |
| 18 | |
| 19 | # Test model on random input data. |
| 20 | input_shape = input_details[0]['shape'] |
| 21 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.uint8) |
| 22 | interpreter.set_tensor(input_details[0]['index'], input_data) |
| 23 | |
Narumol Prangnawarat | 74a3cf5 | 2021-01-29 15:38:54 +0000 | [diff] [blame] | 24 | interpreter.invoke() |
| 25 | |
| 26 | def run_inference(test_data_folder, model_filename, inputs, delegates=None): |
| 27 | model_path = os.path.join(test_data_folder, model_filename) |
| 28 | interpreter = tflite.Interpreter(model_path=model_path, |
| 29 | experimental_delegates=delegates) |
| 30 | interpreter.allocate_tensors() |
| 31 | |
| 32 | # Get input and output tensors. |
| 33 | input_details = interpreter.get_input_details() |
| 34 | output_details = interpreter.get_output_details() |
| 35 | |
| 36 | # Set inputs to tensors. |
| 37 | for i in range(len(inputs)): |
| 38 | interpreter.set_tensor(input_details[i]['index'], inputs[i]) |
| 39 | |
| 40 | interpreter.invoke() |
| 41 | |
| 42 | results = [] |
| 43 | for output in output_details: |
| 44 | results.append(interpreter.get_tensor(output['index'])) |
| 45 | |
| 46 | return results |
| 47 | |
| 48 | def compare_outputs(outputs, expected_outputs): |
| 49 | assert len(outputs) == len(expected_outputs), 'Incorrect number of outputs' |
| 50 | for i in range(len(expected_outputs)): |
| 51 | assert outputs[i].shape == expected_outputs[i].shape, 'Incorrect output shape on output#{}'.format(i) |
| 52 | assert outputs[i].dtype == expected_outputs[i].dtype, 'Incorrect output data type on output#{}'.format(i) |
| 53 | assert outputs[i].all() == expected_outputs[i].all(), 'Incorrect output value on output#{}'.format(i) |