Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 1 | # |
Mikael Olsson | 0754515 | 2023-10-17 13:05:38 +0200 | [diff] [blame] | 2 | # SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com> |
Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | # |
| 5 | import pytest |
| 6 | import os |
| 7 | import ethosu_driver as driver |
| 8 | from ethosu_driver.inference_runner import read_npy_file_to_buf |
| 9 | |
| 10 | |
| 11 | @pytest.fixture() |
| 12 | def device(device_name): |
| 13 | device = driver.open_device(device_name) |
| 14 | yield device |
| 15 | |
| 16 | |
| 17 | @pytest.fixture() |
| 18 | def network(device, model_name, shared_data_folder): |
| 19 | network_file = os.path.join(shared_data_folder, model_name) |
| 20 | network = driver.load_model(device, network_file) |
| 21 | yield network |
| 22 | |
| 23 | |
| 24 | @pytest.mark.parametrize('device_name', ['blabla']) |
| 25 | def test_open_device_wrong_name(device_name): |
| 26 | with pytest.raises(RuntimeError) as err: |
| 27 | device = driver.open_device(device_name) |
| 28 | # Only check for part of the exception since the exception returns |
| 29 | # absolute path which will change on different machines. |
| 30 | assert 'Failed to open device' in str(err.value) |
| 31 | |
| 32 | |
| 33 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
| 34 | def test_network_filenotfound_exception(device, shared_data_folder): |
| 35 | |
| 36 | network_file = os.path.join(shared_data_folder, "some_unknown_model.tflite") |
| 37 | |
| 38 | with pytest.raises(RuntimeError) as err: |
| 39 | driver.load_model(device, network_file) |
| 40 | |
| 41 | # Only check for part of the exception since the exception returns |
| 42 | # absolute path which will change on different machines. |
| 43 | assert 'Failed to open file:' in str(err.value) |
| 44 | |
| 45 | |
| 46 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
| 47 | @pytest.mark.parametrize('model_name', ['model.tflite']) |
| 48 | def test_check_network_ifm_size(network): |
| 49 | assert network.getIfmSize() > 0 |
| 50 | |
| 51 | |
| 52 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
| 53 | def test_allocate_buffers(device): |
| 54 | buffers = driver.allocate_buffers(device, [128, 256]) |
| 55 | assert len(buffers) == 2 |
Mikael Olsson | 0754515 | 2023-10-17 13:05:38 +0200 | [diff] [blame] | 56 | assert buffers[0].size() == 128 |
| 57 | assert buffers[1].size() == 256 |
Kshitij Sisodia | f9efe0d | 2022-09-30 16:42:50 +0100 | [diff] [blame] | 58 | |
| 59 | |
| 60 | @pytest.mark.parametrize('device_name', ['ethosu0']) |
| 61 | @pytest.mark.parametrize('model_name', ['model.tflite']) |
| 62 | @pytest.mark.parametrize('ifms_file_list', [['model_ifm.npy']]) |
| 63 | def test_set_ifm_buffers(device, network, ifms_file_list, shared_data_folder): |
| 64 | full_path_input_files = [] |
| 65 | for input_file in ifms_file_list: |
| 66 | full_path_input_files.append(os.path.join(shared_data_folder, input_file)) |
| 67 | |
| 68 | ifms_data = [] |
| 69 | for ifm_file in full_path_input_files: |
| 70 | ifms_data.append(read_npy_file_to_buf(ifm_file)) |
| 71 | |
| 72 | ifms = driver.allocate_buffers(device, network.getIfmDims()) |
| 73 | driver.populate_buffers(ifms_data, ifms) |
| 74 | assert len(ifms) > 0 |
| 75 | |