blob: 9de421bbb7733f30c67aac00b4bdfde7e3dd8839 [file] [log] [blame]
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001# Copyright (c) 2023, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3"""Calls the data generation library to create the test data."""
4import ctypes as ct
5import json
6from pathlib import Path
7
8import numpy as np
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01009import schemavalidation.schemavalidation as sch
Jeremy Johnson65ba8092023-10-09 16:31:13 +010010
11
12class GenerateError(Exception):
13 """Exception raised for errors performing data generation."""
14
15
16class GenerateLibrary:
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010017 """Python interface to the C generate library.
18
19 Simple usage to write out all input files:
20 set_config(test_desc)
21 write_numpy_files(test_path)
22
23 To get data buffers (for const data):
24 get_tensor_data(tensor_name)
25 """
Jeremy Johnson65ba8092023-10-09 16:31:13 +010026
27 def __init__(self, generate_lib_path):
28 """Find the library and set up the interface."""
29 self.lib_path = generate_lib_path
30 if not self.lib_path.is_file():
31 raise GenerateError(f"Could not find generate library - {self.lib_path}")
32
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010033 self.schema_validator = sch.TestDescSchemaValidator()
34
Jeremy Johnson65ba8092023-10-09 16:31:13 +010035 self.test_desc = None
36 self.json_config = None
37 self.lib = ct.cdll.LoadLibrary(self.lib_path)
38
39 self.tgd_generate_data = self.lib.tgd_generate_data
40 self.tgd_generate_data.argtypes = [
41 ct.c_char_p,
42 ct.c_char_p,
43 ct.c_void_p,
44 ct.c_size_t,
45 ]
46 self.tgd_generate_data.restype = ct.c_bool
47
48 def check_config(self, test_desc: dict):
49 """Quick check that the config supports data generation."""
50 return ("meta" in test_desc) and ("data_gen" in test_desc["meta"])
51
52 def set_config(self, test_desc: dict):
53 """Set the test config in the library.
54
55 test_desc - the test desc.json file
56 """
57 self.test_desc = None
58 self.json_config = None
59
60 if not self.check_config(test_desc):
61 raise GenerateError("No meta/data_gen section found in desc.json")
62
63 # Validate the config versus the schema
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010064 self.schema_validator.validate_config(test_desc)
Jeremy Johnson65ba8092023-10-09 16:31:13 +010065
66 self.test_desc = test_desc
67 self.json_config = test_desc["meta"]["data_gen"]
68
69 def _create_buffer(self, dtype: str, shape: tuple):
70 """Helper to create a buffer of the required type."""
Jeremy Johnson718f3472023-11-30 14:18:19 +000071 size = np.prod(shape)
Jeremy Johnson65ba8092023-10-09 16:31:13 +010072
73 if dtype == "FP32":
74 # Create buffer and initialize to zero
75 buffer = (ct.c_float * size)(0)
76 size_bytes = size * 4
Jeremy Johnson718f3472023-11-30 14:18:19 +000077 elif dtype == "FP16":
78 size_bytes = size * 2
79 # Create buffer of bytes and initialize to zero
80 buffer = (ct.c_ubyte * size_bytes)(0)
Jeremy Johnson65ba8092023-10-09 16:31:13 +010081 else:
82 raise GenerateError(f"Unsupported data type {dtype}")
83
84 return buffer, size_bytes
85
Jeremy Johnson718f3472023-11-30 14:18:19 +000086 def _convert_buffer(self, buffer, dtype: str, shape: tuple):
87 """Helper to convert a buffer to a numpy array."""
88 arr = np.ctypeslib.as_array(buffer)
89
90 if dtype == "FP16":
91 # Convert from bytes back to FP16
92 arr = np.frombuffer(arr, np.float16)
93
94 arr = np.reshape(arr, shape)
95
96 return arr
97
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010098 def _data_gen_array(self, json_config: str, tensor_name: str):
99 """Generate the named tensor data and return a numpy array."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100100 try:
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100101 tensor = json_config["tensors"][tensor_name]
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100102 dtype = tensor["data_type"]
103 shape = tuple(tensor["shape"])
104 except KeyError as e:
105 raise GenerateError(
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100106 f"Missing data in json config for input {tensor_name} - {repr(e)}"
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100107 )
108
109 buffer, size_bytes = self._create_buffer(dtype, shape)
110 buffer_ptr = ct.cast(buffer, ct.c_void_p)
111
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100112 json_bytes = bytes(json.dumps(json_config), "utf8")
113
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100114 result = self.tgd_generate_data(
115 ct.c_char_p(json_bytes),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100116 ct.c_char_p(bytes(tensor_name, "utf8")),
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100117 buffer_ptr,
118 ct.c_size_t(size_bytes),
119 )
120 if not result:
121 raise GenerateError("Data generate failed")
122
Jeremy Johnson718f3472023-11-30 14:18:19 +0000123 arr = self._convert_buffer(buffer, dtype, shape)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100124 return arr
125
126 def _data_gen_write(
127 self, test_path: Path, json_config: str, ifm_name: str, ifm_file: str
128 ):
129 """Generate the named tensor data and save it in numpy format."""
130 arr = self._data_gen_array(json_config, ifm_name)
131
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100132 file_name = test_path / ifm_file
133 np.save(file_name, arr)
134
135 def write_numpy_files(self, test_path: Path):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100136 """Write out all the desc.json input tensors to numpy data files."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100137 if self.test_desc is None or self.json_config is None:
138 raise GenerateError("Cannot write numpy files as no config set up")
139
140 try:
141 ifm_names = self.test_desc["ifm_name"]
142 ifm_files = self.test_desc["ifm_file"]
143 except KeyError as e:
144 raise GenerateError(f"Missing data in desc.json - {repr(e)}")
145
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100146 failures = []
147 for iname, ifile in zip(ifm_names, ifm_files):
148 try:
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100149 self._data_gen_write(test_path, self.json_config, iname, ifile)
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100150 except GenerateError as e:
151 failures.append(
152 f"ERROR: Failed to create data for tensor {iname} - {repr(e)}"
153 )
154
155 if len(failures) > 0:
156 raise GenerateError("\n".join(failures))
157
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100158 def get_tensor_data(self, tensor_name: str, json_config=None):
159 """Get a numpy array for a named tensor in the data_gen meta data."""
160 if json_config is None:
161 if self.json_config is None:
162 raise GenerateError("Cannot get tensor data as no config set up")
163 json_config = self.json_config
164 else:
165 # Validate the given config
166 self.schema_validator.validate_config(
167 json_config, schema_type=sch.TD_SCHEMA_DATA_GEN
168 )
169
170 return self._data_gen_array(json_config, tensor_name)
171
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100172
173def main(argv=None):
174 """Simple command line interface for the data generator."""
175 import argparse
176 import conformance.model_files as cmf
177
178 parser = argparse.ArgumentParser()
179 parser.add_argument(
180 "--generate-lib-path",
181 type=Path,
182 help="Path to TOSA generate lib",
183 )
184 parser.add_argument(
185 "path", type=Path, help="the path to the test directory to generate data for"
186 )
187 args = parser.parse_args(argv)
188 test_path = args.path
189
190 if args.generate_lib_path is None:
191 # Try to work out ref model directory and find the verify library
192 # but this default only works for the python developer environment
193 # i.e. when using the scripts/py-dev-env.* scripts
194 # otherwise use the command line option --generate-lib-path to specify path
195 ref_model_dir = Path(__file__).absolute().parents[2]
196 args.generate_lib_path = cmf.find_tosa_file(
197 cmf.TosaFileType.GENERATE_LIBRARY, ref_model_dir, False
198 )
199
200 if not test_path.is_dir():
201 print(f"ERROR: Invalid directory - {test_path}")
202 return 2
203
204 test_desc_path = test_path / "desc.json"
205
206 if not test_desc_path.is_file():
207 print(f"ERROR: No test description found: {test_desc_path}")
208 return 2
209
210 # Load the JSON desc.json
211 try:
212 with test_desc_path.open("r") as fd:
213 test_desc = json.load(fd)
214 except Exception as e:
215 print(f"ERROR: Loading {test_desc_path} - {repr(e)}")
216 return 2
217
218 try:
219 dgl = GenerateLibrary(args.generate_lib_path)
220 if not dgl.check_config(test_desc):
221 print(f"WARNING: No data generation supported for {test_path}")
222 return 2
223
224 dgl.set_config(test_desc)
225 except GenerateError as e:
226 print(f"ERROR: Initializing generate library - {repr(e)}")
227 return 1
228
229 try:
230 dgl.write_numpy_files(test_path)
231 except GenerateError as e:
232 print(f"ERROR: Writing out data files to {test_path}\n{repr(e)}")
233 return 1
234
235
236if __name__ == "__main__":
237 exit(main())