blob: f3761ec8a1149f8727f45f0c58257de535375e5b [file] [log] [blame]
Jan Eilers2cd18472020-12-15 10:42:38 +00001# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4import tflite_runtime.interpreter as tflite
5import numpy as np
6import os
7
8
9def run_mock_model(delegate, test_data_folder):
10 model_path = os.path.join(test_data_folder, 'mock_model.tflite')
11 interpreter = tflite.Interpreter(model_path=model_path,
12 experimental_delegates=[delegate])
13 interpreter.allocate_tensors()
14
15 # Get input and output tensors.
16 input_details = interpreter.get_input_details()
17 output_details = interpreter.get_output_details()
18
19 # Test model on random input data.
20 input_shape = input_details[0]['shape']
21 input_data = np.array(np.random.random_sample(input_shape), dtype=np.uint8)
22 interpreter.set_tensor(input_details[0]['index'], input_data)
23
Narumol Prangnawarat74a3cf52021-01-29 15:38:54 +000024 interpreter.invoke()
25
26def run_inference(test_data_folder, model_filename, inputs, delegates=None):
27 model_path = os.path.join(test_data_folder, model_filename)
28 interpreter = tflite.Interpreter(model_path=model_path,
29 experimental_delegates=delegates)
30 interpreter.allocate_tensors()
31
32 # Get input and output tensors.
33 input_details = interpreter.get_input_details()
34 output_details = interpreter.get_output_details()
35
36 # Set inputs to tensors.
37 for i in range(len(inputs)):
38 interpreter.set_tensor(input_details[i]['index'], inputs[i])
39
40 interpreter.invoke()
41
42 results = []
43 for output in output_details:
44 results.append(interpreter.get_tensor(output['index']))
45
46 return results
47
48def compare_outputs(outputs, expected_outputs):
49 assert len(outputs) == len(expected_outputs), 'Incorrect number of outputs'
50 for i in range(len(expected_outputs)):
51 assert outputs[i].shape == expected_outputs[i].shape, 'Incorrect output shape on output#{}'.format(i)
52 assert outputs[i].dtype == expected_outputs[i].dtype, 'Incorrect output data type on output#{}'.format(i)
53 assert outputs[i].all() == expected_outputs[i].all(), 'Incorrect output value on output#{}'.format(i)