#
# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
import pytest
import os
import ethosu_driver as driver
from ethosu_driver.inference_runner import read_npy_file_to_buf


@pytest.fixture()
def device(device_name):
    device = driver.Device("/dev/{}".format(device_name))
    yield device


@pytest.fixture()
def network_buffer(device, model_name, shared_data_folder):
    network_file = os.path.join(shared_data_folder, model_name)
    network_buffer = driver.Buffer(device, network_file)
    yield network_buffer


@pytest.mark.parametrize('device_name', ['ethosu0'])
def test_check_device_swig_ownership(device):
    # Check to see that SWIG has ownership for parser. This instructs SWIG to take
    # ownership of the return value. This allows the value to be automatically
    # garbage-collected when it is no longer in use
    assert device.thisown


@pytest.mark.parametrize('device_name', ['ethosu0'])
def test_device_ping(device):
    device.ping()


@pytest.mark.parametrize('device_name', ['blabla'])
def test_device_wrong_name(device_name):
    with pytest.raises(RuntimeError) as err:
        driver.Device("/dev/{}".format(device_name))
    # Only check for part of the exception since the exception returns
    # absolute path which will change on different machines.
    assert 'Failed to open device' in str(err.value)


@pytest.mark.parametrize('device_name', ['ethosu0'])
def test_driver_network_filenotfound_exception(device, shared_data_folder):

    network_file = os.path.join(shared_data_folder, "some_unknown_model.tflite")

    with pytest.raises(RuntimeError) as err:
        network_buffer = driver.Buffer(device, network_file)

    # Only check for part of the exception since the exception returns
    # absolute path which will change on different machines.
    assert 'Failed to open file:' in str(err.value)


@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
def test_check_buffer_swig_ownership(network_buffer):
    # Check to see that SWIG has ownership for parser. This instructs SWIG to take
    # ownership of the return value. This allows the value to be automatically
    # garbage-collected when it is no longer in use
    assert network_buffer.thisown


@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
def test_check_buffer_capacity(network_buffer):
    assert network_buffer.capacity() > 0


@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
def test_check_buffer_size(network_buffer):
    assert network_buffer.size() > 0


@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
def test_check_buffer_clear(network_buffer):
    network_buffer.clear()
    assert network_buffer.size() == 0


@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
def test_check_buffer_resize(network_buffer):
    offset = 1
    new_size = network_buffer.capacity() - offset
    network_buffer.resize(new_size, offset)
    assert network_buffer.size() == new_size


@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
def test_check_buffer_getFd(network_buffer):
    assert network_buffer.getFd() >= 0


@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
def test_check_network_ifm_size(device, network_buffer):
    network = driver.Network(device, network_buffer)
    assert network.getIfmSize() > 0
    assert network_buffer.thisown


@pytest.mark.parametrize('device_name', [('ethosu0')])
def test_check_network_buffer_none(device):

    with pytest.raises(RuntimeError) as err:
        driver.Network(device, None)

    # Only check for part of the exception since the exception returns
    # absolute path which will change on different machines.
    assert 'Failed to create the network' in str(err.value)


@pytest.mark.parametrize('device_name', ['ethosu0'])
@pytest.mark.parametrize('model_name', ['model.tflite'])
def test_check_network_ofm_size(device, network_buffer):
    network = driver.Network(device, network_buffer)
    assert network.getOfmSize() > 0


def test_getMaxPmuEventCounters():
    assert driver.Inference.getMaxPmuEventCounters() > 0


@pytest.fixture()
def inf(device_name, model_name, input_files, timeout, shared_data_folder):
    # Prepate full path of model and inputs
    full_path_model_file = os.path.join(shared_data_folder, model_name)
    full_path_input_files = []
    for input_file in input_files:
        full_path_input_files.append(os.path.join(shared_data_folder, input_file))

    ifms_data = []
    for ifm_file in full_path_input_files:
        ifms_data.append(read_npy_file_to_buf(ifm_file))

    device = driver.open_device(device_name)
    device.ping()
    network = driver.load_model(device, full_path_model_file)
    ofms = driver.allocate_buffers(device, network.getOfmDims())
    ifms = driver.allocate_buffers(device, network.getIfmDims())

    # ofm_buffers = runner.run(ifms_data,timeout, ethos_pmu_counters)
    driver.populate_buffers(ifms_data, ifms)
    ethos_pmu_counters = [1]
    enable_cycle_counter = True
    inf_inst = driver.Inference(network, ifms, ofms, ethos_pmu_counters, enable_cycle_counter)
    inf_inst.wait(int(timeout))

    yield inf_inst


@pytest.mark.parametrize('device_name, model_name, timeout, input_files',
                        [('ethosu0', 'model.tflite', 5000000000, ['model_ifm.npy'])])
def test_inf_get_cycle_counter(inf):
    total_cycles = inf.getCycleCounter()
    assert total_cycles >= 0


@pytest.mark.parametrize('device_name, model_name, timeout, input_files',
                        [('ethosu0', 'model.tflite', 5000000000, ['model_ifm.npy'])])
def test_inf_get_pmu_counters(inf):
    inf_pmu_counter = inf.getPmuCounters()
    assert len(inf_pmu_counter) > 0


@pytest.mark.parametrize('device_name', ['ethosu0'])
def test_capabilities(device):
    cap = device.capabilities()
    assert cap.hwId
    assert cap.hwCfg
    assert cap.driver

@pytest.mark.parametrize('device_name', ['ethosu0'])
def test_kernel_driver_version(device):
    version = device.getDriverVersion()
    zero_version = [0, 0, 0]
    # Validate that a version was returned
    assert zero_version != [version.major, version.minor, version.patch]
    # Check that supported kernel driver major versions are available in Python API
    assert driver.MAX_SUPPORTED_KERNEL_DRIVER_MAJOR_VERSION
    assert driver.MIN_SUPPORTED_KERNEL_DRIVER_MAJOR_VERSION

