blob: fe44b0ecf5aebde503abb77c9a71dd93dd893c95 [file] [log] [blame]
#
# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
import pytest
import os
import ethosu_driver as driver
from ethosu_driver.inference_runner import read_npy_file_to_buf
@pytest.fixture()
def device(device_name):
device = driver.open_device(device_name)
yield device
@pytest.fixture()
def network(device, model_name, shared_data_folder):
network_file = os.path.join(shared_data_folder, model_name)
network = driver.load_model(device, network_file)
yield network
@pytest.mark.parametrize('device_name', ['blabla'])
def test_open_device_wrong_name(device_name):
with pytest.raises(RuntimeError) as err:
device = driver.open_device(device_name)
# Only check for part of the exception since the exception returns
# absolute path which will change on different machines.
assert 'Failed to open device' in str(err.value)
@pytest.mark.parametrize('device_name', ['ethosu0'])
def test_network_filenotfound_exception(device, shared_data_folder):
network_file = os.path.join(shared_data_folder, "some_unknown_model.tflite")
with pytest.raises(RuntimeError) as err:
driver.load_model(device, network_file)
# Only check for part of the exception since the exception returns
# absolute path which will change on different machines.
assert 'Failed to open file:' in str(err.value)
@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
def test_check_network_ifm_size(network):
assert network.getIfmSize() > 0
@pytest.mark.parametrize('device_name', ['ethosu0'])
def test_allocate_buffers(device):
buffers = driver.allocate_buffers(device, [128, 256])
assert len(buffers) == 2
assert buffers[0].size() == 128
assert buffers[1].size() == 256
@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
@pytest.mark.parametrize('ifms_file_list', [['model_ifm.npy']])
def test_set_ifm_buffers(device, network, ifms_file_list, shared_data_folder):
full_path_input_files = []
for input_file in ifms_file_list:
full_path_input_files.append(os.path.join(shared_data_folder, input_file))
ifms_data = []
for ifm_file in full_path_input_files:
ifms_data.append(read_npy_file_to_buf(ifm_file))
ifms = driver.allocate_buffers(device, network.getIfmDims())
driver.populate_buffers(ifms_data, ifms)
assert len(ifms) > 0