blob: e9cb5c8ae3543aba9d1d9fe57687f88bfe6190e9 [file] [log] [blame]
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +01001#
Mikael Olsson308e7f12023-06-12 15:00:55 +02002# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +01003# SPDX-License-Identifier: Apache-2.0
4#
5import pytest
6import os
7import ethosu_driver as driver
8from ethosu_driver.inference_runner import read_npy_file_to_buf
9
10
11@pytest.fixture()
12def device(device_name):
13 device = driver.Device("/dev/{}".format(device_name))
14 yield device
15
16
17@pytest.fixture()
Mikael Olssonc081e592023-10-30 11:10:56 +010018def network_file(model_name, shared_data_folder):
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +010019 network_file = os.path.join(shared_data_folder, model_name)
Mikael Olssonc081e592023-10-30 11:10:56 +010020 yield network_file
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +010021
Mikael Olssonc081e592023-10-30 11:10:56 +010022@pytest.fixture()
23def network(device, network_file):
24 network = driver.Network(device, network_file)
25 yield network
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +010026
27@pytest.mark.parametrize('device_name', ['ethosu0'])
28def test_check_device_swig_ownership(device):
29 # Check to see that SWIG has ownership for parser. This instructs SWIG to take
30 # ownership of the return value. This allows the value to be automatically
31 # garbage-collected when it is no longer in use
32 assert device.thisown
33
34
35@pytest.mark.parametrize('device_name', ['ethosu0'])
36def test_device_ping(device):
37 device.ping()
38
39
40@pytest.mark.parametrize('device_name', ['blabla'])
41def test_device_wrong_name(device_name):
42 with pytest.raises(RuntimeError) as err:
43 driver.Device("/dev/{}".format(device_name))
44 # Only check for part of the exception since the exception returns
45 # absolute path which will change on different machines.
46 assert 'Failed to open device' in str(err.value)
47
48
49@pytest.mark.parametrize('device_name', ['ethosu0'])
Mikael Olssonc081e592023-10-30 11:10:56 +010050@pytest.mark.parametrize('model_name', ['model.tflite'])
51def test_driver_network_from_bytearray(device, network_file):
52 network_data = None
53 with open(network_file, 'rb') as file:
54 network_data = file.read()
55 network = driver.Network(device, network_data)
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +010056
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +010057
Mikael Olssonc081e592023-10-30 11:10:56 +010058@pytest.mark.parametrize('device_name', ['ethosu0'])
59def test_driver_network_from_empty_bytearray(device):
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +010060 with pytest.raises(RuntimeError) as err:
Mikael Olssonc081e592023-10-30 11:10:56 +010061 network = driver.Network(device, bytearray())
62
63 assert 'Failed to create the network, networkSize is zero' in str(err.value)
64
65
66@pytest.mark.parametrize('device_name', ['ethosu0'])
67@pytest.mark.parametrize('model_name', ['model.tflite'])
68def test_driver_network_from_file(device, network_file):
69 network = driver.Network(device, network_file)
70
71
72@pytest.mark.parametrize('device_name', ['ethosu0'])
73@pytest.mark.parametrize('model_name', ['some_unknown_model.tflite'])
74def test_driver_network_filenotfound_exception(device, network_file):
75 with pytest.raises(RuntimeError) as err:
76 network = driver.Network(device, network_file)
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +010077
78 # Only check for part of the exception since the exception returns
79 # absolute path which will change on different machines.
80 assert 'Failed to open file:' in str(err.value)
81
82
83@pytest.mark.parametrize('device_name', ['ethosu0'])
84@pytest.mark.parametrize('model_name', ['model.tflite'])
Mikael Olssonc081e592023-10-30 11:10:56 +010085def test_check_network_swig_ownership(network):
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +010086 # Check to see that SWIG has ownership for parser. This instructs SWIG to take
87 # ownership of the return value. This allows the value to be automatically
88 # garbage-collected when it is no longer in use
Mikael Olssonc081e592023-10-30 11:10:56 +010089 assert network.thisown
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +010090
91
92@pytest.mark.parametrize('device_name', ['ethosu0'])
93@pytest.mark.parametrize('model_name', ['model.tflite'])
Mikael Olssonc081e592023-10-30 11:10:56 +010094def test_check_network_ifm_size(device, network):
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +010095 assert network.getIfmSize() > 0
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +010096
97
98@pytest.mark.parametrize('device_name', ['ethosu0'])
99@pytest.mark.parametrize('model_name', ['model.tflite'])
Mikael Olssonc081e592023-10-30 11:10:56 +0100100def test_check_network_ofm_size(device, network):
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +0100101 assert network.getOfmSize() > 0
102
103
Mikael Olssonc081e592023-10-30 11:10:56 +0100104@pytest.mark.parametrize('device_name', ['ethosu0'])
105def test_check_buffer_swig_ownership(device):
106 buffer = driver.Buffer(device, 1024)
107 assert buffer.thisown
108
109
110@pytest.mark.parametrize('device_name', ['ethosu0'])
111def test_check_buffer_getFd(device):
112 buffer = driver.Buffer(device, 1024)
113 assert buffer.getFd() >= 0
114
115
116@pytest.mark.parametrize('device_name', ['ethosu0'])
117def test_check_buffer_size(device):
118 buffer = driver.Buffer(device, 1024)
119 assert buffer.size() == 1024
120
121
122@pytest.mark.parametrize('device_name', ['ethosu0'])
123@pytest.mark.parametrize('model_name', ['model.tflite'])
124def test_check_buffer_clear(device, network_file):
125 buffer = driver.Buffer(device, network_file)
126
127 buffer.clear()
128 for i in range(buffer.size()):
129 assert buffer.data()[i] == 0
130
131
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +0100132def test_getMaxPmuEventCounters():
133 assert driver.Inference.getMaxPmuEventCounters() > 0
134
135
136@pytest.fixture()
137def inf(device_name, model_name, input_files, timeout, shared_data_folder):
138 # Prepate full path of model and inputs
139 full_path_model_file = os.path.join(shared_data_folder, model_name)
140 full_path_input_files = []
141 for input_file in input_files:
142 full_path_input_files.append(os.path.join(shared_data_folder, input_file))
143
144 ifms_data = []
145 for ifm_file in full_path_input_files:
146 ifms_data.append(read_npy_file_to_buf(ifm_file))
147
148 device = driver.open_device(device_name)
149 device.ping()
150 network = driver.load_model(device, full_path_model_file)
151 ofms = driver.allocate_buffers(device, network.getOfmDims())
152 ifms = driver.allocate_buffers(device, network.getIfmDims())
153
154 # ofm_buffers = runner.run(ifms_data,timeout, ethos_pmu_counters)
155 driver.populate_buffers(ifms_data, ifms)
156 ethos_pmu_counters = [1]
157 enable_cycle_counter = True
158 inf_inst = driver.Inference(network, ifms, ofms, ethos_pmu_counters, enable_cycle_counter)
159 inf_inst.wait(int(timeout))
160
161 yield inf_inst
162
163
164@pytest.mark.parametrize('device_name, model_name, timeout, input_files',
165 [('ethosu0', 'model.tflite', 5000000000, ['model_ifm.npy'])])
166def test_inf_get_cycle_counter(inf):
167 total_cycles = inf.getCycleCounter()
168 assert total_cycles >= 0
169
170
171@pytest.mark.parametrize('device_name, model_name, timeout, input_files',
172 [('ethosu0', 'model.tflite', 5000000000, ['model_ifm.npy'])])
173def test_inf_get_pmu_counters(inf):
174 inf_pmu_counter = inf.getPmuCounters()
175 assert len(inf_pmu_counter) > 0
176
177
178@pytest.mark.parametrize('device_name', ['ethosu0'])
179def test_capabilities(device):
180 cap = device.capabilities()
181 assert cap.hwId
182 assert cap.hwCfg
183 assert cap.driver
Mikael Olsson308e7f12023-06-12 15:00:55 +0200184
185@pytest.mark.parametrize('device_name', ['ethosu0'])
186def test_kernel_driver_version(device):
187 version = device.getDriverVersion()
188 zero_version = [0, 0, 0]
189 # Validate that a version was returned
190 assert zero_version != [version.major, version.minor, version.patch]
191 # Check that supported kernel driver major versions are available in Python API
192 assert driver.MAX_SUPPORTED_KERNEL_DRIVER_MAJOR_VERSION
193 assert driver.MIN_SUPPORTED_KERNEL_DRIVER_MAJOR_VERSION
194
Mikael Olssone9c3f072023-06-12 15:58:10 +0200195def test_driver_library_version():
196 version = driver.getLibraryVersion()
197 expected_version = [driver.DRIVER_LIBRARY_VERSION_MAJOR,
198 driver.DRIVER_LIBRARY_VERSION_MINOR,
199 driver.DRIVER_LIBRARY_VERSION_PATCH]
200 # Validate that the expected version was returned
201 assert expected_version == [version.major, version.minor, version.patch]