blob: fe44b0ecf5aebde503abb77c9a71dd93dd893c95 [file] [log] [blame]
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +01001#
Mikael Olsson07545152023-10-17 13:05:38 +02002# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +01003# SPDX-License-Identifier: Apache-2.0
4#
5import pytest
6import os
7import ethosu_driver as driver
8from ethosu_driver.inference_runner import read_npy_file_to_buf
9
10
11@pytest.fixture()
12def device(device_name):
13 device = driver.open_device(device_name)
14 yield device
15
16
17@pytest.fixture()
18def network(device, model_name, shared_data_folder):
19 network_file = os.path.join(shared_data_folder, model_name)
20 network = driver.load_model(device, network_file)
21 yield network
22
23
24@pytest.mark.parametrize('device_name', ['blabla'])
25def test_open_device_wrong_name(device_name):
26 with pytest.raises(RuntimeError) as err:
27 device = driver.open_device(device_name)
28 # Only check for part of the exception since the exception returns
29 # absolute path which will change on different machines.
30 assert 'Failed to open device' in str(err.value)
31
32
33@pytest.mark.parametrize('device_name', ['ethosu0'])
34def test_network_filenotfound_exception(device, shared_data_folder):
35
36 network_file = os.path.join(shared_data_folder, "some_unknown_model.tflite")
37
38 with pytest.raises(RuntimeError) as err:
39 driver.load_model(device, network_file)
40
41 # Only check for part of the exception since the exception returns
42 # absolute path which will change on different machines.
43 assert 'Failed to open file:' in str(err.value)
44
45
46@pytest.mark.parametrize('device_name', ['ethosu0'])
47@pytest.mark.parametrize('model_name', ['model.tflite'])
48def test_check_network_ifm_size(network):
49 assert network.getIfmSize() > 0
50
51
52@pytest.mark.parametrize('device_name', ['ethosu0'])
53def test_allocate_buffers(device):
54 buffers = driver.allocate_buffers(device, [128, 256])
55 assert len(buffers) == 2
Mikael Olsson07545152023-10-17 13:05:38 +020056 assert buffers[0].size() == 128
57 assert buffers[1].size() == 256
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +010058
59
60@pytest.mark.parametrize('device_name', ['ethosu0'])
61@pytest.mark.parametrize('model_name', ['model.tflite'])
62@pytest.mark.parametrize('ifms_file_list', [['model_ifm.npy']])
63def test_set_ifm_buffers(device, network, ifms_file_list, shared_data_folder):
64 full_path_input_files = []
65 for input_file in ifms_file_list:
66 full_path_input_files.append(os.path.join(shared_data_folder, input_file))
67
68 ifms_data = []
69 for ifm_file in full_path_input_files:
70 ifms_data.append(read_npy_file_to_buf(ifm_file))
71
72 ifms = driver.allocate_buffers(device, network.getIfmDims())
73 driver.populate_buffers(ifms_data, ifms)
74 assert len(ifms) > 0
75