blob: 236f729c38285e09b59c1fd9e7f043da050e480f [file] [log] [blame]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01001#!/usr/bin/env python3
Jeremy Johnson35396f22023-01-04 17:05:25 +00002# Copyright (c) 2021-2023, ARM Limited.
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01003# SPDX-License-Identifier: Apache-2.0
4"""Build conformance tests.
5
6Steps:
7- Specific input shapes (or tests) are specified and produced by using the
8 settings in the .json files.
9- Tests are selected to produce a good coverage.
10- Tests are run on the reference model to produce the correct output files.
11- Tests are converted into JSON format and saved to desired output directory.
12"""
13import argparse
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000014import copy
Jeremy Johnson0ecfa372022-06-30 14:27:56 +010015import json
16import logging
17import multiprocessing as mp
18import os
19import shlex
20import shutil
21import subprocess
22from functools import partial
23from itertools import tee
24from pathlib import Path
25
26from conformance.test_select import Operator
27from convert2conformance.convert2conformance import main as c2c_main
28from distutils.dir_util import copy_tree
29
30logging.basicConfig()
31logger = logging.getLogger("tosa_verif_conformance_generator")
32
33# Configuration for each TOSA profile
34PROFILE_OPS_INFO = {
Jeremy Johnson88588622022-07-12 16:42:29 +010035 "tosa-bi": {
Jeremy Johnson0ecfa372022-06-30 14:27:56 +010036 "operator_test_params": "tosa_base_profile_ops_info.json",
37 "framework_tests": "tosa_base_profile_framework_ops_info.json",
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010038 },
39 "tosa-mi": {
40 # Note: This is just the extra tests not in the base profile!
41 "operator_test_params": "tosa_main_profile_ops_info.json",
42 "framework_tests": "tosa_main_profile_framework_ops_info.json",
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010043 },
Jeremy Johnson0ecfa372022-06-30 14:27:56 +010044}
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010045PROFILES_ALL = "all"
Jeremy Johnson0ecfa372022-06-30 14:27:56 +010046
47LOCATION_REF_MODEL_BINARY = Path("build/reference_model/tosa_reference_model")
48
Jeremy Johnson93d43902022-09-27 12:26:14 +010049DEFAULT_SEED = 42
50
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000051# When there is a dictionary of generator argument lists (groups) only the
52# standard group will have negative tests generated for it
53STANDARD_GENERATOR_GROUP = "standard"
54
Jeremy Johnson0ecfa372022-06-30 14:27:56 +010055
56class GenConformanceError(Exception):
57 """Generation error reporting exception."""
58
59 pass
60
61
62def _run_sh_command(args, cwd, full_cmd):
63 """Run an external command and capture stdout/stderr."""
64 # Quote the command line for printing
65 full_cmd_esc = [shlex.quote(x) for x in full_cmd]
66 if args.capture_output:
67 logger.debug(f"Command: {full_cmd_esc}")
68
69 rc = subprocess.run(
70 full_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=cwd
71 )
72
73 if args.capture_output:
74 stdout = rc.stdout.decode("utf-8")
75 logger.debug(f"stdout: \n{stdout}")
76 if rc.returncode != 0:
77
78 raise Exception(
79 "Error running command: {}.\n{}".format(
80 " ".join(full_cmd_esc), rc.stderr.decode("utf-8")
81 )
82 )
83 return (rc.stdout, rc.stderr)
84
85
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000086def build_op_tests(
Jeremy Johnson1271c442023-09-05 11:39:26 +010087 args,
88 test_type,
89 profile,
90 operator,
91 group,
92 gen_args_list,
93 gen_neg_dim_range,
94 supports=[],
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000095):
Jeremy Johnson0ecfa372022-06-30 14:27:56 +010096 """Build tests for a given operator.
97
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000098 Builds a set of tests based on the given generator arguments list
Jeremy Johnson0ecfa372022-06-30 14:27:56 +010099
100 Returns operator output directory
101 """
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100102 build_tests_cmd = "tosa_verif_build_tests"
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000103 op_build_dir = args.build_dir / profile / group
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100104
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000105 build_cmd_base = [
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100106 build_tests_cmd,
107 "--filter",
108 operator,
109 "-o",
110 str(op_build_dir),
111 "--seed",
Jeremy Johnson93d43902022-09-27 12:26:14 +0100112 str(args.random_seed),
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100113 ]
114
Jeremy Johnson1271c442023-09-05 11:39:26 +0100115 if "lazy_data_gen" in supports and args.lazy_data_generation:
116 build_cmd_base.append("--lazy-data-generation")
117
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000118 build_cmds_list = []
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100119
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000120 if test_type in ["positive", "both"]:
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100121 # Append extra parameters and run test generator for each set of parameters.
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000122 for arglist in gen_args_list:
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000123 build_cmd_pos_test = build_cmd_base.copy()
124 build_cmd_pos_test.extend(["--test-type", "positive"])
125 build_cmd_pos_test.extend(arglist)
126 build_cmds_list.append(build_cmd_pos_test)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100127
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000128 if test_type in ["negative", "both"]:
Jeremy Johnson35396f22023-01-04 17:05:25 +0000129 # Get target-dtypes options and any filter string to limit tests
Jeremy Johnson93d43902022-09-27 12:26:14 +0100130 target_dtypes_args = []
Jeremy Johnson35396f22023-01-04 17:05:25 +0000131 filter_str = None
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000132 for arglist in gen_args_list:
Jeremy Johnson93d43902022-09-27 12:26:14 +0100133 idx = 0
134 while idx < len(arglist):
135 if arglist[idx] == "--target-dtype":
136 if arglist[idx + 1] not in target_dtypes_args:
137 target_dtypes_args.extend(arglist[idx : idx + 2])
138 idx += 1 # skip over option (and then argument below)
Jeremy Johnson35396f22023-01-04 17:05:25 +0000139 elif arglist[idx] == "--filter":
140 filter_str = arglist[idx + 1]
141 idx += 1 # skip over option (and then argument below)
Jeremy Johnson93d43902022-09-27 12:26:14 +0100142 idx += 1
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000143 build_cmd_neg_test = build_cmd_base.copy()
Jeremy Johnson35396f22023-01-04 17:05:25 +0000144 if filter_str:
145 build_cmd_neg_test.extend(["--filter", filter_str])
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000146 build_cmd_neg_test.extend(["--test-type", "negative"])
Jeremy Johnson93d43902022-09-27 12:26:14 +0100147 # Limit sizes of negative tests
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000148 dim_range = gen_neg_dim_range if gen_neg_dim_range is not None else "1,16"
149
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000150 build_cmd_neg_test.extend(["--tensor-dim-range", dim_range])
151 build_cmd_neg_test.extend(target_dtypes_args)
152 build_cmds_list.append(build_cmd_neg_test)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100153
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000154 logger.debug(f"Creating {operator} tests with {len(build_cmds_list)} parameter(s)")
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100155 error = False
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000156 for i, cmd in enumerate(build_cmds_list):
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100157 try:
158 _run_sh_command(args, args.ref_model_dir.absolute(), cmd)
159 logger.info(
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000160 f"{operator} test batch {(i+1)}/{len(build_cmds_list)} created successfully"
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100161 )
162 except Exception as e:
163 logger.error(
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000164 f"{operator} test batch {(i+1)}/{len(build_cmds_list)} unsuccessful, skipping"
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100165 )
166 logger.error(f" build_op_tests error: {e} ")
167 error = True
168 if error:
169 raise (GenConformanceError())
170
171 return op_build_dir
172
173
174def _check_to_include_test(profile, test_name, exclude_negative_tests=False):
175 """Check test name for exclusions, return False to indicate excluded."""
176 excludes = ["ERRORIF"] if exclude_negative_tests else []
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100177
178 for exclusion in excludes:
179 if f"_{exclusion}_" in test_name:
180 return False
181 return True
182
183
184def _get_all_tests_list(
185 profile, test_root_dir, operator, exclude_negative_tests=False, include_all=False
186):
187 """Create test list based on tests in the test_dir."""
188 test_dir = test_root_dir / operator
189 if not test_dir.is_dir():
190 # Tests are split into multiple dirs, for example: conv2d_1x1, conv2d_3x3
191 test_dir = test_root_dir
192 directories = [
193 tdir for tdir in test_dir.glob("*") if tdir.name.startswith(operator)
194 ]
195 else:
196 directories = [test_dir]
197
198 tests = []
199 for tdir in directories:
200 tests.extend(
201 [
202 test
203 for test in tdir.glob("*")
204 if include_all
205 or _check_to_include_test(profile, test.name, exclude_negative_tests)
206 ]
207 )
208 return tests
209
210
Jeremy Johnson1271c442023-09-05 11:39:26 +0100211def generate_results(args, profile, operator, op_build_dir, supports=[], tests=None):
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100212 """Run tests on reference model and save result to the test directory."""
Jeremy Johnson1271c442023-09-05 11:39:26 +0100213 if "lazy_data_gen" in supports and args.lazy_data_generation:
214 logger.info("Skipping running tests due to lazy data gen")
215 return
216
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100217 num_cores = args.num_cores
218 run_tests_cmd = "tosa_verif_run_tests"
219
220 ref_model_path = args.ref_model_dir / LOCATION_REF_MODEL_BINARY
221 ref_cmd_base = ref_cmd = [
222 run_tests_cmd,
223 "--ref-model-path",
224 str(ref_model_path.absolute()),
225 "-j",
226 str(num_cores),
227 "-v",
228 "-t",
229 ]
230 ref_cmds = []
231
232 if not tests:
233 # Do not need to run ERRORIF tests as they don't have result files
234 tests = _get_all_tests_list(
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100235 profile, op_build_dir, operator, exclude_negative_tests=True
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100236 )
237
238 for test in tests:
239 ref_cmd = ref_cmd_base.copy()
240 ref_cmd.append(str(test))
241 ref_cmds.append(ref_cmd)
242
243 fail_string = "UNEXPECTED_FAILURE"
244 failed_counter = 0
245
246 job_pool = mp.Pool(args.num_cores)
247 sh_partial = partial(_run_sh_command, args, args.ref_model_dir.absolute())
248 pool_results = job_pool.map(sh_partial, ref_cmds)
249 job_pool.close()
250 job_pool.join()
251
252 # Use captured output for run_sh_command to work out if test passed.
253 for i, rc in enumerate(pool_results):
254 if fail_string in str(rc[0]):
255 logger.error(f"Test {i+1}/{len(ref_cmds)}: {ref_cmds[i][-1]} failed.")
256 failed_counter += 1
257 else:
258 logger.info(f"Test {i+1}/{len(ref_cmds)}: {ref_cmds[i][-1]} passed.")
259
260 logger.info(f"{len(ref_cmds)-failed_counter}/{len(ref_cmds)} tests passed")
261 logger.info("Ran tests on model and saved results of passing tests")
262
263
264def convert_tests(
265 args,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100266 profile,
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100267 operator,
268 op_build_dir,
269 output_dir,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100270 op_profiles_list,
Jeremy Johnson1271c442023-09-05 11:39:26 +0100271 supports=[],
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100272 tests=None,
273 group=None,
274 trim_op_subdir=False,
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000275 tags=None,
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100276):
277 """Convert tests to JSON and save to output directory."""
278 ref_model_dir = args.ref_model_dir
279
280 if group:
281 output_dir = output_dir / group
282
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000283 c2c_args_base = ["--strict", "--ref-model-directory", str(ref_model_dir)]
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100284 # This op maybe in more than one profile - e.g. tosa_bi and tosa_mi
285 # even if we are only producing tests for tosa_mi
286 for op_profile in op_profiles_list:
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000287 c2c_args_base.extend(["--profile", op_profile])
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000288 if tags is not None:
289 for tag in tags:
290 c2c_args_base.extend(["--tag", tag])
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100291 if args.framework_schema:
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000292 c2c_args_base.extend(["--framework-schema", str(args.framework_schema)])
Jeremy Johnson1271c442023-09-05 11:39:26 +0100293 if "lazy_data_gen" in supports and args.lazy_data_generation:
294 c2c_args_base.append("--lazy-data-generation")
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000295 c2c_args_base.append("--output-directory")
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100296
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000297 c2c_args_list = []
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100298
299 if not tests:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100300 tests = _get_all_tests_list(profile, op_build_dir, operator)
301 logger.info(f"Converting all {profile} profile tests")
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100302
303 # Controls if we copy the tests in their operator sub-directory or not
304 output_dir_relative_pos = -1 if trim_op_subdir else -2
305 for test in tests:
306 logger.info(f"Test chosen: {test}")
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000307 c2c_args = c2c_args_base.copy()
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100308 full_output_directory = output_dir / test.relative_to(
309 *test.parts[:output_dir_relative_pos]
310 )
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000311 c2c_args.append(str(full_output_directory))
312 c2c_args.append(str(test))
313 c2c_args_list.append(c2c_args)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100314
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000315 if len(c2c_args_list) == 0:
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000316 logger.error(
317 f"No tests found for {operator}. Nothing to convert in {op_build_dir}"
318 )
319 raise (GenConformanceError())
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100320
321 job_pool = mp.Pool(args.num_cores)
322
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000323 pool_results = job_pool.map(c2c_main, c2c_args_list)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100324 job_pool.close()
325 job_pool.join()
326
327 failed_counter = 0
328 for i, result in enumerate(pool_results):
329 if result != 0:
330 logger.error(
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000331 f"test {i+1}/{len(c2c_args_list)}: {c2c_args_list[i][-1]} failed to convert."
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100332 )
333 failed_counter += 1
334 else:
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000335 logger.info(
336 f"test {i+1}/{len(c2c_args_list)}: {c2c_args_list[i][-1]} converted"
337 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100338 logger.info(
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000339 f"{len(c2c_args_list)-failed_counter}/{len(c2c_args_list)} tests successfully converted"
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100340 )
341
342 if failed_counter > 0:
343 logger.error(f"Stopping due to {failed_counter} test conversion errors")
344 raise (GenConformanceError())
345
346 logger.info("Converted tests to JSON and saved to output directory")
347
348 return output_dir
349
350
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100351def get_op_tests_selection(
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000352 args,
353 profile,
354 operator,
355 op_build_dir,
356 selection_config,
357 negative=False,
358 ignore_missing=False,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100359):
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100360 """Use test picker to get subsection of tests generated."""
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000361 # Need a full copy of the config as the selector updates it
362 config = copy.deepcopy(selection_config)
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100363 logger.info("Choosing {} tests".format(("negative" if negative else "positive")))
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100364 try:
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100365 op = Operator.registry[operator](
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000366 op_build_dir, config, negative=negative, ignore_missing=ignore_missing
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100367 )
368 except KeyError:
369 logger.error(f"{operator} operator is not supported by test_select")
370 raise (GenConformanceError())
371
372 return op.select_tests()
373
374
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100375def check_op_tests(args, profile, operator, output_dir):
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100376 """Move test folders than contain files larger than 30MB to new directory."""
377 destination_dir = str(args.output_dir) + "_large_files"
378
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100379 tests = _get_all_tests_list(profile, output_dir, operator, include_all=True)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100380 if not tests:
381 logger.error(
382 f"Couldn't find any tests to size check for {operator} in {output_dir}"
383 )
384 raise (GenConformanceError())
385
386 for tdir in tests:
387 move_dir = False
388 test_files = [file for file in tdir.glob("*")]
389 for file in test_files:
390 file_size = os.stat(file).st_size / 1024**2
391 if file_size > 30:
392 move_dir = True
393
394 if move_dir:
395 move_destination = destination_dir / tdir.relative_to(output_dir)
396 logger.warning(
397 f"{tdir.relative_to(output_dir)} contains files that are too large (>30MB), test moved to new folder: {destination_dir}"
398 )
399
400 if move_destination.is_dir():
401 logger.warning(
402 f"{move_destination} directory already exists, deleting existing."
403 )
404 shutil.rmtree(str(move_destination))
405 shutil.move(str(tdir), move_destination)
406
407
408def copy_rename_framework_tests(args, operator, test_picks):
409 """Copy framework tests into new folder and rename them if needed.
410
411 The tests are renamed to match the framework operator names if an
412 alternate name has been used instead.
413 """
414 framework_tests_dir = args.framework_tests_dir
415 new_tests_dir = args.build_dir / "frameworks" / operator
416 os.makedirs(new_tests_dir, exist_ok=True)
417
418 # Get the framework tests operator name
419 if "alternate_names" in test_picks[operator]:
420 alternate_names = test_picks[operator]["alternate_names"]
421 else:
422 alternate_names = [operator]
423
424 # Get the alternate named test directories for the operator
425 for alt_name in alternate_names:
426 test_prefix = f"test_{alt_name}"
427 test_dirs = list(framework_tests_dir.glob(f"{test_prefix}_*"))
428
429 # Copy tests to new directory and rename to match framework operator names
430 # - if there is just 1 alternate name, replace the full test prefix
431 # test_add_... -> add_...
432 # - if there are multiple alternate names, just replace the "test"
433 # test_concatv2_... -> concatenation_concatv2_...
434 old_prefix = test_prefix if len(alternate_names) == 1 else "test"
435
436 for tdir in test_dirs:
437 new_test_name = tdir.name.replace(old_prefix, operator)
438 copy_destination = new_tests_dir / new_test_name
439 logger.debug(f"copying test folder {tdir} to {copy_destination}")
440 copy_tree(str(tdir), str(copy_destination))
441
442 logger.info(f"Copied and renamed {len(test_dirs)} framework test folders")
443 return new_tests_dir.parent
444
445
446def get_framework_tests_selection(args, operator, test_picks, op_build_dir):
447 """Get the list of pre-chosen tests with relative paths."""
448 try:
449 tests = test_picks[operator]["tests"]
450 except KeyError:
451 logger.error(f"Framework test selection not defined for {operator} operator")
452 raise (GenConformanceError())
453
454 test_paths = [op_build_dir / operator / test for test in tests]
455 return test_paths
456
457
458def parse_args(argv=None):
459 """Parse the arguments."""
460 parser = argparse.ArgumentParser()
Jeremy Johnson88588622022-07-12 16:42:29 +0100461 profiles = list(PROFILE_OPS_INFO.keys())
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100462 profiles.append(PROFILES_ALL)
Jeremy Johnson88588622022-07-12 16:42:29 +0100463 parser.add_argument(
464 "--profile",
465 dest="profile",
466 choices=profiles,
467 default=profiles[0],
468 type=str,
469 help=f"TOSA profile (default is {profiles[0]})",
470 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100471 parser.add_argument(
472 "--operators",
473 type=str,
474 nargs="*",
475 help="The operator(s) to create tests for, if not supplied all tests will be created",
476 )
477 parser.add_argument(
Jeremy Johnson88588622022-07-12 16:42:29 +0100478 "--unit-tests",
479 dest="unit_tests",
480 choices=["operator", "framework", "both"],
481 default="operator",
482 type=str,
483 help="Which unit tests are produced (default is operator)",
484 )
485 parser.add_argument(
486 "--test-type",
487 dest="test_type",
488 choices=["positive", "negative", "both"],
489 default="both",
490 type=str,
491 help="Type of tests produced (default is both)",
492 )
493 parser.add_argument(
Jeremy Johnson1271c442023-09-05 11:39:26 +0100494 "--lazy-data-generation",
495 action="store_true",
496 help="Enable lazy data generation (only for tosa-mi)",
497 )
498 parser.add_argument(
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100499 "--ref-model-directory",
500 dest="ref_model_dir",
501 type=Path,
502 required=True,
503 help="Reference Model directory (must be pre-built)",
504 )
Jeremy Johnson88588622022-07-12 16:42:29 +0100505 parser.add_argument(
Jeremy Johnson93d43902022-09-27 12:26:14 +0100506 "--seed",
507 dest="random_seed",
508 default=DEFAULT_SEED,
509 type=int,
510 help="Random test seed",
511 )
512 parser.add_argument(
Jeremy Johnson88588622022-07-12 16:42:29 +0100513 "--framework-tests-directory",
514 dest="framework_tests_dir",
515 type=Path,
516 default=Path.cwd() / "tests",
517 help="The pre-built framework tests directory (default is tests)",
518 )
519 parser.add_argument(
520 "--framework-schema",
521 dest="framework_schema",
522 type=Path,
523 help="Framework flatbuffers schema needed to convert framework models",
524 )
525 parser.add_argument(
526 "--build-directory",
527 dest="build_dir",
528 type=Path,
529 default=Path.cwd() / "conformance_build",
530 help="Temporary build directory for files created during this process (default is conformance_build)",
531 )
532 parser.add_argument(
533 "--output-directory",
534 dest="output_dir",
535 type=Path,
536 default=Path.cwd() / "conformance",
537 help="Output directory (default is conformance)",
538 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100539 script_dir = Path(__file__).parent.absolute()
540 parser.add_argument(
541 "--test-param-json-directory",
542 dest="param_json_dir",
543 type=Path,
544 default=script_dir,
Jeremy Johnson88588622022-07-12 16:42:29 +0100545 help=f"Test parameters (ops info) JSON file directory (default is {script_dir})",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100546 )
547 parser.add_argument(
548 "--convert-all-tests",
549 action="store_true",
550 help="Converts all tests instead of those picked by test_select",
551 )
552 parser.add_argument(
553 "--keep-large-files",
554 action="store_true",
555 help="Keeps tests that contain files larger than 30MB in output directory",
556 )
557 parser.add_argument(
558 "--capture-output",
559 action="store_true",
560 help="Prints output of running sh commands",
561 )
562 parser.add_argument(
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100563 "-j",
564 dest="num_cores",
565 type=int,
566 default=6,
567 help="Number of simultaneous jobs to split the tasks into for multiprocessing",
568 )
569 parser.add_argument(
570 "-v",
571 dest="verbosity",
572 action="count",
573 default=0,
574 help="Verbosity (can be used multiple times for more details)",
575 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100576 args = parser.parse_args(argv)
577
578 return args
579
580
581def main():
582 args = parse_args()
583
584 if not args.ref_model_dir.is_dir():
585 logger.error(
586 f"Missing or invalid reference model directory: {args.ref_model_dir}"
587 )
588 return 2
589 else:
590 ref_model = args.ref_model_dir / LOCATION_REF_MODEL_BINARY
591 if not ref_model.is_file():
592 logger.error(
593 f"{LOCATION_REF_MODEL_BINARY} not found in {args.ref_model_dir}\nHave you built the reference model?"
594 )
595 return 2
596 if args.unit_tests in ["framework", "both"]:
Jeremy Johnsonc1d1c632023-08-02 17:21:36 +0100597 logger.warning(
598 "DEPRECATION - Framework tests are not part of TOSA conformance testing"
599 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100600 if not args.framework_schema:
601 logger.error(
602 "Need to supply location of Framework flatbuffers schema via --framework-schema"
603 )
604 return 2
605 if not args.framework_tests_dir.is_dir():
606 logger.error(
607 f"Missing or invalid framework tests directory: {args.framework_tests_dir}"
608 )
609 return 2
610
611 loglevels = (logging.WARNING, logging.INFO, logging.DEBUG)
612 loglevel = loglevels[min(args.verbosity, len(loglevels) - 1)]
613 logger.setLevel(loglevel)
614 # Set other loggers the same
615 logging.getLogger("test_select").setLevel(loglevel)
616 logging.getLogger("convert2conformance").setLevel(loglevel)
617
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100618 print(f"Output directory: {args.output_dir}")
619
Jeremy Johnson93d43902022-09-27 12:26:14 +0100620 if args.random_seed != DEFAULT_SEED:
621 logger.warning(
622 "Random test seed changed from default, tests will not match official conformance"
623 )
624
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100625 args.build_dir = args.build_dir.resolve()
626 logger.debug(f"Creating build directory: {args.build_dir}")
627 args.build_dir.mkdir(parents=True, exist_ok=True)
628
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100629 # TODO: For tosa-mi should really generate tosa-bi profile as well
630 # - for now leave it as subset instead of as superset (for testing)
631 if args.profile == PROFILES_ALL:
632 profiles = list(PROFILE_OPS_INFO.keys())
633 else:
634 profiles = [args.profile]
635
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100636 try:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100637 for profile in profiles:
638 print(f"Creating conformance tests for TOSA {profile} profile")
639 # Framework unit tests
640 if args.unit_tests in ["framework", "both"]:
641 logger.debug("Creating FRAMEWORK unit tests")
642 test_picks_file = (
643 args.param_json_dir / PROFILE_OPS_INFO[profile]["framework_tests"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100644 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100645 try:
646 with open(test_picks_file, "r") as fd:
647 test_picks = json.load(fd)
648 except Exception as e:
649 logger.error(
650 f"Couldn't load framework tests info - {test_picks_file}: {e}"
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100651 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100652 return 1
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100653
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100654 operators = args.operators
655 if not operators:
656 # Create tests for all the operators
657 operators = list(test_picks.keys())
658
659 root_output_dir = (
660 args.output_dir / "frameworks" / "tflite" / "operators"
661 )
662 for op in operators:
663 logger.info(f"FRAMEWORK OP: {op}")
664 if op not in test_picks:
665 logger.warning(
666 f"Framework op {op} not found in {test_picks_file} - skipping"
667 )
668 continue
669
670 op_profiles_list = test_picks[op]["profile"]
671 if (
672 args.profile != PROFILES_ALL
673 and args.profile not in op_profiles_list
674 ):
675 # Skip this operator as not part of the profile chosen
676 logger.debug(f"Skipping {op} as not part of {args.profile}")
677 continue
678
679 logger.debug(f"Copying and renaming {op}")
680 framework_test_dir = copy_rename_framework_tests(
681 args, op, test_picks
682 )
683
684 if args.convert_all_tests:
685 logger.debug("Running and converting all framework tests")
686 framework_tests = None # Don't select any
687 else:
688 logger.debug("Running and converting selected framework tests")
689 framework_tests = get_framework_tests_selection(
690 args, op, test_picks, framework_test_dir
691 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100692 convert_tests(
693 args,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100694 profile,
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100695 op,
696 framework_test_dir,
697 root_output_dir,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100698 op_profiles_list,
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100699 tests=framework_tests,
700 trim_op_subdir=True,
701 )
702
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100703 # Operator unit tests
704 if args.unit_tests in ["operator", "both"]:
705 logger.debug("Creating OPERATOR unit tests")
706 test_params_file = (
707 args.param_json_dir
708 / PROFILE_OPS_INFO[profile]["operator_test_params"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100709 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100710 try:
711 with open(test_params_file, "r") as fd:
712 test_params = json.load(fd)
713 except Exception as e:
714 logger.error(
715 f"Couldn't load operator test params - {test_params_file}: {e}"
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100716 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100717 return 1
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100718
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100719 operators = args.operators
720 if not operators:
721 # Create tests for all the operators
722 operators = list(test_params.keys())
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100723
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100724 for op in operators:
725 logger.info(f"OPERATOR: {op}")
726 if op not in test_params:
727 logger.warning(
728 f"{op} operator parameters not found in {test_params_file} - skipping"
729 )
730 continue
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100731
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100732 op_profiles_list = test_params[op]["profile"]
733 if (
734 args.profile != PROFILES_ALL
735 and args.profile not in op_profiles_list
736 ):
737 # Skip this operator as not part of the profile chosen
738 logger.debug(f"Skipping {op} as not part of {args.profile}")
739 continue
740
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100741 operator_group = test_params[op]["group"]
742 root_output_dir = args.output_dir / "operators"
Jeremy Johnson1271c442023-09-05 11:39:26 +0100743 supports = (
744 test_params[op]["support_for"]
745 if "support_for" in test_params[op]
746 else []
747 )
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000748
749 # Iterate through the generation groups selecting tests from each
750 for gen_name, gen_dict in test_params[op]["generation"].items():
751 no_neg_tests = (
752 "no_negative_tests" in gen_dict
753 and gen_dict["no_negative_tests"] == "true"
754 )
755
756 if no_neg_tests:
757 if args.test_type == "negative":
758 logger.info(
759 f"No negative tests for {op} / generation group {gen_name}"
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100760 )
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000761 continue
762 # Only produce positive tests
763 test_type = "positive"
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100764 else:
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000765 test_type = args.test_type
766
767 gen_neg_dim_range = (
768 gen_dict["negative_dim_range"]
769 if "negative_dim_range" in gen_dict
770 else None
771 )
772
773 ignore_missing = gen_name != STANDARD_GENERATOR_GROUP
774 tags = (
775 [gen_name] if gen_name != STANDARD_GENERATOR_GROUP else None
776 )
777
778 op_build_dir = build_op_tests(
779 args,
780 test_type,
781 profile,
782 op,
783 gen_name,
784 gen_dict["generator_args"],
785 gen_neg_dim_range,
Jeremy Johnson1271c442023-09-05 11:39:26 +0100786 supports=supports,
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000787 )
788
Jeremy Johnson0c716862023-04-13 17:18:19 +0100789 # Work out which selection criteria we are using
790 if "selector" in gen_dict:
791 selector_name = gen_dict["selector"]
792 if selector_name not in test_params[op]["selection"]:
793 logger.warn(
794 f"Could not find {selector_name} in selection dict for {op} - using default"
795 )
796 selector_name = "default"
797 else:
798 selector_name = "default"
799 if selector_name not in test_params[op]["selection"]:
800 logger.error(
801 f"Could not find {selector_name} in selection dict for {op}"
802 )
803 raise (GenConformanceError())
804
805 # Selection criteria
806 selection_config = test_params[op]["selection"][selector_name]
807
Jeremy Johnson76c6a552023-09-11 09:30:02 +0100808 if args.convert_all_tests:
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000809 logger.debug(f"Running and converting all {op} tests")
Jeremy Johnson1271c442023-09-05 11:39:26 +0100810 generate_results(
811 args, profile, op, op_build_dir, supports=supports
812 )
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000813 operator_test_list = None
814 else:
815 logger.debug(
816 f"Running and converting selection of {op} tests"
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100817 )
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000818 if test_type in ["positive", "both"]:
Jeremy Johnson76c6a552023-09-11 09:30:02 +0100819 if (
820 "all" in selection_config
821 and selection_config["all"] == "true"
822 ):
823 # Just get all the positive tests
824 tests_gen, tests_gen2 = tee(
825 _get_all_tests_list(
826 profile,
827 op_build_dir,
828 op,
829 exclude_negative_tests=True,
830 )
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000831 )
Jeremy Johnson76c6a552023-09-11 09:30:02 +0100832 else:
833 # Get a selection of positive tests
834 tests_gen, tests_gen2 = tee(
835 get_op_tests_selection(
836 args,
837 profile,
838 op,
839 op_build_dir,
840 selection_config,
841 ignore_missing=ignore_missing,
842 )
843 )
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000844 generate_results(
Jeremy Johnson1271c442023-09-05 11:39:26 +0100845 args,
846 profile,
847 op,
848 op_build_dir,
849 supports=supports,
850 tests=tests_gen,
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000851 )
852 operator_test_list = list(tests_gen2)
853 else:
854 operator_test_list = []
855 if test_type in ["negative", "both"]:
856 operator_test_list.extend(
857 get_op_tests_selection(
858 args,
859 profile,
860 op,
861 op_build_dir,
862 selection_config,
863 negative=True,
864 )
865 )
866 output_dir = convert_tests(
867 args,
868 profile,
869 op,
870 op_build_dir,
871 root_output_dir,
872 op_profiles_list,
Jeremy Johnson1271c442023-09-05 11:39:26 +0100873 supports=supports,
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000874 tests=operator_test_list,
875 group=operator_group,
876 tags=tags,
877 )
878 if not args.keep_large_files:
879 check_op_tests(args, profile, op, output_dir)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100880 except GenConformanceError:
881 return 1
882
883 return 0
884
885
886if __name__ == "__main__":
887 exit(main())