blob: 1f9cd3e6172420287c649e6ae02e239460fc7869 [file] [log] [blame]
Jeremy Johnson00423432022-09-12 17:27:37 +01001"""Tests for tosa_reference_model."""
2# Copyright (c) 2022, ARM Limited.
3# SPDX-License-Identifier: Apache-2.0
4import json
Jeremy Johnsona0848c62022-09-15 15:01:30 +01005import re
Jeremy Johnson00423432022-09-12 17:27:37 +01006from pathlib import Path
7from shutil import rmtree
8
9import numpy as np
10import pytest
11from checker.tosa_result_checker import test_check as tosa_check
12from checker.tosa_result_checker import TestResult as TosaResult
13from generator.tosa_verif_build_tests import main as tosa_builder
14from runner.run_command import run_sh_command
15from runner.run_command import RunShCommandError
16
17# Note: Must rename imports so that pytest doesn't assume its a test function/class
18
19# Set this to False if you want ot preserve the test directories after running
20CLEAN_UP_TESTS = True
21
22# Location of reference model binary
23REF_MODEL_PATH = Path(__file__).resolve().parents[2] / "build" / "reference_model"
24REF_MODEL_EXE = "tosa_reference_model"
25REF_MODEL = REF_MODEL_PATH / REF_MODEL_EXE
26
27# Default tensor shape information
28SHAPE_LIST = ["10", "5"]
Jeremy Johnsona0848c62022-09-15 15:01:30 +010029SHAPE_DIMS = len(SHAPE_LIST)
Jeremy Johnson00423432022-09-12 17:27:37 +010030SHAPE_ARG = ",".join(SHAPE_LIST)
31SHAPE_OUT = "x".join(SHAPE_LIST)
32
33# Output file information
34OUTPUT_DIR_PREFIX = "_pytest_vtest"
35OUTPUT_OFM_FILE = "result_refmodel_pytest.npy"
36OUTPUT_RESULT_FILE = "result_numpy_pytest.npy"
Jeremy Johnsona0848c62022-09-15 15:01:30 +010037OUTPUT_CONST_GLOB = "const-*.npy"
Jeremy Johnson00423432022-09-12 17:27:37 +010038
39TEST_DESC_FILENAME = "desc.json"
40
41# Conversion from refmodel type into the type abbreviation used in the test output
42REF_MODEL_TYPE_TO_OUT = {
Jeremy Johnsona0848c62022-09-15 15:01:30 +010043 "bool": "b",
Jeremy Johnson00423432022-09-12 17:27:37 +010044 "int8": "i8",
45 "uint8": "u8",
46 "int16": "i16",
47 "int32": "i32",
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010048 "fp32": "f32",
Jeremy Johnson93d43902022-09-27 12:26:14 +010049 "fp16": "f16",
James Ward24dbc422022-10-19 12:20:31 +010050 "bf16": "bf16",
Jeremy Johnson00423432022-09-12 17:27:37 +010051}
52
53
54@pytest.mark.postcommit
55def test_refmodel_built():
56 """First test to check the reference model has been built."""
57 assert REF_MODEL.is_file()
58
59
60class BuildTosaTest:
61 """Wrapper for managing lifecycle of TOSA unit tests."""
62
Jeremy Johnsona0848c62022-09-15 15:01:30 +010063 def __init__(self, op_name, ref_model_type, num_expected_tests):
Jeremy Johnson00423432022-09-12 17:27:37 +010064 self.op_name = op_name
65 self.ref_model_type = ref_model_type
Jeremy Johnsona0848c62022-09-15 15:01:30 +010066 self.num_expected_tests = num_expected_tests
Jeremy Johnson00423432022-09-12 17:27:37 +010067 self.output_dir = None
Jeremy Johnsona0848c62022-09-15 15:01:30 +010068 self.test_dirs = None
Jeremy Johnson00423432022-09-12 17:27:37 +010069
70 def create_test(self):
71 """Helper to generate a TOSA unit test."""
72 if self.output_dir is not None:
73 # Already created
74 return self.test_dir
75
76 self.output_dir = (
77 Path(__file__).parent
78 / f"{OUTPUT_DIR_PREFIX}_{self.op_name}_{self.ref_model_type}"
79 )
80
Jeremy Johnsona0848c62022-09-15 15:01:30 +010081 # Generate tests without any zero-point
Jeremy Johnson00423432022-09-12 17:27:37 +010082 build_args = [
83 "--filter",
84 self.op_name,
85 "--target-shape",
86 SHAPE_ARG,
87 "--target-dtype",
88 self.ref_model_type,
89 "--zero-point",
90 "0",
Jeremy Johnsona0848c62022-09-15 15:01:30 +010091 "--num-const-inputs-concat",
92 "1",
93 "--dump-const-tensors",
Jeremy Johnson00423432022-09-12 17:27:37 +010094 "-o",
95 str(self.output_dir),
96 ]
Jeremy Johnsona0848c62022-09-15 15:01:30 +010097 print(f"### Building tests: tosa_verif_build_tests {' '.join(build_args)}")
Jeremy Johnson00423432022-09-12 17:27:37 +010098 tosa_builder(build_args)
99
100 # Find the created test
101 test_dir = self.output_dir / self.op_name
102 # Can't assume exact name due to broadcasting and other changes to shape
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100103 test_glob = f"{self.op_name}_*_{REF_MODEL_TYPE_TO_OUT[self.ref_model_type]}*"
104 test_dirs = sorted(test_dir.glob(test_glob))
105 assert len(test_dirs) == self.num_expected_tests
106 for test_dir in test_dirs:
107 assert test_dir.is_dir()
108 self.test_dirs = test_dirs
Jeremy Johnson00423432022-09-12 17:27:37 +0100109
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100110 return self.test_dirs
Jeremy Johnson00423432022-09-12 17:27:37 +0100111
112 def remove_test(self):
113 if self.output_dir is not None and self.output_dir.is_dir():
114 # Delete directory
115 test_tree = self.output_dir.resolve()
116 if CLEAN_UP_TESTS:
117 print(f"Deleting {test_tree}")
118 rmtree(str(test_tree))
119 self.output_dir = None
120 else:
121 print(f"Skipped clean up of {test_tree}")
122
123
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100124# Tests - op_name, ref_model_type, num_expected_tests
Jeremy Johnson00423432022-09-12 17:27:37 +0100125TEST_PARAMS = [
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100126 ("add", "int32", 1),
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100127 ("add", "fp32", 1),
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100128 ("abs", "int32", 1),
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100129 ("abs", "fp32", 1),
Jeremy Johnson93d43902022-09-27 12:26:14 +0100130 ("abs", "fp16", 1),
James Ward24dbc422022-10-19 12:20:31 +0100131 ("abs", "bf16", 1),
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100132 ("negate", "int8", 1),
133 ("negate", "int16", 1),
134 ("negate", "int32", 1),
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100135 ("negate", "fp32", 1),
Jeremy Johnson93d43902022-09-27 12:26:14 +0100136 ("negate", "fp16", 1),
James Ward24dbc422022-10-19 12:20:31 +0100137 ("negate", "bf16", 1),
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100138 # One test per axis (shape dimensions)
139 ("concat", "bool", SHAPE_DIMS),
140 ("concat", "int8", SHAPE_DIMS),
141 ("concat", "int16", SHAPE_DIMS),
142 ("concat", "int32", SHAPE_DIMS),
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100143 ("concat", "fp32", SHAPE_DIMS),
Jeremy Johnson93d43902022-09-27 12:26:14 +0100144 ("concat", "fp16", SHAPE_DIMS),
James Ward24dbc422022-10-19 12:20:31 +0100145 ("concat", "bf16", SHAPE_DIMS),
Jeremy Johnson00423432022-09-12 17:27:37 +0100146]
147
148
149def id_2_name(id):
150 """Convert test id to name - otherwise it will be tosaTestN."""
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100151 op_name, ref_model_type, _ = id
Jeremy Johnson00423432022-09-12 17:27:37 +0100152 return f"{op_name}-{ref_model_type}"
153
154
155@pytest.fixture(params=TEST_PARAMS, ids=id_2_name)
156def tosaTest(request):
157 """Fixture to generate the required test params and clean up."""
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100158 op_name, ref_model_type, num_expected_tests = request.param
159 tst = BuildTosaTest(op_name, ref_model_type, num_expected_tests)
Jeremy Johnson00423432022-09-12 17:27:37 +0100160 yield tst
161 tst.remove_test()
162
163
164@pytest.mark.postcommit
165def test_refmodel_simple_op(tosaTest):
166 """Operator testing versus Numpy."""
167 op_name = tosaTest.op_name
168
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100169 # Generate TOSA test(s) (mostly should be single test)
170 test_dirs = tosaTest.create_test()
Jeremy Johnson00423432022-09-12 17:27:37 +0100171
James Ward24dbc422022-10-19 12:20:31 +0100172 # Indicate miscellaneous checks to run in tosa_check
173 misc_checks = []
174
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100175 for test_dir in test_dirs:
176 # Run ref model
177 desc_file = test_dir / TEST_DESC_FILENAME
178 assert desc_file.is_file()
179 refmodel_cmd = [
180 str(REF_MODEL),
181 "--test_desc",
182 str(desc_file),
183 "--ofm_file",
184 OUTPUT_OFM_FILE,
185 ]
186 try:
187 run_sh_command(refmodel_cmd, verbose=True, capture_output=True)
188 except RunShCommandError as err:
189 assert False, f"Unexpected exception {err}"
Jeremy Johnson00423432022-09-12 17:27:37 +0100190
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100191 # Find output
192 ofm_file = test_dir / OUTPUT_OFM_FILE
193 assert ofm_file.is_file()
Jeremy Johnson00423432022-09-12 17:27:37 +0100194
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100195 # Load inputs for Numpy
196 with desc_file.open("r") as fp:
197 test_desc = json.load(fp)
198 tensors = []
199 assert "ifm_file" in test_desc
200 for input_name in test_desc["ifm_file"]:
201 input_file = test_dir / input_name
202 assert input_file.is_file()
203 tensors.append(np.load(str(input_file)))
Jeremy Johnson00423432022-09-12 17:27:37 +0100204
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100205 # Load constants for Numpy
206 const_files = sorted(test_dir.glob(OUTPUT_CONST_GLOB))
207 consts = []
208 for const_file in const_files:
209 assert const_file.is_file()
210 consts.append(np.load(str(const_file)))
Jeremy Johnson00423432022-09-12 17:27:37 +0100211
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100212 # Perform Numpy operation
213 if op_name == "abs":
214 assert len(tensors) == 1
215 result = np.abs(tensors[0])
216 elif op_name == "add":
217 assert len(tensors) == 2
218 result = np.add(tensors[0], tensors[1])
219 elif op_name == "concat":
220 assert len(consts) == 1
221 # Get axis from test directory name
222 match = re.search(r"axis([0-9]+)", test_dir.name)
223 assert match is not None
224 axis = int(match.group(1))
225 result = np.concatenate((*tensors, consts[0]), axis=axis)
226 elif op_name == "negate":
227 assert len(tensors) == 1
228 result = np.negative(tensors[0])
229 else:
230 assert False, f"Unknown operation {op_name}"
Jeremy Johnson00423432022-09-12 17:27:37 +0100231
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100232 # Save Numpy result
233 result_file = test_dir / OUTPUT_RESULT_FILE
234 np.save(str(result_file), result)
235 assert result_file.is_file()
236
James Ward24dbc422022-10-19 12:20:31 +0100237 # Ensure valid bf16
238 if tosaTest.ref_model_type == "bf16":
239 misc_checks.append("bf16")
240
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100241 # Check Numpy result versus refmodel
242 check_result, tolerance, msg = tosa_check(
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100243 result_file,
244 ofm_file,
James Ward24dbc422022-10-19 12:20:31 +0100245 test_name=test_dir.name,
246 misc_checks=misc_checks,
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100247 )
248 assert check_result == TosaResult.PASS