blob: 7182d9c275d1dfd09691287ae10f59fbac0b5ae3 [file] [log] [blame]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +01001#!/usr/bin/env python3
2# Copyright (c) 2021-2022, ARM Limited.
3# SPDX-License-Identifier: Apache-2.0
4"""Build conformance tests.
5
6Steps:
7- Specific input shapes (or tests) are specified and produced by using the
8 settings in the .json files.
9- Tests are selected to produce a good coverage.
10- Tests are run on the reference model to produce the correct output files.
11- Tests are converted into JSON format and saved to desired output directory.
12"""
13import argparse
14import json
15import logging
16import multiprocessing as mp
17import os
18import shlex
19import shutil
20import subprocess
21from functools import partial
22from itertools import tee
23from pathlib import Path
24
25from conformance.test_select import Operator
26from convert2conformance.convert2conformance import main as c2c_main
27from distutils.dir_util import copy_tree
28
29logging.basicConfig()
30logger = logging.getLogger("tosa_verif_conformance_generator")
31
32# Configuration for each TOSA profile
33PROFILE_OPS_INFO = {
Jeremy Johnson88588622022-07-12 16:42:29 +010034 "tosa-bi": {
Jeremy Johnson0ecfa372022-06-30 14:27:56 +010035 "operator_test_params": "tosa_base_profile_ops_info.json",
36 "framework_tests": "tosa_base_profile_framework_ops_info.json",
Jeremy Johnson93d43902022-09-27 12:26:14 +010037 "exclude_types": [],
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010038 },
39 "tosa-mi": {
40 # Note: This is just the extra tests not in the base profile!
41 "operator_test_params": "tosa_main_profile_ops_info.json",
42 "framework_tests": "tosa_main_profile_framework_ops_info.json",
43 "exclude_types": [],
44 },
Jeremy Johnson0ecfa372022-06-30 14:27:56 +010045}
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +010046PROFILES_ALL = "all"
Jeremy Johnson0ecfa372022-06-30 14:27:56 +010047
48LOCATION_REF_MODEL_BINARY = Path("build/reference_model/tosa_reference_model")
49
Jeremy Johnson93d43902022-09-27 12:26:14 +010050DEFAULT_SEED = 42
51
Jeremy Johnson0ecfa372022-06-30 14:27:56 +010052
53class GenConformanceError(Exception):
54 """Generation error reporting exception."""
55
56 pass
57
58
59def _run_sh_command(args, cwd, full_cmd):
60 """Run an external command and capture stdout/stderr."""
61 # Quote the command line for printing
62 full_cmd_esc = [shlex.quote(x) for x in full_cmd]
63 if args.capture_output:
64 logger.debug(f"Command: {full_cmd_esc}")
65
66 rc = subprocess.run(
67 full_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=cwd
68 )
69
70 if args.capture_output:
71 stdout = rc.stdout.decode("utf-8")
72 logger.debug(f"stdout: \n{stdout}")
73 if rc.returncode != 0:
74
75 raise Exception(
76 "Error running command: {}.\n{}".format(
77 " ".join(full_cmd_esc), rc.stderr.decode("utf-8")
78 )
79 )
80 return (rc.stdout, rc.stderr)
81
82
Jeremy Johnsond88c3b32022-12-01 14:46:14 +000083def build_op_tests(args, profile, operator, test_params):
Jeremy Johnson0ecfa372022-06-30 14:27:56 +010084 """Build tests for a given operator.
85
86 Builds a set of tests based on the parameters defined in test_params
87
88 Returns operator output directory
89 """
90 assert operator in test_params
91
92 build_tests_cmd = "tosa_verif_build_tests"
Jeremy Johnsond88c3b32022-12-01 14:46:14 +000093 op_build_dir = args.build_dir / profile
Jeremy Johnson0ecfa372022-06-30 14:27:56 +010094
95 ref_cmd_base = [
96 build_tests_cmd,
97 "--filter",
98 operator,
99 "-o",
100 str(op_build_dir),
101 "--seed",
Jeremy Johnson93d43902022-09-27 12:26:14 +0100102 str(args.random_seed),
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100103 ]
104
105 ref_cmds = []
106
107 if args.test_type in ["positive", "both"]:
108 # Append extra parameters and run test generator for each set of parameters.
109 for arglist in test_params[operator]["generator_args"]:
110 ref_cmd_pos_test = ref_cmd_base.copy()
Jeremy Johnson93d43902022-09-27 12:26:14 +0100111 ref_cmd_pos_test.extend(["--test-type", "positive"])
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100112 ref_cmd_pos_test.extend(arglist)
113 ref_cmds.append(ref_cmd_pos_test)
114
115 if args.test_type in ["negative", "both"]:
Jeremy Johnson93d43902022-09-27 12:26:14 +0100116 # Get target-dtypes options only to limit tests to those needed
117 target_dtypes_args = []
118 for arglist in test_params[operator]["generator_args"]:
119 idx = 0
120 while idx < len(arglist):
121 if arglist[idx] == "--target-dtype":
122 if arglist[idx + 1] not in target_dtypes_args:
123 target_dtypes_args.extend(arglist[idx : idx + 2])
124 idx += 1 # skip over option (and then argument below)
125 idx += 1
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100126 ref_cmd_neg_test = ref_cmd_base.copy()
127 ref_cmd_neg_test.extend(["--test-type", "negative"])
Jeremy Johnson93d43902022-09-27 12:26:14 +0100128 # Limit sizes of negative tests
129 ref_cmd_neg_test.extend(["--tensor-dim-range", "1,16"])
130 ref_cmd_neg_test.extend(target_dtypes_args)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100131 ref_cmds.append(ref_cmd_neg_test)
132
133 logger.debug(f"Creating {operator} tests with {len(ref_cmds)} parameter(s)")
134 error = False
135 for i, cmd in enumerate(ref_cmds):
136 try:
137 _run_sh_command(args, args.ref_model_dir.absolute(), cmd)
138 logger.info(
139 f"{operator} test batch {(i+1)}/{len(ref_cmds)} created successfully"
140 )
141 except Exception as e:
142 logger.error(
143 f"{operator} test batch {(i+1)}/{len(ref_cmds)} unsuccessful, skipping"
144 )
145 logger.error(f" build_op_tests error: {e} ")
146 error = True
147 if error:
148 raise (GenConformanceError())
149
150 return op_build_dir
151
152
153def _check_to_include_test(profile, test_name, exclude_negative_tests=False):
154 """Check test name for exclusions, return False to indicate excluded."""
155 excludes = ["ERRORIF"] if exclude_negative_tests else []
156 excludes.extend(PROFILE_OPS_INFO[profile]["exclude_types"])
157
158 for exclusion in excludes:
159 if f"_{exclusion}_" in test_name:
160 return False
161 return True
162
163
164def _get_all_tests_list(
165 profile, test_root_dir, operator, exclude_negative_tests=False, include_all=False
166):
167 """Create test list based on tests in the test_dir."""
168 test_dir = test_root_dir / operator
169 if not test_dir.is_dir():
170 # Tests are split into multiple dirs, for example: conv2d_1x1, conv2d_3x3
171 test_dir = test_root_dir
172 directories = [
173 tdir for tdir in test_dir.glob("*") if tdir.name.startswith(operator)
174 ]
175 else:
176 directories = [test_dir]
177
178 tests = []
179 for tdir in directories:
180 tests.extend(
181 [
182 test
183 for test in tdir.glob("*")
184 if include_all
185 or _check_to_include_test(profile, test.name, exclude_negative_tests)
186 ]
187 )
188 return tests
189
190
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100191def generate_results(args, profile, operator, op_build_dir, tests=None):
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100192 """Run tests on reference model and save result to the test directory."""
193 num_cores = args.num_cores
194 run_tests_cmd = "tosa_verif_run_tests"
195
196 ref_model_path = args.ref_model_dir / LOCATION_REF_MODEL_BINARY
197 ref_cmd_base = ref_cmd = [
198 run_tests_cmd,
199 "--ref-model-path",
200 str(ref_model_path.absolute()),
201 "-j",
202 str(num_cores),
203 "-v",
204 "-t",
205 ]
206 ref_cmds = []
207
208 if not tests:
209 # Do not need to run ERRORIF tests as they don't have result files
210 tests = _get_all_tests_list(
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100211 profile, op_build_dir, operator, exclude_negative_tests=True
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100212 )
213
214 for test in tests:
215 ref_cmd = ref_cmd_base.copy()
216 ref_cmd.append(str(test))
217 ref_cmds.append(ref_cmd)
218
219 fail_string = "UNEXPECTED_FAILURE"
220 failed_counter = 0
221
222 job_pool = mp.Pool(args.num_cores)
223 sh_partial = partial(_run_sh_command, args, args.ref_model_dir.absolute())
224 pool_results = job_pool.map(sh_partial, ref_cmds)
225 job_pool.close()
226 job_pool.join()
227
228 # Use captured output for run_sh_command to work out if test passed.
229 for i, rc in enumerate(pool_results):
230 if fail_string in str(rc[0]):
231 logger.error(f"Test {i+1}/{len(ref_cmds)}: {ref_cmds[i][-1]} failed.")
232 failed_counter += 1
233 else:
234 logger.info(f"Test {i+1}/{len(ref_cmds)}: {ref_cmds[i][-1]} passed.")
235
236 logger.info(f"{len(ref_cmds)-failed_counter}/{len(ref_cmds)} tests passed")
237 logger.info("Ran tests on model and saved results of passing tests")
238
239
240def convert_tests(
241 args,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100242 profile,
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100243 operator,
244 op_build_dir,
245 output_dir,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100246 op_profiles_list,
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100247 tests=None,
248 group=None,
249 trim_op_subdir=False,
250):
251 """Convert tests to JSON and save to output directory."""
252 ref_model_dir = args.ref_model_dir
253
254 if group:
255 output_dir = output_dir / group
256
257 ref_cmd_base = ["--ref-model-directory", str(ref_model_dir)]
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100258 # This op maybe in more than one profile - e.g. tosa_bi and tosa_mi
259 # even if we are only producing tests for tosa_mi
260 for op_profile in op_profiles_list:
261 ref_cmd_base.extend(["--profile", op_profile])
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100262 if args.framework_schema:
263 ref_cmd_base.extend(["--framework-schema", str(args.framework_schema)])
264 ref_cmd_base.append("--output-directory")
265
266 ref_cmds = []
267
268 if not tests:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100269 tests = _get_all_tests_list(profile, op_build_dir, operator)
270 logger.info(f"Converting all {profile} profile tests")
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100271
272 # Controls if we copy the tests in their operator sub-directory or not
273 output_dir_relative_pos = -1 if trim_op_subdir else -2
274 for test in tests:
275 logger.info(f"Test chosen: {test}")
276 ref_cmd = ref_cmd_base.copy()
277 full_output_directory = output_dir / test.relative_to(
278 *test.parts[:output_dir_relative_pos]
279 )
280 ref_cmd.append(str(full_output_directory))
281 ref_cmd.append(str(test))
282 ref_cmds.append(ref_cmd)
283
284 if len(ref_cmds) == 0:
285 logger.warning("No tests found. Nothing to convert")
286 return
287
288 job_pool = mp.Pool(args.num_cores)
289
290 pool_results = job_pool.map(c2c_main, ref_cmds)
291 job_pool.close()
292 job_pool.join()
293
294 failed_counter = 0
295 for i, result in enumerate(pool_results):
296 if result != 0:
297 logger.error(
298 f"test {i+1}/{len(ref_cmds)}: {ref_cmds[i][-1]} failed to convert."
299 )
300 failed_counter += 1
301 else:
302 logger.info(f"test {i+1}/{len(ref_cmds)}: {ref_cmds[i][-1]} converted")
303 logger.info(
304 f"{len(ref_cmds)-failed_counter}/{len(ref_cmds)} tests successfully converted"
305 )
306
307 if failed_counter > 0:
308 logger.error(f"Stopping due to {failed_counter} test conversion errors")
309 raise (GenConformanceError())
310
311 logger.info("Converted tests to JSON and saved to output directory")
312
313 return output_dir
314
315
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100316def get_op_tests_selection(
317 args, profile, operator, op_build_dir, test_params, negative=False
318):
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100319 """Use test picker to get subsection of tests generated."""
320 assert operator in test_params
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100321 logger.info("Choosing {} tests".format(("negative" if negative else "positive")))
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100322 try:
323 op_params = test_params[operator]
324 op = Operator.registry[operator](
325 op_build_dir,
326 op_params,
327 negative,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100328 exclude_types=PROFILE_OPS_INFO[profile]["exclude_types"],
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100329 )
330 except KeyError:
331 logger.error(f"{operator} operator is not supported by test_select")
332 raise (GenConformanceError())
333
334 return op.select_tests()
335
336
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100337def check_op_tests(args, profile, operator, output_dir):
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100338 """Move test folders than contain files larger than 30MB to new directory."""
339 destination_dir = str(args.output_dir) + "_large_files"
340
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100341 tests = _get_all_tests_list(profile, output_dir, operator, include_all=True)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100342 if not tests:
343 logger.error(
344 f"Couldn't find any tests to size check for {operator} in {output_dir}"
345 )
346 raise (GenConformanceError())
347
348 for tdir in tests:
349 move_dir = False
350 test_files = [file for file in tdir.glob("*")]
351 for file in test_files:
352 file_size = os.stat(file).st_size / 1024**2
353 if file_size > 30:
354 move_dir = True
355
356 if move_dir:
357 move_destination = destination_dir / tdir.relative_to(output_dir)
358 logger.warning(
359 f"{tdir.relative_to(output_dir)} contains files that are too large (>30MB), test moved to new folder: {destination_dir}"
360 )
361
362 if move_destination.is_dir():
363 logger.warning(
364 f"{move_destination} directory already exists, deleting existing."
365 )
366 shutil.rmtree(str(move_destination))
367 shutil.move(str(tdir), move_destination)
368
369
370def copy_rename_framework_tests(args, operator, test_picks):
371 """Copy framework tests into new folder and rename them if needed.
372
373 The tests are renamed to match the framework operator names if an
374 alternate name has been used instead.
375 """
376 framework_tests_dir = args.framework_tests_dir
377 new_tests_dir = args.build_dir / "frameworks" / operator
378 os.makedirs(new_tests_dir, exist_ok=True)
379
380 # Get the framework tests operator name
381 if "alternate_names" in test_picks[operator]:
382 alternate_names = test_picks[operator]["alternate_names"]
383 else:
384 alternate_names = [operator]
385
386 # Get the alternate named test directories for the operator
387 for alt_name in alternate_names:
388 test_prefix = f"test_{alt_name}"
389 test_dirs = list(framework_tests_dir.glob(f"{test_prefix}_*"))
390
391 # Copy tests to new directory and rename to match framework operator names
392 # - if there is just 1 alternate name, replace the full test prefix
393 # test_add_... -> add_...
394 # - if there are multiple alternate names, just replace the "test"
395 # test_concatv2_... -> concatenation_concatv2_...
396 old_prefix = test_prefix if len(alternate_names) == 1 else "test"
397
398 for tdir in test_dirs:
399 new_test_name = tdir.name.replace(old_prefix, operator)
400 copy_destination = new_tests_dir / new_test_name
401 logger.debug(f"copying test folder {tdir} to {copy_destination}")
402 copy_tree(str(tdir), str(copy_destination))
403
404 logger.info(f"Copied and renamed {len(test_dirs)} framework test folders")
405 return new_tests_dir.parent
406
407
408def get_framework_tests_selection(args, operator, test_picks, op_build_dir):
409 """Get the list of pre-chosen tests with relative paths."""
410 try:
411 tests = test_picks[operator]["tests"]
412 except KeyError:
413 logger.error(f"Framework test selection not defined for {operator} operator")
414 raise (GenConformanceError())
415
416 test_paths = [op_build_dir / operator / test for test in tests]
417 return test_paths
418
419
420def parse_args(argv=None):
421 """Parse the arguments."""
422 parser = argparse.ArgumentParser()
Jeremy Johnson88588622022-07-12 16:42:29 +0100423 profiles = list(PROFILE_OPS_INFO.keys())
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100424 profiles.append(PROFILES_ALL)
Jeremy Johnson88588622022-07-12 16:42:29 +0100425 parser.add_argument(
426 "--profile",
427 dest="profile",
428 choices=profiles,
429 default=profiles[0],
430 type=str,
431 help=f"TOSA profile (default is {profiles[0]})",
432 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100433 parser.add_argument(
434 "--operators",
435 type=str,
436 nargs="*",
437 help="The operator(s) to create tests for, if not supplied all tests will be created",
438 )
439 parser.add_argument(
Jeremy Johnson88588622022-07-12 16:42:29 +0100440 "--unit-tests",
441 dest="unit_tests",
442 choices=["operator", "framework", "both"],
443 default="operator",
444 type=str,
445 help="Which unit tests are produced (default is operator)",
446 )
447 parser.add_argument(
448 "--test-type",
449 dest="test_type",
450 choices=["positive", "negative", "both"],
451 default="both",
452 type=str,
453 help="Type of tests produced (default is both)",
454 )
455 parser.add_argument(
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100456 "--ref-model-directory",
457 dest="ref_model_dir",
458 type=Path,
459 required=True,
460 help="Reference Model directory (must be pre-built)",
461 )
Jeremy Johnson88588622022-07-12 16:42:29 +0100462 parser.add_argument(
Jeremy Johnson93d43902022-09-27 12:26:14 +0100463 "--seed",
464 dest="random_seed",
465 default=DEFAULT_SEED,
466 type=int,
467 help="Random test seed",
468 )
469 parser.add_argument(
Jeremy Johnson88588622022-07-12 16:42:29 +0100470 "--framework-tests-directory",
471 dest="framework_tests_dir",
472 type=Path,
473 default=Path.cwd() / "tests",
474 help="The pre-built framework tests directory (default is tests)",
475 )
476 parser.add_argument(
477 "--framework-schema",
478 dest="framework_schema",
479 type=Path,
480 help="Framework flatbuffers schema needed to convert framework models",
481 )
482 parser.add_argument(
483 "--build-directory",
484 dest="build_dir",
485 type=Path,
486 default=Path.cwd() / "conformance_build",
487 help="Temporary build directory for files created during this process (default is conformance_build)",
488 )
489 parser.add_argument(
490 "--output-directory",
491 dest="output_dir",
492 type=Path,
493 default=Path.cwd() / "conformance",
494 help="Output directory (default is conformance)",
495 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100496 script_dir = Path(__file__).parent.absolute()
497 parser.add_argument(
498 "--test-param-json-directory",
499 dest="param_json_dir",
500 type=Path,
501 default=script_dir,
Jeremy Johnson88588622022-07-12 16:42:29 +0100502 help=f"Test parameters (ops info) JSON file directory (default is {script_dir})",
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100503 )
504 parser.add_argument(
505 "--convert-all-tests",
506 action="store_true",
507 help="Converts all tests instead of those picked by test_select",
508 )
509 parser.add_argument(
510 "--keep-large-files",
511 action="store_true",
512 help="Keeps tests that contain files larger than 30MB in output directory",
513 )
514 parser.add_argument(
515 "--capture-output",
516 action="store_true",
517 help="Prints output of running sh commands",
518 )
519 parser.add_argument(
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100520 "-j",
521 dest="num_cores",
522 type=int,
523 default=6,
524 help="Number of simultaneous jobs to split the tasks into for multiprocessing",
525 )
526 parser.add_argument(
527 "-v",
528 dest="verbosity",
529 action="count",
530 default=0,
531 help="Verbosity (can be used multiple times for more details)",
532 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100533 args = parser.parse_args(argv)
534
535 return args
536
537
538def main():
539 args = parse_args()
540
541 if not args.ref_model_dir.is_dir():
542 logger.error(
543 f"Missing or invalid reference model directory: {args.ref_model_dir}"
544 )
545 return 2
546 else:
547 ref_model = args.ref_model_dir / LOCATION_REF_MODEL_BINARY
548 if not ref_model.is_file():
549 logger.error(
550 f"{LOCATION_REF_MODEL_BINARY} not found in {args.ref_model_dir}\nHave you built the reference model?"
551 )
552 return 2
553 if args.unit_tests in ["framework", "both"]:
554 if not args.framework_schema:
555 logger.error(
556 "Need to supply location of Framework flatbuffers schema via --framework-schema"
557 )
558 return 2
559 if not args.framework_tests_dir.is_dir():
560 logger.error(
561 f"Missing or invalid framework tests directory: {args.framework_tests_dir}"
562 )
563 return 2
564
565 loglevels = (logging.WARNING, logging.INFO, logging.DEBUG)
566 loglevel = loglevels[min(args.verbosity, len(loglevels) - 1)]
567 logger.setLevel(loglevel)
568 # Set other loggers the same
569 logging.getLogger("test_select").setLevel(loglevel)
570 logging.getLogger("convert2conformance").setLevel(loglevel)
571
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100572 print(f"Output directory: {args.output_dir}")
573
Jeremy Johnson93d43902022-09-27 12:26:14 +0100574 if args.random_seed != DEFAULT_SEED:
575 logger.warning(
576 "Random test seed changed from default, tests will not match official conformance"
577 )
578
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100579 args.build_dir = args.build_dir.resolve()
580 logger.debug(f"Creating build directory: {args.build_dir}")
581 args.build_dir.mkdir(parents=True, exist_ok=True)
582
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100583 # TODO: For tosa-mi should really generate tosa-bi profile as well
584 # - for now leave it as subset instead of as superset (for testing)
585 if args.profile == PROFILES_ALL:
586 profiles = list(PROFILE_OPS_INFO.keys())
587 else:
588 profiles = [args.profile]
589
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100590 try:
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100591 for profile in profiles:
592 print(f"Creating conformance tests for TOSA {profile} profile")
593 # Framework unit tests
594 if args.unit_tests in ["framework", "both"]:
595 logger.debug("Creating FRAMEWORK unit tests")
596 test_picks_file = (
597 args.param_json_dir / PROFILE_OPS_INFO[profile]["framework_tests"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100598 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100599 try:
600 with open(test_picks_file, "r") as fd:
601 test_picks = json.load(fd)
602 except Exception as e:
603 logger.error(
604 f"Couldn't load framework tests info - {test_picks_file}: {e}"
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100605 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100606 return 1
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100607
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100608 operators = args.operators
609 if not operators:
610 # Create tests for all the operators
611 operators = list(test_picks.keys())
612
613 root_output_dir = (
614 args.output_dir / "frameworks" / "tflite" / "operators"
615 )
616 for op in operators:
617 logger.info(f"FRAMEWORK OP: {op}")
618 if op not in test_picks:
619 logger.warning(
620 f"Framework op {op} not found in {test_picks_file} - skipping"
621 )
622 continue
623
624 op_profiles_list = test_picks[op]["profile"]
625 if (
626 args.profile != PROFILES_ALL
627 and args.profile not in op_profiles_list
628 ):
629 # Skip this operator as not part of the profile chosen
630 logger.debug(f"Skipping {op} as not part of {args.profile}")
631 continue
632
633 logger.debug(f"Copying and renaming {op}")
634 framework_test_dir = copy_rename_framework_tests(
635 args, op, test_picks
636 )
637
638 if args.convert_all_tests:
639 logger.debug("Running and converting all framework tests")
640 framework_tests = None # Don't select any
641 else:
642 logger.debug("Running and converting selected framework tests")
643 framework_tests = get_framework_tests_selection(
644 args, op, test_picks, framework_test_dir
645 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100646 convert_tests(
647 args,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100648 profile,
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100649 op,
650 framework_test_dir,
651 root_output_dir,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100652 op_profiles_list,
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100653 tests=framework_tests,
654 trim_op_subdir=True,
655 )
656
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100657 # Operator unit tests
658 if args.unit_tests in ["operator", "both"]:
659 logger.debug("Creating OPERATOR unit tests")
660 test_params_file = (
661 args.param_json_dir
662 / PROFILE_OPS_INFO[profile]["operator_test_params"]
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100663 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100664 try:
665 with open(test_params_file, "r") as fd:
666 test_params = json.load(fd)
667 except Exception as e:
668 logger.error(
669 f"Couldn't load operator test params - {test_params_file}: {e}"
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100670 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100671 return 1
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100672
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100673 operators = args.operators
674 if not operators:
675 # Create tests for all the operators
676 operators = list(test_params.keys())
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100677
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100678 for op in operators:
679 logger.info(f"OPERATOR: {op}")
680 if op not in test_params:
681 logger.warning(
682 f"{op} operator parameters not found in {test_params_file} - skipping"
683 )
684 continue
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100685
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100686 if (
687 args.test_type == "negative"
688 and "no_negative_tests" in test_params[op]
689 and test_params[op]["no_negative_tests"]
690 ):
691 logger.warning(f"No negative tests for {op}")
692 continue
693
694 op_profiles_list = test_params[op]["profile"]
695 if (
696 args.profile != PROFILES_ALL
697 and args.profile not in op_profiles_list
698 ):
699 # Skip this operator as not part of the profile chosen
700 logger.debug(f"Skipping {op} as not part of {args.profile}")
701 continue
702
Jeremy Johnsond88c3b32022-12-01 14:46:14 +0000703 op_build_dir = build_op_tests(args, profile, op, test_params)
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100704
705 operator_group = test_params[op]["group"]
706 root_output_dir = args.output_dir / "operators"
707 if args.convert_all_tests:
708 logger.debug(f"Running and converting all {op} tests")
709 generate_results(args, profile, op, op_build_dir)
710 operator_test_list = None
711 else:
712 logger.debug(f"Running and converting selection of {op} tests")
713 if args.test_type in ["positive", "both"]:
714 tests_gen, tests_gen2 = tee(
715 get_op_tests_selection(
716 args, profile, op, op_build_dir, test_params
717 )
718 )
719 generate_results(args, profile, op, op_build_dir, tests_gen)
720 operator_test_list = list(tests_gen2)
721 else:
722 operator_test_list = []
723 if args.test_type in ["negative", "both"] and (
724 "no_negative_tests" not in test_params[op]
725 or not test_params[op]["no_negative_tests"]
726 ):
727 operator_test_list.extend(
728 get_op_tests_selection(
729 args,
730 profile,
731 op,
732 op_build_dir,
733 test_params,
734 negative=True,
735 )
736 )
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100737 output_dir = convert_tests(
Jeremy Johnson88588622022-07-12 16:42:29 +0100738 args,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100739 profile,
Jeremy Johnson88588622022-07-12 16:42:29 +0100740 op,
741 op_build_dir,
742 root_output_dir,
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100743 op_profiles_list,
744 tests=operator_test_list,
Jeremy Johnson88588622022-07-12 16:42:29 +0100745 group=operator_group,
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100746 )
Jeremy Johnsone4b08ff2022-09-15 10:38:17 +0100747 if not args.keep_large_files:
748 check_op_tests(args, profile, op, output_dir)
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100749 except GenConformanceError:
750 return 1
751
752 return 0
753
754
755if __name__ == "__main__":
756 exit(main())