blob: fc8e92113fe9af0718883100cb1932f75dd5ad7e [file] [log] [blame]
Kshitij Sisodiaf9efe0d2022-09-30 16:42:50 +01001#
2# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
3# SPDX-License-Identifier: Apache-2.0
4#
5import pytest
6import os
7import ethosu_driver as driver
8from ethosu_driver.inference_runner import read_npy_file_to_buf
9
10
11@pytest.fixture()
12def device(device_name):
13 device = driver.open_device(device_name)
14 yield device
15
16
17@pytest.fixture()
18def network(device, model_name, shared_data_folder):
19 network_file = os.path.join(shared_data_folder, model_name)
20 network = driver.load_model(device, network_file)
21 yield network
22
23
24@pytest.mark.parametrize('device_name', ['blabla'])
25def test_open_device_wrong_name(device_name):
26 with pytest.raises(RuntimeError) as err:
27 device = driver.open_device(device_name)
28 # Only check for part of the exception since the exception returns
29 # absolute path which will change on different machines.
30 assert 'Failed to open device' in str(err.value)
31
32
33@pytest.mark.parametrize('device_name', ['ethosu0'])
34def test_network_filenotfound_exception(device, shared_data_folder):
35
36 network_file = os.path.join(shared_data_folder, "some_unknown_model.tflite")
37
38 with pytest.raises(RuntimeError) as err:
39 driver.load_model(device, network_file)
40
41 # Only check for part of the exception since the exception returns
42 # absolute path which will change on different machines.
43 assert 'Failed to open file:' in str(err.value)
44
45
46@pytest.mark.parametrize('device_name', ['ethosu0'])
47@pytest.mark.parametrize('model_name', ['model.tflite'])
48def test_check_network_ifm_size(network):
49 assert network.getIfmSize() > 0
50
51
52@pytest.mark.parametrize('device_name', ['ethosu0'])
53def test_allocate_buffers(device):
54 buffers = driver.allocate_buffers(device, [128, 256])
55 assert len(buffers) == 2
56 assert buffers[0].size() == 0
57 assert buffers[0].capacity() == 128
58 assert buffers[1].size() == 0
59 assert buffers[1].capacity() == 256
60
61
62@pytest.mark.parametrize('device_name', ['ethosu0'])
63@pytest.mark.parametrize('model_name', ['model.tflite'])
64@pytest.mark.parametrize('ifms_file_list', [['model_ifm.npy']])
65def test_set_ifm_buffers(device, network, ifms_file_list, shared_data_folder):
66 full_path_input_files = []
67 for input_file in ifms_file_list:
68 full_path_input_files.append(os.path.join(shared_data_folder, input_file))
69
70 ifms_data = []
71 for ifm_file in full_path_input_files:
72 ifms_data.append(read_npy_file_to_buf(ifm_file))
73
74 ifms = driver.allocate_buffers(device, network.getIfmDims())
75 driver.populate_buffers(ifms_data, ifms)
76 assert len(ifms) > 0
77