blob: bb52a8665e0ee1525fa5755f0ad1299457bf1a8d [file] [log] [blame]
Jeremy Johnson00423432022-09-12 17:27:37 +01001"""Tests for tosa_reference_model."""
evacha014a205112024-03-08 16:39:24 +00002# Copyright (c) 2022-2024, ARM Limited.
Jeremy Johnson00423432022-09-12 17:27:37 +01003# SPDX-License-Identifier: Apache-2.0
4import json
Jeremy Johnsona0848c62022-09-15 15:01:30 +01005import re
Jeremy Johnson00423432022-09-12 17:27:37 +01006from pathlib import Path
7from shutil import rmtree
8
Jeremy Johnson65ba8092023-10-09 16:31:13 +01009import conformance.model_files as cmf
Jeremy Johnson00423432022-09-12 17:27:37 +010010import numpy as np
11import pytest
12from checker.tosa_result_checker import test_check as tosa_check
13from checker.tosa_result_checker import TestResult as TosaResult
14from generator.tosa_verif_build_tests import main as tosa_builder
15from runner.run_command import run_sh_command
16from runner.run_command import RunShCommandError
17
Jeremy Johnson48df8c72023-09-12 14:52:34 +010018# Note: Must rename imports (like test_check) so that pytest doesn't assume its a test function/class
19
20# Location of reference model binaries
Jeremy Johnson65ba8092023-10-09 16:31:13 +010021REF_MODEL_DIR = Path(__file__).resolve().parents[2]
22REF_MODEL_EXE_PATH = cmf.find_tosa_file(
23 cmf.TosaFileType.REF_MODEL, REF_MODEL_DIR, False
24)
25GENERATE_LIB_PATH = cmf.find_tosa_file(
26 cmf.TosaFileType.GENERATE_LIBRARY, REF_MODEL_EXE_PATH
27)
Jeremy Johnson00423432022-09-12 17:27:37 +010028
29# Set this to False if you want ot preserve the test directories after running
30CLEAN_UP_TESTS = True
31
Jeremy Johnson00423432022-09-12 17:27:37 +010032# Default tensor shape information
33SHAPE_LIST = ["10", "5"]
Jeremy Johnsona0848c62022-09-15 15:01:30 +010034SHAPE_DIMS = len(SHAPE_LIST)
Jeremy Johnson00423432022-09-12 17:27:37 +010035SHAPE_ARG = ",".join(SHAPE_LIST)
36SHAPE_OUT = "x".join(SHAPE_LIST)
37
38# Output file information
39OUTPUT_DIR_PREFIX = "_pytest_vtest"
40OUTPUT_OFM_FILE = "result_refmodel_pytest.npy"
41OUTPUT_RESULT_FILE = "result_numpy_pytest.npy"
Jeremy Johnsona0848c62022-09-15 15:01:30 +010042OUTPUT_CONST_GLOB = "const-*.npy"
Jeremy Johnson00423432022-09-12 17:27:37 +010043
44TEST_DESC_FILENAME = "desc.json"
Jerry Ge0bd4ec82023-05-01 18:36:43 +000045TOSA_LEVEL = "EIGHTK"
Jeremy Johnson00423432022-09-12 17:27:37 +010046
47# Conversion from refmodel type into the type abbreviation used in the test output
48REF_MODEL_TYPE_TO_OUT = {
Jeremy Johnsona0848c62022-09-15 15:01:30 +010049 "bool": "b",
Jeremy Johnson00423432022-09-12 17:27:37 +010050 "int8": "i8",
51 "uint8": "u8",
52 "int16": "i16",
53 "int32": "i32",
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +010054 "fp32": "f32",
Jeremy Johnson93d43902022-09-27 12:26:14 +010055 "fp16": "f16",
James Ward24dbc422022-10-19 12:20:31 +010056 "bf16": "bf16",
Jeremy Johnson00423432022-09-12 17:27:37 +010057}
58
Jeremy Johnson65ba8092023-10-09 16:31:13 +010059# NOTE: These tests are marked as POST COMMIT
60# To run them, please build the reference_model in a local "build" directory
61# (as per the README) and run them using: pytest -m "postcommit"
Jeremy Johnson48df8c72023-09-12 14:52:34 +010062
Jeremy Johnson00423432022-09-12 17:27:37 +010063
64@pytest.mark.postcommit
65def test_refmodel_built():
66 """First test to check the reference model has been built."""
Jeremy Johnson48df8c72023-09-12 14:52:34 +010067 assert REF_MODEL_EXE_PATH.is_file()
Jeremy Johnson00423432022-09-12 17:27:37 +010068
69
70class BuildTosaTest:
71 """Wrapper for managing lifecycle of TOSA unit tests."""
72
Jeremy Johnsona0848c62022-09-15 15:01:30 +010073 def __init__(self, op_name, ref_model_type, num_expected_tests):
Jeremy Johnson00423432022-09-12 17:27:37 +010074 self.op_name = op_name
75 self.ref_model_type = ref_model_type
Jeremy Johnsona0848c62022-09-15 15:01:30 +010076 self.num_expected_tests = num_expected_tests
Jeremy Johnson00423432022-09-12 17:27:37 +010077 self.output_dir = None
Jeremy Johnsona0848c62022-09-15 15:01:30 +010078 self.test_dirs = None
Jeremy Johnson00423432022-09-12 17:27:37 +010079
80 def create_test(self):
81 """Helper to generate a TOSA unit test."""
82 if self.output_dir is not None:
83 # Already created
84 return self.test_dir
85
86 self.output_dir = (
87 Path(__file__).parent
88 / f"{OUTPUT_DIR_PREFIX}_{self.op_name}_{self.ref_model_type}"
89 )
90
Jeremy Johnsona0848c62022-09-15 15:01:30 +010091 # Generate tests without any zero-point
Jeremy Johnson00423432022-09-12 17:27:37 +010092 build_args = [
Jeremy Johnson65ba8092023-10-09 16:31:13 +010093 "--generate-lib-path",
94 str(GENERATE_LIB_PATH),
Jeremy Johnson00423432022-09-12 17:27:37 +010095 "--filter",
96 self.op_name,
97 "--target-shape",
98 SHAPE_ARG,
99 "--target-dtype",
100 self.ref_model_type,
101 "--zero-point",
102 "0",
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100103 "--num-const-inputs-concat",
104 "1",
105 "--dump-const-tensors",
Jeremy Johnson00423432022-09-12 17:27:37 +0100106 "-o",
107 str(self.output_dir),
108 ]
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100109 print(f"### Building tests: tosa_verif_build_tests {' '.join(build_args)}")
Jeremy Johnson00423432022-09-12 17:27:37 +0100110 tosa_builder(build_args)
111
112 # Find the created test
113 test_dir = self.output_dir / self.op_name
114 # Can't assume exact name due to broadcasting and other changes to shape
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100115 test_glob = f"{self.op_name}_*_{REF_MODEL_TYPE_TO_OUT[self.ref_model_type]}*"
116 test_dirs = sorted(test_dir.glob(test_glob))
117 assert len(test_dirs) == self.num_expected_tests
118 for test_dir in test_dirs:
119 assert test_dir.is_dir()
120 self.test_dirs = test_dirs
Jeremy Johnson00423432022-09-12 17:27:37 +0100121
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100122 return self.test_dirs
Jeremy Johnson00423432022-09-12 17:27:37 +0100123
124 def remove_test(self):
125 if self.output_dir is not None and self.output_dir.is_dir():
126 # Delete directory
127 test_tree = self.output_dir.resolve()
128 if CLEAN_UP_TESTS:
129 print(f"Deleting {test_tree}")
130 rmtree(str(test_tree))
131 self.output_dir = None
132 else:
133 print(f"Skipped clean up of {test_tree}")
134
135
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100136# Tests - op_name, ref_model_type, num_expected_tests
evacha014a205112024-03-08 16:39:24 +0000137# FP Special datagen adds a second expected test to FP16 and FP32 tests for OPs it is added to
Jeremy Johnson00423432022-09-12 17:27:37 +0100138TEST_PARAMS = [
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100139 ("add", "int32", 1),
evacha014a205112024-03-08 16:39:24 +0000140 ("add", "fp32", 2),
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100141 ("abs", "int32", 1),
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100142 ("abs", "fp32", 1),
Jeremy Johnson93d43902022-09-27 12:26:14 +0100143 ("abs", "fp16", 1),
James Ward24dbc422022-10-19 12:20:31 +0100144 ("abs", "bf16", 1),
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100145 ("negate", "int8", 1),
146 ("negate", "int16", 1),
147 ("negate", "int32", 1),
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100148 ("negate", "fp32", 1),
Jeremy Johnson93d43902022-09-27 12:26:14 +0100149 ("negate", "fp16", 1),
James Ward24dbc422022-10-19 12:20:31 +0100150 ("negate", "bf16", 1),
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100151 # One test per axis (shape dimensions)
152 ("concat", "bool", SHAPE_DIMS),
153 ("concat", "int8", SHAPE_DIMS),
154 ("concat", "int16", SHAPE_DIMS),
155 ("concat", "int32", SHAPE_DIMS),
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100156 ("concat", "fp32", SHAPE_DIMS),
Jeremy Johnson93d43902022-09-27 12:26:14 +0100157 ("concat", "fp16", SHAPE_DIMS),
James Ward24dbc422022-10-19 12:20:31 +0100158 ("concat", "bf16", SHAPE_DIMS),
Jeremy Johnson00423432022-09-12 17:27:37 +0100159]
160
161
162def id_2_name(id):
163 """Convert test id to name - otherwise it will be tosaTestN."""
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100164 op_name, ref_model_type, _ = id
Jeremy Johnson00423432022-09-12 17:27:37 +0100165 return f"{op_name}-{ref_model_type}"
166
167
168@pytest.fixture(params=TEST_PARAMS, ids=id_2_name)
169def tosaTest(request):
170 """Fixture to generate the required test params and clean up."""
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100171 op_name, ref_model_type, num_expected_tests = request.param
172 tst = BuildTosaTest(op_name, ref_model_type, num_expected_tests)
Jeremy Johnson00423432022-09-12 17:27:37 +0100173 yield tst
174 tst.remove_test()
175
176
177@pytest.mark.postcommit
178def test_refmodel_simple_op(tosaTest):
179 """Operator testing versus Numpy."""
180 op_name = tosaTest.op_name
181
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100182 # Generate TOSA test(s) (mostly should be single test)
183 test_dirs = tosaTest.create_test()
Jeremy Johnson00423432022-09-12 17:27:37 +0100184
James Ward24dbc422022-10-19 12:20:31 +0100185 # Indicate miscellaneous checks to run in tosa_check
186 misc_checks = []
187
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100188 for test_dir in test_dirs:
189 # Run ref model
190 desc_file = test_dir / TEST_DESC_FILENAME
191 assert desc_file.is_file()
192 refmodel_cmd = [
Jeremy Johnson48df8c72023-09-12 14:52:34 +0100193 str(REF_MODEL_EXE_PATH),
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100194 "--test_desc",
195 str(desc_file),
196 "--ofm_file",
197 OUTPUT_OFM_FILE,
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000198 "--tosa_level",
199 TOSA_LEVEL,
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100200 ]
201 try:
202 run_sh_command(refmodel_cmd, verbose=True, capture_output=True)
203 except RunShCommandError as err:
204 assert False, f"Unexpected exception {err}"
Jeremy Johnson00423432022-09-12 17:27:37 +0100205
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100206 # Find output
207 ofm_file = test_dir / OUTPUT_OFM_FILE
208 assert ofm_file.is_file()
Jeremy Johnson00423432022-09-12 17:27:37 +0100209
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100210 # Load inputs for Numpy
211 with desc_file.open("r") as fp:
212 test_desc = json.load(fp)
213 tensors = []
214 assert "ifm_file" in test_desc
215 for input_name in test_desc["ifm_file"]:
216 input_file = test_dir / input_name
217 assert input_file.is_file()
218 tensors.append(np.load(str(input_file)))
Jeremy Johnson00423432022-09-12 17:27:37 +0100219
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100220 # Load constants for Numpy
221 const_files = sorted(test_dir.glob(OUTPUT_CONST_GLOB))
222 consts = []
223 for const_file in const_files:
224 assert const_file.is_file()
225 consts.append(np.load(str(const_file)))
Jeremy Johnson00423432022-09-12 17:27:37 +0100226
evacha014a205112024-03-08 16:39:24 +0000227 # Check if the data is from FP special datagen which can give invalid results
228 fp_special_data = test_dir.match("*_fs")
229
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100230 # Perform Numpy operation
231 if op_name == "abs":
232 assert len(tensors) == 1
233 result = np.abs(tensors[0])
234 elif op_name == "add":
235 assert len(tensors) == 2
evacha014a205112024-03-08 16:39:24 +0000236 if fp_special_data:
237 with np.errstate(invalid="ignore"):
238 result = np.add(tensors[0], tensors[1])
239 else:
240 result = np.add(tensors[0], tensors[1])
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100241 elif op_name == "concat":
242 assert len(consts) == 1
243 # Get axis from test directory name
244 match = re.search(r"axis([0-9]+)", test_dir.name)
245 assert match is not None
246 axis = int(match.group(1))
247 result = np.concatenate((*tensors, consts[0]), axis=axis)
248 elif op_name == "negate":
249 assert len(tensors) == 1
250 result = np.negative(tensors[0])
251 else:
252 assert False, f"Unknown operation {op_name}"
Jeremy Johnson00423432022-09-12 17:27:37 +0100253
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100254 # Save Numpy result
255 result_file = test_dir / OUTPUT_RESULT_FILE
256 np.save(str(result_file), result)
257 assert result_file.is_file()
258
James Ward24dbc422022-10-19 12:20:31 +0100259 # Ensure valid bf16
260 if tosaTest.ref_model_type == "bf16":
261 misc_checks.append("bf16")
262
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100263 # Check Numpy result versus refmodel
264 check_result, tolerance, msg = tosa_check(
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100265 result_file,
266 ofm_file,
James Ward24dbc422022-10-19 12:20:31 +0100267 test_name=test_dir.name,
268 misc_checks=misc_checks,
Jeremy Johnsona0848c62022-09-15 15:01:30 +0100269 )
270 assert check_result == TosaResult.PASS