blob: d78d158f58bc7699f7b12af667306ad56ca32a82 [file] [log] [blame]
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00001"""Tests for tosa_result_checker.py."""
2# Copyright (c) 2021-2022, ARM Limited.
3# SPDX-License-Identifier: Apache-2.0
4from pathlib import Path
5
Jeremy Johnson5c1364c2022-01-13 15:04:21 +00006import checker.tosa_result_checker as trc
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00007import numpy as np
8import pytest
9
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000010
11def _create_data_file(name, npy_data):
12 """Create numpy data file."""
13 file = Path(__file__).parent / name
14 with open(file, "wb") as f:
15 np.save(f, npy_data)
16 return file
17
18
19def _create_empty_file(name):
20 """Create numpy data file."""
21 file = Path(__file__).parent / name
22 f = open(file, "wb")
23 f.close()
24 return file
25
26
27def _delete_data_file(file: Path):
28 """Delete numpy data file."""
29 file.unlink()
30
31
32@pytest.mark.parametrize(
33 "data_type,expected",
34 [
35 (np.int8, trc.TestResult.MISMATCH),
36 (np.int16, trc.TestResult.MISMATCH),
37 (np.int32, trc.TestResult.PASS),
38 (np.int64, trc.TestResult.PASS),
39 (np.uint8, trc.TestResult.MISMATCH),
40 (np.uint16, trc.TestResult.MISMATCH),
41 (np.uint32, trc.TestResult.MISMATCH),
42 (np.uint64, trc.TestResult.MISMATCH),
James Ward8b390432022-08-12 20:48:56 +010043 (np.float16, trc.TestResult.PASS),
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +000044 (np.float32, trc.TestResult.PASS),
45 (np.float64, trc.TestResult.MISMATCH),
46 (bool, trc.TestResult.PASS),
47 ],
48)
49def test_supported_types(data_type, expected):
50 """Check which data types are supported."""
51 # Generate data
52 npy_data = np.ndarray(shape=(2, 3), dtype=data_type)
53
54 # Save data as reference and result files to compare.
55 reference_file = _create_data_file("reference.npy", npy_data)
56 result_file = _create_data_file("result.npy", npy_data)
57
58 args = [str(reference_file), str(result_file)]
59 """Compares reference and result npy files, returns zero if it passes."""
60 assert trc.main(args) == expected
61
62 # Remove files created
63 _delete_data_file(reference_file)
64 _delete_data_file(result_file)
65
66
67@pytest.mark.parametrize(
68 "data_type,expected",
69 [
70 (np.int32, trc.TestResult.MISMATCH),
71 (np.int64, trc.TestResult.MISMATCH),
72 (np.float32, trc.TestResult.MISMATCH),
73 (bool, trc.TestResult.MISMATCH),
74 ],
75)
76def test_shape_mismatch(data_type, expected):
77 """Check that mismatch shapes do not pass."""
78 # Generate and save data as reference and result files to compare.
79 npy_data = np.ones(shape=(3, 2), dtype=data_type)
80 reference_file = _create_data_file("reference.npy", npy_data)
81 npy_data = np.ones(shape=(2, 3), dtype=data_type)
82 result_file = _create_data_file("result.npy", npy_data)
83
84 args = [str(reference_file), str(result_file)]
85 """Compares reference and result npy files, returns zero if it passes."""
86 assert trc.main(args) == expected
87
88 # Remove files created
89 _delete_data_file(reference_file)
90 _delete_data_file(result_file)
91
92
93@pytest.mark.parametrize(
94 "data_type,expected",
95 [
96 (np.int32, trc.TestResult.MISMATCH),
97 (np.int64, trc.TestResult.MISMATCH),
98 (np.float32, trc.TestResult.MISMATCH),
99 (bool, trc.TestResult.MISMATCH),
100 ],
101)
102def test_results_mismatch(data_type, expected):
103 """Check that different results do not pass."""
104 # Generate and save data as reference and result files to compare.
105 npy_data = np.zeros(shape=(2, 3), dtype=data_type)
106 reference_file = _create_data_file("reference.npy", npy_data)
107 npy_data = np.ones(shape=(2, 3), dtype=data_type)
108 result_file = _create_data_file("result.npy", npy_data)
109
110 args = [str(reference_file), str(result_file)]
111 """Compares reference and result npy files, returns zero if it passes."""
112 assert trc.main(args) == expected
113
114 # Remove files created
115 _delete_data_file(reference_file)
116 _delete_data_file(result_file)
117
118
119@pytest.mark.parametrize(
120 "data_type1,data_type2,expected",
121 [ # Pairwise testing of all supported types
122 (np.int32, np.int64, trc.TestResult.MISMATCH),
123 (bool, np.float32, trc.TestResult.MISMATCH),
124 ],
125)
126def test_types_mismatch(data_type1, data_type2, expected):
127 """Check that different types in results do not pass."""
128 # Generate and save data as reference and result files to compare.
129 npy_data = np.ones(shape=(3, 2), dtype=data_type1)
130 reference_file = _create_data_file("reference.npy", npy_data)
131 npy_data = np.ones(shape=(3, 2), dtype=data_type2)
132 result_file = _create_data_file("result.npy", npy_data)
133
134 args = [str(reference_file), str(result_file)]
135 """Compares reference and result npy files, returns zero if it passes."""
136 assert trc.main(args) == expected
137
138 # Remove files created
139 _delete_data_file(reference_file)
140 _delete_data_file(result_file)
141
142
143@pytest.mark.parametrize(
144 "reference_exists,result_exists,expected",
145 [
146 (True, False, trc.TestResult.MISSING_FILE),
147 (False, True, trc.TestResult.MISSING_FILE),
148 ],
149)
150def test_missing_files(reference_exists, result_exists, expected):
151 """Check that missing files are caught."""
152 # Generate and save data
153 npy_data = np.ndarray(shape=(2, 3), dtype=bool)
154 reference_file = _create_data_file("reference.npy", npy_data)
155 result_file = _create_data_file("result.npy", npy_data)
156 if not reference_exists:
157 _delete_data_file(reference_file)
158 if not result_exists:
159 _delete_data_file(result_file)
160
161 args = [str(reference_file), str(result_file)]
162 assert trc.main(args) == expected
163
164 if reference_exists:
165 _delete_data_file(reference_file)
166 if result_exists:
167 _delete_data_file(result_file)
168
169
170@pytest.mark.parametrize(
171 "reference_numpy,result_numpy,expected",
172 [
173 (True, False, trc.TestResult.INCORRECT_FORMAT),
174 (False, True, trc.TestResult.INCORRECT_FORMAT),
175 ],
176)
177def test_incorrect_format_files(reference_numpy, result_numpy, expected):
178 """Check that incorrect format files are caught."""
179 # Generate and save data
180 npy_data = np.ndarray(shape=(2, 3), dtype=bool)
181 reference_file = (
182 _create_data_file("reference.npy", npy_data)
183 if reference_numpy
184 else _create_empty_file("empty.npy")
185 )
186 result_file = (
187 _create_data_file("result.npy", npy_data)
188 if result_numpy
189 else _create_empty_file("empty.npy")
190 )
191
192 args = [str(reference_file), str(result_file)]
193 assert trc.main(args) == expected
194
195 _delete_data_file(reference_file)
196 _delete_data_file(result_file)