blob: bfff200e49431de963eb1f26b8e50ac434318e52 [file] [log] [blame]
Matthew Bentham245d64c2019-12-02 12:59:43 +00001# Copyright © 2019 Arm Ltd. All rights reserved.
2# SPDX-License-Identifier: MIT
3import os
4
5import pytest
6import pyarmnn as ann
7import numpy as np
8
9
10@pytest.fixture(scope="function")
11def get_tensor_info_input(shared_data_folder):
12 """
13 Sample input tensor information.
14 """
15 parser = ann.ITfLiteParser()
16 parser.CreateNetworkFromBinaryFile(os.path.join(shared_data_folder, 'ssd_mobilenetv1.tflite'))
17 graph_id = 0
18
19 input_binding_info = [parser.GetNetworkInputBindingInfo(graph_id, 'normalized_input_image_tensor')]
20
21 yield input_binding_info
22
23
24@pytest.fixture(scope="function")
25def get_tensor_info_output(shared_data_folder):
26 """
27 Sample output tensor information.
28 """
29 parser = ann.ITfLiteParser()
30 parser.CreateNetworkFromBinaryFile(os.path.join(shared_data_folder, 'ssd_mobilenetv1.tflite'))
31 graph_id = 0
32
33 output_names = parser.GetSubgraphOutputTensorNames(graph_id)
34 outputs_binding_info = []
35
36 for output_name in output_names:
37 outputs_binding_info.append(parser.GetNetworkOutputBindingInfo(graph_id, output_name))
38
39 yield outputs_binding_info
40
41
42def test_make_input_tensors(get_tensor_info_input):
43 input_tensor_info = get_tensor_info_input
44 input_data = []
45
46 for tensor_id, tensor_info in input_tensor_info:
47 input_data.append(np.random.randint(0, 255, size=(1, tensor_info.GetNumElements())).astype(np.uint8))
48
49 input_tensors = ann.make_input_tensors(input_tensor_info, input_data)
50 assert len(input_tensors) == 1
51
52 for tensor, tensor_info in zip(input_tensors, input_tensor_info):
53 # Because we created ConstTensor function, we cannot check type directly.
54 assert type(tensor[1]).__name__ == 'ConstTensor'
55 assert str(tensor[1].GetInfo()) == str(tensor_info[1])
56
57
58def test_make_output_tensors(get_tensor_info_output):
59 output_binding_info = get_tensor_info_output
60
61 output_tensors = ann.make_output_tensors(output_binding_info)
62 assert len(output_tensors) == 4
63
64 for tensor, tensor_info in zip(output_tensors, output_binding_info):
65 assert type(tensor[1]) == ann.Tensor
66 assert str(tensor[1].GetInfo()) == str(tensor_info[1])
67
68
69def test_workload_tensors_to_ndarray(get_tensor_info_output):
70 output_binding_info = get_tensor_info_output
71 output_tensors = ann.make_output_tensors(output_binding_info)
72
73 data = ann.workload_tensors_to_ndarray(output_tensors)
74
75 for i in range(0, len(output_tensors)):
76 assert len(data[i]) == output_tensors[i][1].GetNumElements()
77
78
79def test_make_input_tensors_fp16(get_tensor_info_input):
80 # Check ConstTensor with float16
81 input_tensor_info = get_tensor_info_input
82 input_data = []
83
84 for tensor_id, tensor_info in input_tensor_info:
85 input_data.append(np.random.randint(0, 255, size=(1, tensor_info.GetNumElements())).astype(np.float16))
86 tensor_info.SetDataType(ann.DataType_Float16) # set datatype to float16
87
88 input_tensors = ann.make_input_tensors(input_tensor_info, input_data)
89 assert len(input_tensors) == 1
90
91 for tensor, tensor_info in zip(input_tensors, input_tensor_info):
92 # Because we created ConstTensor function, we cannot check type directly.
93 assert type(tensor[1]).__name__ == 'ConstTensor'
94 assert str(tensor[1].GetInfo()) == str(tensor_info[1])
95 assert tensor[1].GetDataType() == ann.DataType_Float16
96 assert tensor[1].GetNumElements() == 270000
97 assert tensor[1].GetNumBytes() == 540000 # check each element is two byte