| """Tests for tosa_result_checker.py.""" |
| # Copyright (c) 2021-2022, ARM Limited. |
| # SPDX-License-Identifier: Apache-2.0 |
| from pathlib import Path |
| |
| import checker.tosa_result_checker as trc |
| import numpy as np |
| import pytest |
| |
| |
| def _create_data_file(name, npy_data): |
| """Create numpy data file.""" |
| file = Path(__file__).parent / name |
| with open(file, "wb") as f: |
| np.save(f, npy_data) |
| return file |
| |
| |
| def _create_empty_file(name): |
| """Create numpy data file.""" |
| file = Path(__file__).parent / name |
| f = open(file, "wb") |
| f.close() |
| return file |
| |
| |
| def _delete_data_file(file: Path): |
| """Delete numpy data file.""" |
| file.unlink() |
| |
| |
| @pytest.mark.parametrize( |
| "data_type,expected", |
| [ |
| (np.int8, trc.TestResult.PASS), |
| (np.int16, trc.TestResult.PASS), |
| (np.int32, trc.TestResult.PASS), |
| (np.int64, trc.TestResult.PASS), |
| (np.uint8, trc.TestResult.PASS), |
| (np.uint16, trc.TestResult.PASS), |
| (np.uint32, trc.TestResult.MISMATCH), |
| (np.uint64, trc.TestResult.MISMATCH), |
| (np.float16, trc.TestResult.PASS), |
| (np.float32, trc.TestResult.PASS), |
| (np.float64, trc.TestResult.MISMATCH), |
| (bool, trc.TestResult.PASS), |
| ], |
| ) |
| def test_supported_types(data_type, expected): |
| """Check which data types are supported.""" |
| # Generate data |
| npy_data = np.ndarray(shape=(2, 3), dtype=data_type) |
| |
| # Save data as reference and result files to compare. |
| reference_file = _create_data_file("reference.npy", npy_data) |
| result_file = _create_data_file("result.npy", npy_data) |
| |
| args = [str(reference_file), str(result_file)] |
| """Compares reference and result npy files, returns zero if it passes.""" |
| assert trc.main(args) == expected |
| |
| # Remove files created |
| _delete_data_file(reference_file) |
| _delete_data_file(result_file) |
| |
| |
| @pytest.mark.parametrize( |
| "data_type,expected", |
| [ |
| (np.int32, trc.TestResult.MISMATCH), |
| (np.int64, trc.TestResult.MISMATCH), |
| (np.float32, trc.TestResult.MISMATCH), |
| (bool, trc.TestResult.MISMATCH), |
| ], |
| ) |
| def test_shape_mismatch(data_type, expected): |
| """Check that mismatch shapes do not pass.""" |
| # Generate and save data as reference and result files to compare. |
| npy_data = np.ones(shape=(3, 2), dtype=data_type) |
| reference_file = _create_data_file("reference.npy", npy_data) |
| npy_data = np.ones(shape=(2, 3), dtype=data_type) |
| result_file = _create_data_file("result.npy", npy_data) |
| |
| args = [str(reference_file), str(result_file)] |
| """Compares reference and result npy files, returns zero if it passes.""" |
| assert trc.main(args) == expected |
| |
| # Remove files created |
| _delete_data_file(reference_file) |
| _delete_data_file(result_file) |
| |
| |
| @pytest.mark.parametrize( |
| "data_type,expected", |
| [ |
| (np.int32, trc.TestResult.MISMATCH), |
| (np.int64, trc.TestResult.MISMATCH), |
| (np.float32, trc.TestResult.MISMATCH), |
| (bool, trc.TestResult.MISMATCH), |
| ], |
| ) |
| def test_results_mismatch(data_type, expected): |
| """Check that different results do not pass.""" |
| # Generate and save data as reference and result files to compare. |
| npy_data = np.zeros(shape=(2, 3), dtype=data_type) |
| reference_file = _create_data_file("reference.npy", npy_data) |
| npy_data = np.ones(shape=(2, 3), dtype=data_type) |
| result_file = _create_data_file("result.npy", npy_data) |
| |
| args = [str(reference_file), str(result_file)] |
| """Compares reference and result npy files, returns zero if it passes.""" |
| assert trc.main(args) == expected |
| |
| # Remove files created |
| _delete_data_file(reference_file) |
| _delete_data_file(result_file) |
| |
| |
| @pytest.mark.parametrize( |
| "data_type1,data_type2,expected", |
| [ # Pairwise testing of all supported types |
| (np.int32, np.int64, trc.TestResult.MISMATCH), |
| (bool, np.float32, trc.TestResult.MISMATCH), |
| ], |
| ) |
| def test_types_mismatch(data_type1, data_type2, expected): |
| """Check that different types in results do not pass.""" |
| # Generate and save data as reference and result files to compare. |
| npy_data = np.ones(shape=(3, 2), dtype=data_type1) |
| reference_file = _create_data_file("reference.npy", npy_data) |
| npy_data = np.ones(shape=(3, 2), dtype=data_type2) |
| result_file = _create_data_file("result.npy", npy_data) |
| |
| args = [str(reference_file), str(result_file)] |
| """Compares reference and result npy files, returns zero if it passes.""" |
| assert trc.main(args) == expected |
| |
| # Remove files created |
| _delete_data_file(reference_file) |
| _delete_data_file(result_file) |
| |
| |
| @pytest.mark.parametrize( |
| "reference_exists,result_exists,expected", |
| [ |
| (True, False, trc.TestResult.MISSING_FILE), |
| (False, True, trc.TestResult.MISSING_FILE), |
| ], |
| ) |
| def test_missing_files(reference_exists, result_exists, expected): |
| """Check that missing files are caught.""" |
| # Generate and save data |
| npy_data = np.ndarray(shape=(2, 3), dtype=bool) |
| reference_file = _create_data_file("reference.npy", npy_data) |
| result_file = _create_data_file("result.npy", npy_data) |
| if not reference_exists: |
| _delete_data_file(reference_file) |
| if not result_exists: |
| _delete_data_file(result_file) |
| |
| args = [str(reference_file), str(result_file)] |
| assert trc.main(args) == expected |
| |
| if reference_exists: |
| _delete_data_file(reference_file) |
| if result_exists: |
| _delete_data_file(result_file) |
| |
| |
| @pytest.mark.parametrize( |
| "reference_numpy,result_numpy,expected", |
| [ |
| (True, False, trc.TestResult.INCORRECT_FORMAT), |
| (False, True, trc.TestResult.INCORRECT_FORMAT), |
| ], |
| ) |
| def test_incorrect_format_files(reference_numpy, result_numpy, expected): |
| """Check that incorrect format files are caught.""" |
| # Generate and save data |
| npy_data = np.ndarray(shape=(2, 3), dtype=bool) |
| reference_file = ( |
| _create_data_file("reference.npy", npy_data) |
| if reference_numpy |
| else _create_empty_file("empty.npy") |
| ) |
| result_file = ( |
| _create_data_file("result.npy", npy_data) |
| if result_numpy |
| else _create_empty_file("empty.npy") |
| ) |
| |
| args = [str(reference_file), str(result_file)] |
| assert trc.main(args) == expected |
| |
| _delete_data_file(reference_file) |
| _delete_data_file(result_file) |