blob: bc8a2fc260db6dd8870d041720f018ac4ce39b25 [file] [log] [blame]
Jeremy Johnsonbe1a9402021-12-15 17:14:56 +00001"""Tests for tosa_result_checker.py."""
2# Copyright (c) 2021-2022, ARM Limited.
3# SPDX-License-Identifier: Apache-2.0
4from pathlib import Path
5
6import numpy as np
7import pytest
8
9import checker.tosa_result_checker as trc
10
11
12def _create_data_file(name, npy_data):
13 """Create numpy data file."""
14 file = Path(__file__).parent / name
15 with open(file, "wb") as f:
16 np.save(f, npy_data)
17 return file
18
19
20def _create_empty_file(name):
21 """Create numpy data file."""
22 file = Path(__file__).parent / name
23 f = open(file, "wb")
24 f.close()
25 return file
26
27
28def _delete_data_file(file: Path):
29 """Delete numpy data file."""
30 file.unlink()
31
32
33@pytest.mark.parametrize(
34 "data_type,expected",
35 [
36 (np.int8, trc.TestResult.MISMATCH),
37 (np.int16, trc.TestResult.MISMATCH),
38 (np.int32, trc.TestResult.PASS),
39 (np.int64, trc.TestResult.PASS),
40 (np.uint8, trc.TestResult.MISMATCH),
41 (np.uint16, trc.TestResult.MISMATCH),
42 (np.uint32, trc.TestResult.MISMATCH),
43 (np.uint64, trc.TestResult.MISMATCH),
44 (np.float16, trc.TestResult.MISMATCH),
45 (np.float32, trc.TestResult.PASS),
46 (np.float64, trc.TestResult.MISMATCH),
47 (bool, trc.TestResult.PASS),
48 ],
49)
50def test_supported_types(data_type, expected):
51 """Check which data types are supported."""
52 # Generate data
53 npy_data = np.ndarray(shape=(2, 3), dtype=data_type)
54
55 # Save data as reference and result files to compare.
56 reference_file = _create_data_file("reference.npy", npy_data)
57 result_file = _create_data_file("result.npy", npy_data)
58
59 args = [str(reference_file), str(result_file)]
60 """Compares reference and result npy files, returns zero if it passes."""
61 assert trc.main(args) == expected
62
63 # Remove files created
64 _delete_data_file(reference_file)
65 _delete_data_file(result_file)
66
67
68@pytest.mark.parametrize(
69 "data_type,expected",
70 [
71 (np.int32, trc.TestResult.MISMATCH),
72 (np.int64, trc.TestResult.MISMATCH),
73 (np.float32, trc.TestResult.MISMATCH),
74 (bool, trc.TestResult.MISMATCH),
75 ],
76)
77def test_shape_mismatch(data_type, expected):
78 """Check that mismatch shapes do not pass."""
79 # Generate and save data as reference and result files to compare.
80 npy_data = np.ones(shape=(3, 2), dtype=data_type)
81 reference_file = _create_data_file("reference.npy", npy_data)
82 npy_data = np.ones(shape=(2, 3), dtype=data_type)
83 result_file = _create_data_file("result.npy", npy_data)
84
85 args = [str(reference_file), str(result_file)]
86 """Compares reference and result npy files, returns zero if it passes."""
87 assert trc.main(args) == expected
88
89 # Remove files created
90 _delete_data_file(reference_file)
91 _delete_data_file(result_file)
92
93
94@pytest.mark.parametrize(
95 "data_type,expected",
96 [
97 (np.int32, trc.TestResult.MISMATCH),
98 (np.int64, trc.TestResult.MISMATCH),
99 (np.float32, trc.TestResult.MISMATCH),
100 (bool, trc.TestResult.MISMATCH),
101 ],
102)
103def test_results_mismatch(data_type, expected):
104 """Check that different results do not pass."""
105 # Generate and save data as reference and result files to compare.
106 npy_data = np.zeros(shape=(2, 3), dtype=data_type)
107 reference_file = _create_data_file("reference.npy", npy_data)
108 npy_data = np.ones(shape=(2, 3), dtype=data_type)
109 result_file = _create_data_file("result.npy", npy_data)
110
111 args = [str(reference_file), str(result_file)]
112 """Compares reference and result npy files, returns zero if it passes."""
113 assert trc.main(args) == expected
114
115 # Remove files created
116 _delete_data_file(reference_file)
117 _delete_data_file(result_file)
118
119
120@pytest.mark.parametrize(
121 "data_type1,data_type2,expected",
122 [ # Pairwise testing of all supported types
123 (np.int32, np.int64, trc.TestResult.MISMATCH),
124 (bool, np.float32, trc.TestResult.MISMATCH),
125 ],
126)
127def test_types_mismatch(data_type1, data_type2, expected):
128 """Check that different types in results do not pass."""
129 # Generate and save data as reference and result files to compare.
130 npy_data = np.ones(shape=(3, 2), dtype=data_type1)
131 reference_file = _create_data_file("reference.npy", npy_data)
132 npy_data = np.ones(shape=(3, 2), dtype=data_type2)
133 result_file = _create_data_file("result.npy", npy_data)
134
135 args = [str(reference_file), str(result_file)]
136 """Compares reference and result npy files, returns zero if it passes."""
137 assert trc.main(args) == expected
138
139 # Remove files created
140 _delete_data_file(reference_file)
141 _delete_data_file(result_file)
142
143
144@pytest.mark.parametrize(
145 "reference_exists,result_exists,expected",
146 [
147 (True, False, trc.TestResult.MISSING_FILE),
148 (False, True, trc.TestResult.MISSING_FILE),
149 ],
150)
151def test_missing_files(reference_exists, result_exists, expected):
152 """Check that missing files are caught."""
153 # Generate and save data
154 npy_data = np.ndarray(shape=(2, 3), dtype=bool)
155 reference_file = _create_data_file("reference.npy", npy_data)
156 result_file = _create_data_file("result.npy", npy_data)
157 if not reference_exists:
158 _delete_data_file(reference_file)
159 if not result_exists:
160 _delete_data_file(result_file)
161
162 args = [str(reference_file), str(result_file)]
163 assert trc.main(args) == expected
164
165 if reference_exists:
166 _delete_data_file(reference_file)
167 if result_exists:
168 _delete_data_file(result_file)
169
170
171@pytest.mark.parametrize(
172 "reference_numpy,result_numpy,expected",
173 [
174 (True, False, trc.TestResult.INCORRECT_FORMAT),
175 (False, True, trc.TestResult.INCORRECT_FORMAT),
176 ],
177)
178def test_incorrect_format_files(reference_numpy, result_numpy, expected):
179 """Check that incorrect format files are caught."""
180 # Generate and save data
181 npy_data = np.ndarray(shape=(2, 3), dtype=bool)
182 reference_file = (
183 _create_data_file("reference.npy", npy_data)
184 if reference_numpy
185 else _create_empty_file("empty.npy")
186 )
187 result_file = (
188 _create_data_file("result.npy", npy_data)
189 if result_numpy
190 else _create_empty_file("empty.npy")
191 )
192
193 args = [str(reference_file), str(result_file)]
194 assert trc.main(args) == expected
195
196 _delete_data_file(reference_file)
197 _delete_data_file(result_file)