Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 1 | """Tests for tosa_result_checker.py.""" |
| 2 | # Copyright (c) 2021-2022, ARM Limited. |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | from pathlib import Path |
| 5 | |
Jeremy Johnson | 5c1364c | 2022-01-13 15:04:21 +0000 | [diff] [blame] | 6 | import checker.tosa_result_checker as trc |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 7 | import numpy as np |
| 8 | import pytest |
| 9 | |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 10 | |
| 11 | def _create_data_file(name, npy_data): |
| 12 | """Create numpy data file.""" |
| 13 | file = Path(__file__).parent / name |
| 14 | with open(file, "wb") as f: |
| 15 | np.save(f, npy_data) |
| 16 | return file |
| 17 | |
| 18 | |
| 19 | def _create_empty_file(name): |
| 20 | """Create numpy data file.""" |
| 21 | file = Path(__file__).parent / name |
| 22 | f = open(file, "wb") |
| 23 | f.close() |
| 24 | return file |
| 25 | |
| 26 | |
| 27 | def _delete_data_file(file: Path): |
| 28 | """Delete numpy data file.""" |
| 29 | file.unlink() |
| 30 | |
| 31 | |
| 32 | @pytest.mark.parametrize( |
| 33 | "data_type,expected", |
| 34 | [ |
Jeremy Johnson | 72dcab7 | 2023-10-30 10:28:21 +0000 | [diff] [blame^] | 35 | (np.int8, trc.TestResult.PASS), |
| 36 | (np.int16, trc.TestResult.PASS), |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 37 | (np.int32, trc.TestResult.PASS), |
| 38 | (np.int64, trc.TestResult.PASS), |
Jeremy Johnson | 72dcab7 | 2023-10-30 10:28:21 +0000 | [diff] [blame^] | 39 | (np.uint8, trc.TestResult.PASS), |
| 40 | (np.uint16, trc.TestResult.PASS), |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 41 | (np.uint32, trc.TestResult.MISMATCH), |
| 42 | (np.uint64, trc.TestResult.MISMATCH), |
James Ward | 8b39043 | 2022-08-12 20:48:56 +0100 | [diff] [blame] | 43 | (np.float16, trc.TestResult.PASS), |
Jeremy Johnson | be1a940 | 2021-12-15 17:14:56 +0000 | [diff] [blame] | 44 | (np.float32, trc.TestResult.PASS), |
| 45 | (np.float64, trc.TestResult.MISMATCH), |
| 46 | (bool, trc.TestResult.PASS), |
| 47 | ], |
| 48 | ) |
| 49 | def test_supported_types(data_type, expected): |
| 50 | """Check which data types are supported.""" |
| 51 | # Generate data |
| 52 | npy_data = np.ndarray(shape=(2, 3), dtype=data_type) |
| 53 | |
| 54 | # Save data as reference and result files to compare. |
| 55 | reference_file = _create_data_file("reference.npy", npy_data) |
| 56 | result_file = _create_data_file("result.npy", npy_data) |
| 57 | |
| 58 | args = [str(reference_file), str(result_file)] |
| 59 | """Compares reference and result npy files, returns zero if it passes.""" |
| 60 | assert trc.main(args) == expected |
| 61 | |
| 62 | # Remove files created |
| 63 | _delete_data_file(reference_file) |
| 64 | _delete_data_file(result_file) |
| 65 | |
| 66 | |
| 67 | @pytest.mark.parametrize( |
| 68 | "data_type,expected", |
| 69 | [ |
| 70 | (np.int32, trc.TestResult.MISMATCH), |
| 71 | (np.int64, trc.TestResult.MISMATCH), |
| 72 | (np.float32, trc.TestResult.MISMATCH), |
| 73 | (bool, trc.TestResult.MISMATCH), |
| 74 | ], |
| 75 | ) |
| 76 | def test_shape_mismatch(data_type, expected): |
| 77 | """Check that mismatch shapes do not pass.""" |
| 78 | # Generate and save data as reference and result files to compare. |
| 79 | npy_data = np.ones(shape=(3, 2), dtype=data_type) |
| 80 | reference_file = _create_data_file("reference.npy", npy_data) |
| 81 | npy_data = np.ones(shape=(2, 3), dtype=data_type) |
| 82 | result_file = _create_data_file("result.npy", npy_data) |
| 83 | |
| 84 | args = [str(reference_file), str(result_file)] |
| 85 | """Compares reference and result npy files, returns zero if it passes.""" |
| 86 | assert trc.main(args) == expected |
| 87 | |
| 88 | # Remove files created |
| 89 | _delete_data_file(reference_file) |
| 90 | _delete_data_file(result_file) |
| 91 | |
| 92 | |
| 93 | @pytest.mark.parametrize( |
| 94 | "data_type,expected", |
| 95 | [ |
| 96 | (np.int32, trc.TestResult.MISMATCH), |
| 97 | (np.int64, trc.TestResult.MISMATCH), |
| 98 | (np.float32, trc.TestResult.MISMATCH), |
| 99 | (bool, trc.TestResult.MISMATCH), |
| 100 | ], |
| 101 | ) |
| 102 | def test_results_mismatch(data_type, expected): |
| 103 | """Check that different results do not pass.""" |
| 104 | # Generate and save data as reference and result files to compare. |
| 105 | npy_data = np.zeros(shape=(2, 3), dtype=data_type) |
| 106 | reference_file = _create_data_file("reference.npy", npy_data) |
| 107 | npy_data = np.ones(shape=(2, 3), dtype=data_type) |
| 108 | result_file = _create_data_file("result.npy", npy_data) |
| 109 | |
| 110 | args = [str(reference_file), str(result_file)] |
| 111 | """Compares reference and result npy files, returns zero if it passes.""" |
| 112 | assert trc.main(args) == expected |
| 113 | |
| 114 | # Remove files created |
| 115 | _delete_data_file(reference_file) |
| 116 | _delete_data_file(result_file) |
| 117 | |
| 118 | |
| 119 | @pytest.mark.parametrize( |
| 120 | "data_type1,data_type2,expected", |
| 121 | [ # Pairwise testing of all supported types |
| 122 | (np.int32, np.int64, trc.TestResult.MISMATCH), |
| 123 | (bool, np.float32, trc.TestResult.MISMATCH), |
| 124 | ], |
| 125 | ) |
| 126 | def test_types_mismatch(data_type1, data_type2, expected): |
| 127 | """Check that different types in results do not pass.""" |
| 128 | # Generate and save data as reference and result files to compare. |
| 129 | npy_data = np.ones(shape=(3, 2), dtype=data_type1) |
| 130 | reference_file = _create_data_file("reference.npy", npy_data) |
| 131 | npy_data = np.ones(shape=(3, 2), dtype=data_type2) |
| 132 | result_file = _create_data_file("result.npy", npy_data) |
| 133 | |
| 134 | args = [str(reference_file), str(result_file)] |
| 135 | """Compares reference and result npy files, returns zero if it passes.""" |
| 136 | assert trc.main(args) == expected |
| 137 | |
| 138 | # Remove files created |
| 139 | _delete_data_file(reference_file) |
| 140 | _delete_data_file(result_file) |
| 141 | |
| 142 | |
| 143 | @pytest.mark.parametrize( |
| 144 | "reference_exists,result_exists,expected", |
| 145 | [ |
| 146 | (True, False, trc.TestResult.MISSING_FILE), |
| 147 | (False, True, trc.TestResult.MISSING_FILE), |
| 148 | ], |
| 149 | ) |
| 150 | def test_missing_files(reference_exists, result_exists, expected): |
| 151 | """Check that missing files are caught.""" |
| 152 | # Generate and save data |
| 153 | npy_data = np.ndarray(shape=(2, 3), dtype=bool) |
| 154 | reference_file = _create_data_file("reference.npy", npy_data) |
| 155 | result_file = _create_data_file("result.npy", npy_data) |
| 156 | if not reference_exists: |
| 157 | _delete_data_file(reference_file) |
| 158 | if not result_exists: |
| 159 | _delete_data_file(result_file) |
| 160 | |
| 161 | args = [str(reference_file), str(result_file)] |
| 162 | assert trc.main(args) == expected |
| 163 | |
| 164 | if reference_exists: |
| 165 | _delete_data_file(reference_file) |
| 166 | if result_exists: |
| 167 | _delete_data_file(result_file) |
| 168 | |
| 169 | |
| 170 | @pytest.mark.parametrize( |
| 171 | "reference_numpy,result_numpy,expected", |
| 172 | [ |
| 173 | (True, False, trc.TestResult.INCORRECT_FORMAT), |
| 174 | (False, True, trc.TestResult.INCORRECT_FORMAT), |
| 175 | ], |
| 176 | ) |
| 177 | def test_incorrect_format_files(reference_numpy, result_numpy, expected): |
| 178 | """Check that incorrect format files are caught.""" |
| 179 | # Generate and save data |
| 180 | npy_data = np.ndarray(shape=(2, 3), dtype=bool) |
| 181 | reference_file = ( |
| 182 | _create_data_file("reference.npy", npy_data) |
| 183 | if reference_numpy |
| 184 | else _create_empty_file("empty.npy") |
| 185 | ) |
| 186 | result_file = ( |
| 187 | _create_data_file("result.npy", npy_data) |
| 188 | if result_numpy |
| 189 | else _create_empty_file("empty.npy") |
| 190 | ) |
| 191 | |
| 192 | args = [str(reference_file), str(result_file)] |
| 193 | assert trc.main(args) == expected |
| 194 | |
| 195 | _delete_data_file(reference_file) |
| 196 | _delete_data_file(result_file) |