blob: 0dd207fa7d34a1bebae9834ea978987a111941b9 [file] [log] [blame]
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +01001#
Mikael Olsson308e7f12023-06-12 15:00:55 +02002# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +01003# SPDX-License-Identifier: Apache-2.0
4#
5import pytest
6import os
7import ethosu_driver as driver
8from ethosu_driver.inference_runner import read_npy_file_to_buf
9
10
11@pytest.fixture()
12def device(device_name):
13 device = driver.Device("/dev/{}".format(device_name))
14 yield device
15
16
17@pytest.fixture()
18def network_buffer(device, model_name, shared_data_folder):
19 network_file = os.path.join(shared_data_folder, model_name)
20 network_buffer = driver.Buffer(device, network_file)
21 yield network_buffer
22
23
24@pytest.mark.parametrize('device_name', ['ethosu0'])
25def test_check_device_swig_ownership(device):
26 # Check to see that SWIG has ownership for parser. This instructs SWIG to take
27 # ownership of the return value. This allows the value to be automatically
28 # garbage-collected when it is no longer in use
29 assert device.thisown
30
31
32@pytest.mark.parametrize('device_name', ['ethosu0'])
33def test_device_ping(device):
34 device.ping()
35
36
37@pytest.mark.parametrize('device_name', ['blabla'])
38def test_device_wrong_name(device_name):
39 with pytest.raises(RuntimeError) as err:
40 driver.Device("/dev/{}".format(device_name))
41 # Only check for part of the exception since the exception returns
42 # absolute path which will change on different machines.
43 assert 'Failed to open device' in str(err.value)
44
45
46@pytest.mark.parametrize('device_name', ['ethosu0'])
47def test_driver_network_filenotfound_exception(device, shared_data_folder):
48
49 network_file = os.path.join(shared_data_folder, "some_unknown_model.tflite")
50
51 with pytest.raises(RuntimeError) as err:
52 network_buffer = driver.Buffer(device, network_file)
53
54 # Only check for part of the exception since the exception returns
55 # absolute path which will change on different machines.
56 assert 'Failed to open file:' in str(err.value)
57
58
59@pytest.mark.parametrize('device_name', ['ethosu0'])
60@pytest.mark.parametrize('model_name', ['model.tflite'])
61def test_check_buffer_swig_ownership(network_buffer):
62 # Check to see that SWIG has ownership for parser. This instructs SWIG to take
63 # ownership of the return value. This allows the value to be automatically
64 # garbage-collected when it is no longer in use
65 assert network_buffer.thisown
66
67
68@pytest.mark.parametrize('device_name', ['ethosu0'])
69@pytest.mark.parametrize('model_name', ['model.tflite'])
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +010070def test_check_buffer_size(network_buffer):
71 assert network_buffer.size() > 0
72
73
74@pytest.mark.parametrize('device_name', ['ethosu0'])
75@pytest.mark.parametrize('model_name', ['model.tflite'])
76def test_check_buffer_clear(network_buffer):
77 network_buffer.clear()
Mikael Olsson07545152023-10-17 13:05:38 +020078 for i in range(network_buffer.size()):
79 assert network_buffer.data()[i] == 0
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +010080
81
82@pytest.mark.parametrize('device_name', ['ethosu0'])
83@pytest.mark.parametrize('model_name', ['model.tflite'])
84def test_check_buffer_getFd(network_buffer):
85 assert network_buffer.getFd() >= 0
86
87
88@pytest.mark.parametrize('device_name', ['ethosu0'])
89@pytest.mark.parametrize('model_name', ['model.tflite'])
90def test_check_network_ifm_size(device, network_buffer):
91 network = driver.Network(device, network_buffer)
92 assert network.getIfmSize() > 0
93 assert network_buffer.thisown
94
95
96@pytest.mark.parametrize('device_name', [('ethosu0')])
97def test_check_network_buffer_none(device):
98
99 with pytest.raises(RuntimeError) as err:
100 driver.Network(device, None)
101
102 # Only check for part of the exception since the exception returns
103 # absolute path which will change on different machines.
104 assert 'Failed to create the network' in str(err.value)
105
106
107@pytest.mark.parametrize('device_name', ['ethosu0'])
108@pytest.mark.parametrize('model_name', ['model.tflite'])
109def test_check_network_ofm_size(device, network_buffer):
110 network = driver.Network(device, network_buffer)
111 assert network.getOfmSize() > 0
112
113
114def test_getMaxPmuEventCounters():
115 assert driver.Inference.getMaxPmuEventCounters() > 0
116
117
118@pytest.fixture()
119def inf(device_name, model_name, input_files, timeout, shared_data_folder):
120 # Prepate full path of model and inputs
121 full_path_model_file = os.path.join(shared_data_folder, model_name)
122 full_path_input_files = []
123 for input_file in input_files:
124 full_path_input_files.append(os.path.join(shared_data_folder, input_file))
125
126 ifms_data = []
127 for ifm_file in full_path_input_files:
128 ifms_data.append(read_npy_file_to_buf(ifm_file))
129
130 device = driver.open_device(device_name)
131 device.ping()
132 network = driver.load_model(device, full_path_model_file)
133 ofms = driver.allocate_buffers(device, network.getOfmDims())
134 ifms = driver.allocate_buffers(device, network.getIfmDims())
135
136 # ofm_buffers = runner.run(ifms_data,timeout, ethos_pmu_counters)
137 driver.populate_buffers(ifms_data, ifms)
138 ethos_pmu_counters = [1]
139 enable_cycle_counter = True
140 inf_inst = driver.Inference(network, ifms, ofms, ethos_pmu_counters, enable_cycle_counter)
141 inf_inst.wait(int(timeout))
142
143 yield inf_inst
144
145
146@pytest.mark.parametrize('device_name, model_name, timeout, input_files',
147 [('ethosu0', 'model.tflite', 5000000000, ['model_ifm.npy'])])
148def test_inf_get_cycle_counter(inf):
149 total_cycles = inf.getCycleCounter()
150 assert total_cycles >= 0
151
152
153@pytest.mark.parametrize('device_name, model_name, timeout, input_files',
154 [('ethosu0', 'model.tflite', 5000000000, ['model_ifm.npy'])])
155def test_inf_get_pmu_counters(inf):
156 inf_pmu_counter = inf.getPmuCounters()
157 assert len(inf_pmu_counter) > 0
158
159
160@pytest.mark.parametrize('device_name', ['ethosu0'])
161def test_capabilities(device):
162 cap = device.capabilities()
163 assert cap.hwId
164 assert cap.hwCfg
165 assert cap.driver
Mikael Olsson308e7f12023-06-12 15:00:55 +0200166
167@pytest.mark.parametrize('device_name', ['ethosu0'])
168def test_kernel_driver_version(device):
169 version = device.getDriverVersion()
170 zero_version = [0, 0, 0]
171 # Validate that a version was returned
172 assert zero_version != [version.major, version.minor, version.patch]
173 # Check that supported kernel driver major versions are available in Python API
174 assert driver.MAX_SUPPORTED_KERNEL_DRIVER_MAJOR_VERSION
175 assert driver.MIN_SUPPORTED_KERNEL_DRIVER_MAJOR_VERSION
176
Mikael Olssone9c3f072023-06-12 15:58:10 +0200177def test_driver_library_version():
178 version = driver.getLibraryVersion()
179 expected_version = [driver.DRIVER_LIBRARY_VERSION_MAJOR,
180 driver.DRIVER_LIBRARY_VERSION_MINOR,
181 driver.DRIVER_LIBRARY_VERSION_PATCH]
182 # Validate that the expected version was returned
183 assert expected_version == [version.major, version.minor, version.patch]