blob: 5496aedb21ce1aebb5db12e152c3d5b3d98d18c6 [file] [log] [blame]
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +01001#
2# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
3# SPDX-License-Identifier: Apache-2.0
4#
5import pytest
6import os
7import ethosu_driver as driver
8from ethosu_driver.inference_runner import read_npy_file_to_buf
9
10
11@pytest.fixture()
12def device(device_name):
13 device = driver.Device("/dev/{}".format(device_name))
14 yield device
15
16
17@pytest.fixture()
18def network_buffer(device, model_name, shared_data_folder):
19 network_file = os.path.join(shared_data_folder, model_name)
20 network_buffer = driver.Buffer(device, network_file)
21 yield network_buffer
22
23
24@pytest.mark.parametrize('device_name', ['ethosu0'])
25def test_check_device_swig_ownership(device):
26 # Check to see that SWIG has ownership for parser. This instructs SWIG to take
27 # ownership of the return value. This allows the value to be automatically
28 # garbage-collected when it is no longer in use
29 assert device.thisown
30
31
32@pytest.mark.parametrize('device_name', ['ethosu0'])
33def test_device_ping(device):
34 device.ping()
35
36
37@pytest.mark.parametrize('device_name', ['blabla'])
38def test_device_wrong_name(device_name):
39 with pytest.raises(RuntimeError) as err:
40 driver.Device("/dev/{}".format(device_name))
41 # Only check for part of the exception since the exception returns
42 # absolute path which will change on different machines.
43 assert 'Failed to open device' in str(err.value)
44
45
46@pytest.mark.parametrize('device_name', ['ethosu0'])
47def test_driver_network_filenotfound_exception(device, shared_data_folder):
48
49 network_file = os.path.join(shared_data_folder, "some_unknown_model.tflite")
50
51 with pytest.raises(RuntimeError) as err:
52 network_buffer = driver.Buffer(device, network_file)
53
54 # Only check for part of the exception since the exception returns
55 # absolute path which will change on different machines.
56 assert 'Failed to open file:' in str(err.value)
57
58
59@pytest.mark.parametrize('device_name', ['ethosu0'])
60@pytest.mark.parametrize('model_name', ['model.tflite'])
61def test_check_buffer_swig_ownership(network_buffer):
62 # Check to see that SWIG has ownership for parser. This instructs SWIG to take
63 # ownership of the return value. This allows the value to be automatically
64 # garbage-collected when it is no longer in use
65 assert network_buffer.thisown
66
67
68@pytest.mark.parametrize('device_name', ['ethosu0'])
69@pytest.mark.parametrize('model_name', ['model.tflite'])
70def test_check_buffer_capacity(network_buffer):
71 assert network_buffer.capacity() > 0
72
73
74@pytest.mark.parametrize('device_name', ['ethosu0'])
75@pytest.mark.parametrize('model_name', ['model.tflite'])
76def test_check_buffer_size(network_buffer):
77 assert network_buffer.size() > 0
78
79
80@pytest.mark.parametrize('device_name', ['ethosu0'])
81@pytest.mark.parametrize('model_name', ['model.tflite'])
82def test_check_buffer_clear(network_buffer):
83 network_buffer.clear()
84 assert network_buffer.size() == 0
85
86
87@pytest.mark.parametrize('device_name', ['ethosu0'])
88@pytest.mark.parametrize('model_name', ['model.tflite'])
89def test_check_buffer_resize(network_buffer):
90 offset = 1
91 new_size = network_buffer.capacity() - offset
92 network_buffer.resize(new_size, offset)
93 assert network_buffer.size() == new_size
94
95
96@pytest.mark.parametrize('device_name', ['ethosu0'])
97@pytest.mark.parametrize('model_name', ['model.tflite'])
98def test_check_buffer_getFd(network_buffer):
99 assert network_buffer.getFd() >= 0
100
101
102@pytest.mark.parametrize('device_name', ['ethosu0'])
103@pytest.mark.parametrize('model_name', ['model.tflite'])
104def test_check_network_ifm_size(device, network_buffer):
105 network = driver.Network(device, network_buffer)
106 assert network.getIfmSize() > 0
107 assert network_buffer.thisown
108
109
110@pytest.mark.parametrize('device_name', [('ethosu0')])
111def test_check_network_buffer_none(device):
112
113 with pytest.raises(RuntimeError) as err:
114 driver.Network(device, None)
115
116 # Only check for part of the exception since the exception returns
117 # absolute path which will change on different machines.
118 assert 'Failed to create the network' in str(err.value)
119
120
121@pytest.mark.parametrize('device_name', ['ethosu0'])
122@pytest.mark.parametrize('model_name', ['model.tflite'])
123def test_check_network_ofm_size(device, network_buffer):
124 network = driver.Network(device, network_buffer)
125 assert network.getOfmSize() > 0
126
127
128def test_getMaxPmuEventCounters():
129 assert driver.Inference.getMaxPmuEventCounters() > 0
130
131
132@pytest.fixture()
133def inf(device_name, model_name, input_files, timeout, shared_data_folder):
134 # Prepate full path of model and inputs
135 full_path_model_file = os.path.join(shared_data_folder, model_name)
136 full_path_input_files = []
137 for input_file in input_files:
138 full_path_input_files.append(os.path.join(shared_data_folder, input_file))
139
140 ifms_data = []
141 for ifm_file in full_path_input_files:
142 ifms_data.append(read_npy_file_to_buf(ifm_file))
143
144 device = driver.open_device(device_name)
145 device.ping()
146 network = driver.load_model(device, full_path_model_file)
147 ofms = driver.allocate_buffers(device, network.getOfmDims())
148 ifms = driver.allocate_buffers(device, network.getIfmDims())
149
150 # ofm_buffers = runner.run(ifms_data,timeout, ethos_pmu_counters)
151 driver.populate_buffers(ifms_data, ifms)
152 ethos_pmu_counters = [1]
153 enable_cycle_counter = True
154 inf_inst = driver.Inference(network, ifms, ofms, ethos_pmu_counters, enable_cycle_counter)
155 inf_inst.wait(int(timeout))
156
157 yield inf_inst
158
159
160@pytest.mark.parametrize('device_name, model_name, timeout, input_files',
161 [('ethosu0', 'model.tflite', 5000000000, ['model_ifm.npy'])])
162def test_inf_get_cycle_counter(inf):
163 total_cycles = inf.getCycleCounter()
164 assert total_cycles >= 0
165
166
167@pytest.mark.parametrize('device_name, model_name, timeout, input_files',
168 [('ethosu0', 'model.tflite', 5000000000, ['model_ifm.npy'])])
169def test_inf_get_pmu_counters(inf):
170 inf_pmu_counter = inf.getPmuCounters()
171 assert len(inf_pmu_counter) > 0
172
173
174@pytest.mark.parametrize('device_name', ['ethosu0'])
175def test_capabilities(device):
176 cap = device.capabilities()
177 assert cap.hwId
178 assert cap.hwCfg
179 assert cap.driver