Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 1 | # |
Mikael Olsson | 308e7f1 | 2023-06-12 15:00:55 +0200 | [diff] [blame] | 2 | # SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com> |
Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | # |
| 5 | import pytest |
| 6 | import os |
| 7 | import ethosu_driver as driver |
| 8 | from ethosu_driver.inference_runner import read_npy_file_to_buf |
| 9 | |
| 10 | |
| 11 | @pytest.fixture() |
| 12 | def device(device_name): |
| 13 | device = driver.Device("/dev/{}".format(device_name)) |
| 14 | yield device |
| 15 | |
| 16 | |
| 17 | @pytest.fixture() |
Mikael Olsson | c081e59 | 2023-10-30 11:10:56 +0100 | [diff] [blame^] | 18 | def network_file(model_name, shared_data_folder): |
Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 19 | network_file = os.path.join(shared_data_folder, model_name) |
Mikael Olsson | c081e59 | 2023-10-30 11:10:56 +0100 | [diff] [blame^] | 20 | yield network_file |
Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 21 | |
Mikael Olsson | c081e59 | 2023-10-30 11:10:56 +0100 | [diff] [blame^] | 22 | @pytest.fixture() |
| 23 | def network(device, network_file): |
| 24 | network = driver.Network(device, network_file) |
| 25 | yield network |
Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 26 | |
| 27 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
| 28 | def test_check_device_swig_ownership(device): |
| 29 | # Check to see that SWIG has ownership for parser. This instructs SWIG to take |
| 30 | # ownership of the return value. This allows the value to be automatically |
| 31 | # garbage-collected when it is no longer in use |
| 32 | assert device.thisown |
| 33 | |
| 34 | |
| 35 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
| 36 | def test_device_ping(device): |
| 37 | device.ping() |
| 38 | |
| 39 | |
| 40 | @pytest.mark.parametrize('device_name', ['blabla']) |
| 41 | def test_device_wrong_name(device_name): |
| 42 | with pytest.raises(RuntimeError) as err: |
| 43 | driver.Device("/dev/{}".format(device_name)) |
| 44 | # Only check for part of the exception since the exception returns |
| 45 | # absolute path which will change on different machines. |
| 46 | assert 'Failed to open device' in str(err.value) |
| 47 | |
| 48 | |
| 49 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
Mikael Olsson | c081e59 | 2023-10-30 11:10:56 +0100 | [diff] [blame^] | 50 | @pytest.mark.parametrize('model_name', ['model.tflite']) |
| 51 | def test_driver_network_from_bytearray(device, network_file): |
| 52 | network_data = None |
| 53 | with open(network_file, 'rb') as file: |
| 54 | network_data = file.read() |
| 55 | network = driver.Network(device, network_data) |
Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 56 | |
Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 57 | |
Mikael Olsson | c081e59 | 2023-10-30 11:10:56 +0100 | [diff] [blame^] | 58 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
| 59 | def test_driver_network_from_empty_bytearray(device): |
Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 60 | with pytest.raises(RuntimeError) as err: |
Mikael Olsson | c081e59 | 2023-10-30 11:10:56 +0100 | [diff] [blame^] | 61 | network = driver.Network(device, bytearray()) |
| 62 | |
| 63 | assert 'Failed to create the network, networkSize is zero' in str(err.value) |
| 64 | |
| 65 | |
| 66 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
| 67 | @pytest.mark.parametrize('model_name', ['model.tflite']) |
| 68 | def test_driver_network_from_file(device, network_file): |
| 69 | network = driver.Network(device, network_file) |
| 70 | |
| 71 | |
| 72 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
| 73 | @pytest.mark.parametrize('model_name', ['some_unknown_model.tflite']) |
| 74 | def test_driver_network_filenotfound_exception(device, network_file): |
| 75 | with pytest.raises(RuntimeError) as err: |
| 76 | network = driver.Network(device, network_file) |
Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 77 | |
| 78 | # Only check for part of the exception since the exception returns |
| 79 | # absolute path which will change on different machines. |
| 80 | assert 'Failed to open file:' in str(err.value) |
| 81 | |
| 82 | |
| 83 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
| 84 | @pytest.mark.parametrize('model_name', ['model.tflite']) |
Mikael Olsson | c081e59 | 2023-10-30 11:10:56 +0100 | [diff] [blame^] | 85 | def test_check_network_swig_ownership(network): |
Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 86 | # Check to see that SWIG has ownership for parser. This instructs SWIG to take |
| 87 | # ownership of the return value. This allows the value to be automatically |
| 88 | # garbage-collected when it is no longer in use |
Mikael Olsson | c081e59 | 2023-10-30 11:10:56 +0100 | [diff] [blame^] | 89 | assert network.thisown |
Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 90 | |
| 91 | |
| 92 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
| 93 | @pytest.mark.parametrize('model_name', ['model.tflite']) |
Mikael Olsson | c081e59 | 2023-10-30 11:10:56 +0100 | [diff] [blame^] | 94 | def test_check_network_ifm_size(device, network): |
Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 95 | assert network.getIfmSize() > 0 |
Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 96 | |
| 97 | |
| 98 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
| 99 | @pytest.mark.parametrize('model_name', ['model.tflite']) |
Mikael Olsson | c081e59 | 2023-10-30 11:10:56 +0100 | [diff] [blame^] | 100 | def test_check_network_ofm_size(device, network): |
Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 101 | assert network.getOfmSize() > 0 |
| 102 | |
| 103 | |
Mikael Olsson | c081e59 | 2023-10-30 11:10:56 +0100 | [diff] [blame^] | 104 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
| 105 | def test_check_buffer_swig_ownership(device): |
| 106 | buffer = driver.Buffer(device, 1024) |
| 107 | assert buffer.thisown |
| 108 | |
| 109 | |
| 110 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
| 111 | def test_check_buffer_getFd(device): |
| 112 | buffer = driver.Buffer(device, 1024) |
| 113 | assert buffer.getFd() >= 0 |
| 114 | |
| 115 | |
| 116 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
| 117 | def test_check_buffer_size(device): |
| 118 | buffer = driver.Buffer(device, 1024) |
| 119 | assert buffer.size() == 1024 |
| 120 | |
| 121 | |
| 122 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
| 123 | @pytest.mark.parametrize('model_name', ['model.tflite']) |
| 124 | def test_check_buffer_clear(device, network_file): |
| 125 | buffer = driver.Buffer(device, network_file) |
| 126 | |
| 127 | buffer.clear() |
| 128 | for i in range(buffer.size()): |
| 129 | assert buffer.data()[i] == 0 |
| 130 | |
| 131 | |
Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 132 | def test_getMaxPmuEventCounters(): |
| 133 | assert driver.Inference.getMaxPmuEventCounters() > 0 |
| 134 | |
| 135 | |
| 136 | @pytest.fixture() |
| 137 | def inf(device_name, model_name, input_files, timeout, shared_data_folder): |
| 138 | # Prepate full path of model and inputs |
| 139 | full_path_model_file = os.path.join(shared_data_folder, model_name) |
| 140 | full_path_input_files = [] |
| 141 | for input_file in input_files: |
| 142 | full_path_input_files.append(os.path.join(shared_data_folder, input_file)) |
| 143 | |
| 144 | ifms_data = [] |
| 145 | for ifm_file in full_path_input_files: |
| 146 | ifms_data.append(read_npy_file_to_buf(ifm_file)) |
| 147 | |
| 148 | device = driver.open_device(device_name) |
| 149 | device.ping() |
| 150 | network = driver.load_model(device, full_path_model_file) |
| 151 | ofms = driver.allocate_buffers(device, network.getOfmDims()) |
| 152 | ifms = driver.allocate_buffers(device, network.getIfmDims()) |
| 153 | |
| 154 | # ofm_buffers = runner.run(ifms_data,timeout, ethos_pmu_counters) |
| 155 | driver.populate_buffers(ifms_data, ifms) |
| 156 | ethos_pmu_counters = [1] |
| 157 | enable_cycle_counter = True |
| 158 | inf_inst = driver.Inference(network, ifms, ofms, ethos_pmu_counters, enable_cycle_counter) |
| 159 | inf_inst.wait(int(timeout)) |
| 160 | |
| 161 | yield inf_inst |
| 162 | |
| 163 | |
| 164 | @pytest.mark.parametrize('device_name, model_name, timeout, input_files', |
| 165 | [('ethosu0', 'model.tflite', 5000000000, ['model_ifm.npy'])]) |
| 166 | def test_inf_get_cycle_counter(inf): |
| 167 | total_cycles = inf.getCycleCounter() |
| 168 | assert total_cycles >= 0 |
| 169 | |
| 170 | |
| 171 | @pytest.mark.parametrize('device_name, model_name, timeout, input_files', |
| 172 | [('ethosu0', 'model.tflite', 5000000000, ['model_ifm.npy'])]) |
| 173 | def test_inf_get_pmu_counters(inf): |
| 174 | inf_pmu_counter = inf.getPmuCounters() |
| 175 | assert len(inf_pmu_counter) > 0 |
| 176 | |
| 177 | |
| 178 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
| 179 | def test_capabilities(device): |
| 180 | cap = device.capabilities() |
| 181 | assert cap.hwId |
| 182 | assert cap.hwCfg |
| 183 | assert cap.driver |
Mikael Olsson | 308e7f1 | 2023-06-12 15:00:55 +0200 | [diff] [blame] | 184 | |
| 185 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
| 186 | def test_kernel_driver_version(device): |
| 187 | version = device.getDriverVersion() |
| 188 | zero_version = [0, 0, 0] |
| 189 | # Validate that a version was returned |
| 190 | assert zero_version != [version.major, version.minor, version.patch] |
| 191 | # Check that supported kernel driver major versions are available in Python API |
| 192 | assert driver.MAX_SUPPORTED_KERNEL_DRIVER_MAJOR_VERSION |
| 193 | assert driver.MIN_SUPPORTED_KERNEL_DRIVER_MAJOR_VERSION |
| 194 | |
Mikael Olsson | e9c3f07 | 2023-06-12 15:58:10 +0200 | [diff] [blame] | 195 | def test_driver_library_version(): |
| 196 | version = driver.getLibraryVersion() |
| 197 | expected_version = [driver.DRIVER_LIBRARY_VERSION_MAJOR, |
| 198 | driver.DRIVER_LIBRARY_VERSION_MINOR, |
| 199 | driver.DRIVER_LIBRARY_VERSION_PATCH] |
| 200 | # Validate that the expected version was returned |
| 201 | assert expected_version == [version.major, version.minor, version.patch] |