blob: 79e67203c322cf0beec7e4ba93c971bd5671cea8 [file] [log] [blame]
Jeremy Johnson00423432022-09-12 17:27:37 +01001"""Tests for tosa_reference_model."""
2# Copyright (c) 2022, ARM Limited.
3# SPDX-License-Identifier: Apache-2.0
4import json
Jeremy Johnsona0848c62022-09-15 15:01:30 +01005import re
Jeremy Johnson00423432022-09-12 17:27:37 +01006from pathlib import Path
7from shutil import rmtree
8
9import numpy as np
10import pytest
11from checker.tosa_result_checker import test_check as tosa_check
12from checker.tosa_result_checker import TestResult as TosaResult
13from generator.tosa_verif_build_tests import main as tosa_builder
14from runner.run_command import run_sh_command
15from runner.run_command import RunShCommandError
16
17# Note: Must rename imports so that pytest doesn't assume its a test function/class
18
19# Set this to False if you want ot preserve the test directories after running
20CLEAN_UP_TESTS = True
21
22# Location of reference model binary
23REF_MODEL_PATH = Path(__file__).resolve().parents[2] / "build" / "reference_model"
24REF_MODEL_EXE = "tosa_reference_model"
25REF_MODEL = REF_MODEL_PATH / REF_MODEL_EXE
26
27# Default tensor shape information
28SHAPE_LIST = ["10", "5"]
Jeremy Johnsona0848c62022-09-15 15:01:30 +010029SHAPE_DIMS = len(SHAPE_LIST)
Jeremy Johnson00423432022-09-12 17:27:37 +010030SHAPE_ARG = ",".join(SHAPE_LIST)
31SHAPE_OUT = "x".join(SHAPE_LIST)
32
33# Output file information
34OUTPUT_DIR_PREFIX = "_pytest_vtest"
35OUTPUT_OFM_FILE = "result_refmodel_pytest.npy"
36OUTPUT_RESULT_FILE = "result_numpy_pytest.npy"
Jeremy Johnsona0848c62022-09-15 15:01:30 +010037OUTPUT_CONST_GLOB = "const-*.npy"
Jeremy Johnson00423432022-09-12 17:27:37 +010038
39TEST_DESC_FILENAME = "desc.json"
Jerry Ge0bd4ec82023-05-01 18:36:43 +000040TOSA_LEVEL = "EIGHTK"
Jeremy Johnson00423432022-09-12 17:27:37 +010041
42# Conversion from refmodel type into the type abbreviation used in the test output
43REF_MODEL_TYPE_TO_OUT = {
Jeremy Johnsona0848c62022-09-15 15:01:30 +010044 "bool": "b",
Jeremy Johnson00423432022-09-12 17:27:37 +010045 "int8": "i8",
46 "uint8": "u8",
47 "int16": "i16",
48 "int32": "i32",
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010049 "fp32": "f32",
Jeremy Johnson93d43902022-09-27 12:26:14 +010050 "fp16": "f16",
James Ward24dbc422022-10-19 12:20:31 +010051 "bf16": "bf16",
Jeremy Johnson00423432022-09-12 17:27:37 +010052}
53
54
55@pytest.mark.postcommit
56def test_refmodel_built():
57 """First test to check the reference model has been built."""
58 assert REF_MODEL.is_file()
59
60
61class BuildTosaTest:
62 """Wrapper for managing lifecycle of TOSA unit tests."""
63
Jeremy Johnsona0848c62022-09-15 15:01:30 +010064 def __init__(self, op_name, ref_model_type, num_expected_tests):
Jeremy Johnson00423432022-09-12 17:27:37 +010065 self.op_name = op_name
66 self.ref_model_type = ref_model_type
Jeremy Johnsona0848c62022-09-15 15:01:30 +010067 self.num_expected_tests = num_expected_tests
Jeremy Johnson00423432022-09-12 17:27:37 +010068 self.output_dir = None
Jeremy Johnsona0848c62022-09-15 15:01:30 +010069 self.test_dirs = None
Jeremy Johnson00423432022-09-12 17:27:37 +010070
71 def create_test(self):
72 """Helper to generate a TOSA unit test."""
73 if self.output_dir is not None:
74 # Already created
75 return self.test_dir
76
77 self.output_dir = (
78 Path(__file__).parent
79 / f"{OUTPUT_DIR_PREFIX}_{self.op_name}_{self.ref_model_type}"
80 )
81
Jeremy Johnsona0848c62022-09-15 15:01:30 +010082 # Generate tests without any zero-point
Jeremy Johnson00423432022-09-12 17:27:37 +010083 build_args = [
84 "--filter",
85 self.op_name,
86 "--target-shape",
87 SHAPE_ARG,
88 "--target-dtype",
89 self.ref_model_type,
90 "--zero-point",
91 "0",
Jeremy Johnsona0848c62022-09-15 15:01:30 +010092 "--num-const-inputs-concat",
93 "1",
94 "--dump-const-tensors",
Jeremy Johnson00423432022-09-12 17:27:37 +010095 "-o",
96 str(self.output_dir),
97 ]
Jeremy Johnsona0848c62022-09-15 15:01:30 +010098 print(f"### Building tests: tosa_verif_build_tests {' '.join(build_args)}")
Jeremy Johnson00423432022-09-12 17:27:37 +010099 tosa_builder(build_args)
100
101 # Find the created test
102 test_dir = self.output_dir / self.op_name
103 # Can't assume exact name due to broadcasting and other changes to shape
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100104 test_glob = f"{self.op_name}_*_{REF_MODEL_TYPE_TO_OUT[self.ref_model_type]}*"
105 test_dirs = sorted(test_dir.glob(test_glob))
106 assert len(test_dirs) == self.num_expected_tests
107 for test_dir in test_dirs:
108 assert test_dir.is_dir()
109 self.test_dirs = test_dirs
Jeremy Johnson00423432022-09-12 17:27:37 +0100110
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100111 return self.test_dirs
Jeremy Johnson00423432022-09-12 17:27:37 +0100112
113 def remove_test(self):
114 if self.output_dir is not None and self.output_dir.is_dir():
115 # Delete directory
116 test_tree = self.output_dir.resolve()
117 if CLEAN_UP_TESTS:
118 print(f"Deleting {test_tree}")
119 rmtree(str(test_tree))
120 self.output_dir = None
121 else:
122 print(f"Skipped clean up of {test_tree}")
123
124
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100125# Tests - op_name, ref_model_type, num_expected_tests
Jeremy Johnson00423432022-09-12 17:27:37 +0100126TEST_PARAMS = [
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100127 ("add", "int32", 1),
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100128 ("add", "fp32", 1),
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100129 ("abs", "int32", 1),
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100130 ("abs", "fp32", 1),
Jeremy Johnson93d43902022-09-27 12:26:14 +0100131 ("abs", "fp16", 1),
James Ward24dbc422022-10-19 12:20:31 +0100132 ("abs", "bf16", 1),
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100133 ("negate", "int8", 1),
134 ("negate", "int16", 1),
135 ("negate", "int32", 1),
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100136 ("negate", "fp32", 1),
Jeremy Johnson93d43902022-09-27 12:26:14 +0100137 ("negate", "fp16", 1),
James Ward24dbc422022-10-19 12:20:31 +0100138 ("negate", "bf16", 1),
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100139 # One test per axis (shape dimensions)
140 ("concat", "bool", SHAPE_DIMS),
141 ("concat", "int8", SHAPE_DIMS),
142 ("concat", "int16", SHAPE_DIMS),
143 ("concat", "int32", SHAPE_DIMS),
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100144 ("concat", "fp32", SHAPE_DIMS),
Jeremy Johnson93d43902022-09-27 12:26:14 +0100145 ("concat", "fp16", SHAPE_DIMS),
James Ward24dbc422022-10-19 12:20:31 +0100146 ("concat", "bf16", SHAPE_DIMS),
Jeremy Johnson00423432022-09-12 17:27:37 +0100147]
148
149
150def id_2_name(id):
151 """Convert test id to name - otherwise it will be tosaTestN."""
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100152 op_name, ref_model_type, _ = id
Jeremy Johnson00423432022-09-12 17:27:37 +0100153 return f"{op_name}-{ref_model_type}"
154
155
156@pytest.fixture(params=TEST_PARAMS, ids=id_2_name)
157def tosaTest(request):
158 """Fixture to generate the required test params and clean up."""
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100159 op_name, ref_model_type, num_expected_tests = request.param
160 tst = BuildTosaTest(op_name, ref_model_type, num_expected_tests)
Jeremy Johnson00423432022-09-12 17:27:37 +0100161 yield tst
162 tst.remove_test()
163
164
165@pytest.mark.postcommit
166def test_refmodel_simple_op(tosaTest):
167 """Operator testing versus Numpy."""
168 op_name = tosaTest.op_name
169
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100170 # Generate TOSA test(s) (mostly should be single test)
171 test_dirs = tosaTest.create_test()
Jeremy Johnson00423432022-09-12 17:27:37 +0100172
James Ward24dbc422022-10-19 12:20:31 +0100173 # Indicate miscellaneous checks to run in tosa_check
174 misc_checks = []
175
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100176 for test_dir in test_dirs:
177 # Run ref model
178 desc_file = test_dir / TEST_DESC_FILENAME
179 assert desc_file.is_file()
180 refmodel_cmd = [
181 str(REF_MODEL),
182 "--test_desc",
183 str(desc_file),
184 "--ofm_file",
185 OUTPUT_OFM_FILE,
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000186 "--tosa_level",
187 TOSA_LEVEL,
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100188 ]
189 try:
190 run_sh_command(refmodel_cmd, verbose=True, capture_output=True)
191 except RunShCommandError as err:
192 assert False, f"Unexpected exception {err}"
Jeremy Johnson00423432022-09-12 17:27:37 +0100193
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100194 # Find output
195 ofm_file = test_dir / OUTPUT_OFM_FILE
196 assert ofm_file.is_file()
Jeremy Johnson00423432022-09-12 17:27:37 +0100197
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100198 # Load inputs for Numpy
199 with desc_file.open("r") as fp:
200 test_desc = json.load(fp)
201 tensors = []
202 assert "ifm_file" in test_desc
203 for input_name in test_desc["ifm_file"]:
204 input_file = test_dir / input_name
205 assert input_file.is_file()
206 tensors.append(np.load(str(input_file)))
Jeremy Johnson00423432022-09-12 17:27:37 +0100207
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100208 # Load constants for Numpy
209 const_files = sorted(test_dir.glob(OUTPUT_CONST_GLOB))
210 consts = []
211 for const_file in const_files:
212 assert const_file.is_file()
213 consts.append(np.load(str(const_file)))
Jeremy Johnson00423432022-09-12 17:27:37 +0100214
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100215 # Perform Numpy operation
216 if op_name == "abs":
217 assert len(tensors) == 1
218 result = np.abs(tensors[0])
219 elif op_name == "add":
220 assert len(tensors) == 2
221 result = np.add(tensors[0], tensors[1])
222 elif op_name == "concat":
223 assert len(consts) == 1
224 # Get axis from test directory name
225 match = re.search(r"axis([0-9]+)", test_dir.name)
226 assert match is not None
227 axis = int(match.group(1))
228 result = np.concatenate((*tensors, consts[0]), axis=axis)
229 elif op_name == "negate":
230 assert len(tensors) == 1
231 result = np.negative(tensors[0])
232 else:
233 assert False, f"Unknown operation {op_name}"
Jeremy Johnson00423432022-09-12 17:27:37 +0100234
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100235 # Save Numpy result
236 result_file = test_dir / OUTPUT_RESULT_FILE
237 np.save(str(result_file), result)
238 assert result_file.is_file()
239
James Ward24dbc422022-10-19 12:20:31 +0100240 # Ensure valid bf16
241 if tosaTest.ref_model_type == "bf16":
242 misc_checks.append("bf16")
243
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100244 # Check Numpy result versus refmodel
245 check_result, tolerance, msg = tosa_check(
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100246 result_file,
247 ofm_file,
James Ward24dbc422022-10-19 12:20:31 +0100248 test_name=test_dir.name,
249 misc_checks=misc_checks,
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100250 )
251 assert check_result == TosaResult.PASS