blob: b1f8d0ebf69944ace123bcf57b2c1f52f16541ab [file] [log] [blame]
Jeremy Johnson6179c212022-01-13 13:46:35 +00001#!/usr/bin/env python3
2# Copyright (c) 2021-2022, ARM Limited.
3# SPDX-License-Identifier: Apache-2.0
4"""This script converts generated tests into conformance tests.
5
6It can convert a framework unit test or a reference model unit test.
7It expects the tests have been already run on the reference model
8so it can capture the result as the expected result.
9"""
10import argparse
11import json
12import logging
13import os
14from pathlib import Path
15from typing import Optional
16
17from json2fbbin.json2fbbin import fbbin_to_json
18from json2numpy.json2numpy import npy_to_json
19
20logging.basicConfig(level=logging.INFO)
21logger = logging.getLogger("convert2conformance")
22
23LOCATION_REF_MODEL_SCHEMA = Path("thirdparty/serialization_lib/schema/tosa.fbs")
24LOCATION_REF_MODEL_FLATC = Path(
25 "build/thirdparty/serialization_lib/third_party/flatbuffers/flatc"
26)
27
28NAME_FLATBUFFER_DIR = ["flatbuffer-", "_FW_"]
29NAME_DESC_FILENAME = "desc.json"
30NAME_CONFORMANCE_RESULT_PREFIX = "Conformance-"
31NAME_REFMODEL_RUN_RESULT_SUFFIX = ".runner.tosa_refmodel_sut_run.npy"
32
Jeremy Johnson88588622022-07-12 16:42:29 +010033PROFILES_LIST = ["tosa-bi", "tosa-mi"]
34
Jeremy Johnson6179c212022-01-13 13:46:35 +000035
36def parse_args(argv):
37 """Parse the arguments."""
38 parser = argparse.ArgumentParser()
39 parser.add_argument(
40 "test_dir",
41 default=Path.cwd(),
42 type=Path,
43 nargs="?",
44 help="The test directory to convert (default is CWD)",
45 )
46 parser.add_argument(
47 "--ref-model-directory",
48 dest="ref_model_dir",
49 type=Path,
50 required=True,
51 help="Reference Model directory (must be pre-built)",
52 )
53 parser.add_argument(
54 "--output-directory",
55 dest="output_dir",
56 type=Path,
57 default=Path.cwd() / "conformance",
58 help="Output directory (default is conformance in CWD)",
59 )
60 parser.add_argument(
61 "--framework",
62 dest="framework",
63 choices=["tflite"],
64 default="tflite",
65 help="Framework to convert (default tflite)",
66 )
67 parser.add_argument(
68 "--framework-schema",
69 dest="framework_schema",
70 type=Path,
71 help="Framework schema needed to convert framework models",
72 )
73 parser.add_argument(
Jeremy Johnson88588622022-07-12 16:42:29 +010074 "--profile",
75 dest="profile",
76 choices=PROFILES_LIST,
77 action="append",
78 required=True,
79 help="Profiles this test is suitable for. May be repeated",
80 )
81 parser.add_argument(
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +000082 "--tag",
83 dest="tag",
84 action="append",
85 type=str,
86 help="Optional string tag mark this test with. May be repeated",
87 )
88 parser.add_argument(
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +000089 "--strict",
90 dest="strict",
91 action="store_true",
92 help="Output directory must not contain the same test directory",
93 )
94 parser.add_argument(
Jeremy Johnson6179c212022-01-13 13:46:35 +000095 "-v", "--verbose", dest="verbose", action="store_true", help="Verbose operation"
96 )
97 args = parser.parse_args(argv)
Jeremy Johnson88588622022-07-12 16:42:29 +010098
Jeremy Johnson6179c212022-01-13 13:46:35 +000099 return args
100
101
102def find_ref_model_artifacts(path: Path):
103 """Check the location of the flatc compiler and schema artifacts."""
104 flatc = path / LOCATION_REF_MODEL_FLATC
105 schema = path / LOCATION_REF_MODEL_SCHEMA
106 if not flatc.is_file():
107 raise Exception(
108 f"flatc not found in {flatc}\nHave you built the flatbuffers compiler?"
109 )
110 if not schema.is_file():
111 raise Exception(
112 f"TOSA schema not found at {schema}\nHave you checked out the submodules?"
113 )
114 return flatc, schema
115
116
117def find_framework_artifacts(framework: str, schema_path: Path, desc_file: Path):
118 """Check that any required schema has been supplied for conversion."""
119 if framework == "tflite":
120 if not schema_path:
121 raise Exception("the following arguments are required: --framework-schema")
122 elif not schema_path.is_file():
123 raise Exception(f"framework schema not found at {schema_path}")
124 model = desc_file.parent.parent / "model.tflite"
125 if not model.is_file():
126 raise Exception(f"Model file not found at {model}")
127 return schema_path, model
128 return None, None
129
130
131def get_framework_name(name_array: list, framework: str):
132 """Get the framework conversion directory name."""
133 name = ""
134 for part in name_array:
135 if part == "_FW_":
136 part = framework
137 name = f"{name}{part}"
138 return name
139
140
141def convert_flatbuffer_file(flatc: Path, schema: Path, model_file: Path, output: Path):
142 """Convert the flatbuffer binary into JSON."""
143 try:
144 fbbin_to_json(flatc, schema, model_file, output)
145 except Exception as e:
146 logger.error(f"Failed to convert flatbuffer binary:\n{e}")
147 return None
148
149 if model_file.name == "model.tflite":
150 file_name = "model-tflite.json"
151 os.rename(output / "model.json", output / file_name)
152 else:
153 file_name = model_file.stem + ".json"
154 return output / file_name
155
156
157def convert_numpy_file(n_file: Path, output: Path, outname: Optional[str] = None):
158 """Convert a numpy file into a JSON file."""
159 j_file = output / (outname if outname else (n_file.stem + ".json"))
160 npy_to_json(n_file, j_file)
161 return j_file
162
163
164def update_desc_json(
Jeremy Johnson88588622022-07-12 16:42:29 +0100165 test_dir: Path,
166 test_desc,
167 output_dir: Optional[Path] = None,
168 create_result=True,
169 profiles=None,
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000170 tags=None,
Jeremy Johnson6179c212022-01-13 13:46:35 +0000171):
172 """Update the desc.json format for conformance and optionally create result."""
173 ofm_files = []
174 cfm_files = []
175 if not output_dir:
176 output_dir = test_dir
177 for index, ofm in enumerate(test_desc["ofm_file"]):
178 ofm_path = test_dir / ofm
179 if not test_desc["expected_failure"]:
180 cfm = NAME_CONFORMANCE_RESULT_PREFIX + test_desc["ofm_name"][index]
181 if create_result:
182 if ofm_path.is_file():
183 # Use the desc.json name
184 ofm_refmodel = ofm_path
185 else:
186 # Adjust for renaming due to tosa_verif_run_tests
187 ofm_refmodel = ofm_path.with_suffix(NAME_REFMODEL_RUN_RESULT_SUFFIX)
188 # Create conformance result
189 if ofm_refmodel.is_file():
190 convert_numpy_file(ofm_refmodel, output_dir, outname=cfm + ".json")
191 else:
192 logger.error(f"Missing result file {ofm_path}")
193 return None
194 cfm_files.append(cfm + ".npy")
195 # Remove path and "ref-"/"ref_model_" from output filenames
196 ofm_files.append(strip_ref_output_name(ofm_path.name))
197
198 # Rewrite output file names as they can be relative, but keep them npys
199 test_desc["ofm_file"] = ofm_files
200 if not test_desc["expected_failure"]:
201 # Output expected result file for conformance if expected pass
202 test_desc["expected_result_file"] = cfm_files
Jeremy Johnson88588622022-07-12 16:42:29 +0100203
204 # Add supported profiles
205 if profiles is None:
206 # Assume base profile
207 profiles = [PROFILES_LIST[0]]
208 test_desc["profile"] = profiles
209
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000210 # Add tags (if any)
211 if tags is not None:
212 test_desc["tag"] = tags
213
Jeremy Johnson6179c212022-01-13 13:46:35 +0000214 return test_desc
215
216
217def strip_ref_output_name(name):
218 """Remove mentions of reference from output files."""
219 if name.startswith("ref-"):
220 name = name[4:]
221 if name.startswith("ref_model_"):
222 name = name[10:]
223 return name
224
225
226def main(argv=None):
227 """Convert the given directory to a conformance test."""
228 args = parse_args(argv)
229 # Verbosity
230 if args.verbose:
231 logger.setLevel(logging.DEBUG)
232
233 # Check we can get the files we need
234 try:
235 flatc, schema = find_ref_model_artifacts(args.ref_model_dir)
236 except Exception as err:
237 logger.error(err)
238 return 2
239
240 # Work out where the desc.json file is
241 desc_filename = args.test_dir / NAME_DESC_FILENAME
242 framework_conversion = False
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100243 test_type_desc = "unknown"
Jeremy Johnson6179c212022-01-13 13:46:35 +0000244 if desc_filename.is_file():
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100245 logger.debug("Found TOSA operator unit test")
246 test_type_desc = "TOSA operator"
Jeremy Johnson6179c212022-01-13 13:46:35 +0000247 else:
248 desc_filename = (
249 args.test_dir
250 / get_framework_name(NAME_FLATBUFFER_DIR, args.framework)
251 / NAME_DESC_FILENAME
252 )
253 if desc_filename.is_file():
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100254 logger.debug(f"Found framework unit test for {args.framework}")
255 test_type_desc = f"{args.framework}"
Jeremy Johnson6179c212022-01-13 13:46:35 +0000256 framework_conversion = True
257 else:
258 logger.error(f"Could not find {NAME_DESC_FILENAME} in {args.test_dir}")
259 return 2
260 logger.debug(f"desc.json file: {desc_filename}")
261
262 # Check for required files for framework conversion
263 if framework_conversion:
264 try:
265 framework_schema, framework_filename = find_framework_artifacts(
266 args.framework, args.framework_schema, desc_filename
267 )
268 except Exception as err:
269 logger.error(err)
270 return 2
271 else:
272 framework_schema, framework_filename = None, None
273
274 # Open the meta desc.json file
275 with open(desc_filename, mode="r") as fd:
276 test_desc = json.load(fd)
277
278 if "tosa_file" not in test_desc:
279 logger.error(f"Unsupported desc.json file found {desc_filename}")
280 return 2
281
282 # Dictionary fix
283 if "ifm_name" not in test_desc:
284 logger.warn("Old format desc.json file found - attempting to fix up")
285 test_desc["ifm_name"] = test_desc["ifm_placeholder"]
286 del test_desc["ifm_placeholder"]
287
288 # Make the output directory if needed
289 try:
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000290 args.output_dir.mkdir(parents=True, exist_ok=(not args.strict))
Jeremy Johnson6179c212022-01-13 13:46:35 +0000291 except FileExistsError:
Jeremy Johnsondd8d9c22022-12-12 14:18:10 +0000292 if args.strict:
293 logger.error(f"{args.output_dir} already exists")
294 else:
295 logger.error(f"{args.output_dir} is not a directory")
Jeremy Johnson6179c212022-01-13 13:46:35 +0000296 return 2
297
298 # Convert the TOSA flatbuffer binary
299 tosa_filename = desc_filename.parent / test_desc["tosa_file"]
300 tosa_filename = convert_flatbuffer_file(
301 flatc, schema, tosa_filename, args.output_dir
302 )
303 if not tosa_filename:
304 # Failed to convert the file, json2fbbin will have printed an error
305 return 1
306 else:
307 # Replace binary with JSON name
308 test_desc["tosa_file"] = tosa_filename.name
309
310 if framework_conversion and framework_filename:
311 # Convert the framework flatbuffer binary
312 framework_filename = convert_flatbuffer_file(
313 flatc, framework_schema, framework_filename, args.output_dir
314 )
315 if not framework_filename:
316 # Failed to convert the file, json2fbbin will have printed an error
317 return 1
318
319 # Convert input files to JSON
320 ifm_files = []
321 for file in test_desc["ifm_file"]:
322 if file is None:
323 ifm_files.append(None)
324 else:
325 path = desc_filename.parent / file
326 convert_numpy_file(path, args.output_dir)
327 ifm_files.append(path.name)
328 # Rewrite input file names to make sure the paths are correct,
329 # but keep them numpys as the test runner will convert them back
330 # before giving them to the SUT
331 test_desc["ifm_file"] = ifm_files
332
333 # Update desc.json and convert result files to JSON
334 test_desc = update_desc_json(
Jeremy Johnson88588622022-07-12 16:42:29 +0100335 desc_filename.parent,
336 test_desc,
337 output_dir=args.output_dir,
338 create_result=True,
339 profiles=args.profile,
Jeremy Johnsonfd05bb32023-02-07 16:39:24 +0000340 tags=args.tag,
Jeremy Johnson6179c212022-01-13 13:46:35 +0000341 )
342 if not test_desc:
343 # Error from conversion/update
344 return 1
345
346 # Output new desc.json
347 new_desc_filename = args.output_dir / NAME_DESC_FILENAME
348 with open(new_desc_filename, "w") as fd:
349 json.dump(test_desc, fd, indent=2)
350
Jeremy Johnson0ecfa372022-06-30 14:27:56 +0100351 logger.info(f"Converted {test_type_desc} test to {args.output_dir}")
Jeremy Johnson6179c212022-01-13 13:46:35 +0000352 return 0
353
354
355if __name__ == "__main__":
356 exit(main())