blob: 743475c2d9e71abe3574ca2410529e1c649194ba [file] [log] [blame]
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001# Copyright (c) 2023, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3"""Calls the data generation library to create the test data."""
4import ctypes as ct
5import json
6from pathlib import Path
7
8import numpy as np
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01009import schemavalidation.schemavalidation as sch
Jeremy Johnson65ba8092023-10-09 16:31:13 +010010
11
12class GenerateError(Exception):
13 """Exception raised for errors performing data generation."""
14
15
16class GenerateLibrary:
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010017 """Python interface to the C generate library.
18
19 Simple usage to write out all input files:
20 set_config(test_desc)
21 write_numpy_files(test_path)
22
23 To get data buffers (for const data):
24 get_tensor_data(tensor_name)
25 """
Jeremy Johnson65ba8092023-10-09 16:31:13 +010026
27 def __init__(self, generate_lib_path):
28 """Find the library and set up the interface."""
29 self.lib_path = generate_lib_path
Jeremy Johnson39f34342023-11-27 15:02:04 +000030 if self.lib_path is None or not self.lib_path.is_file():
Jeremy Johnson65ba8092023-10-09 16:31:13 +010031 raise GenerateError(f"Could not find generate library - {self.lib_path}")
32
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010033 self.schema_validator = sch.TestDescSchemaValidator()
34
Jeremy Johnson65ba8092023-10-09 16:31:13 +010035 self.test_desc = None
36 self.json_config = None
37 self.lib = ct.cdll.LoadLibrary(self.lib_path)
38
39 self.tgd_generate_data = self.lib.tgd_generate_data
40 self.tgd_generate_data.argtypes = [
41 ct.c_char_p,
42 ct.c_char_p,
43 ct.c_void_p,
44 ct.c_size_t,
45 ]
46 self.tgd_generate_data.restype = ct.c_bool
47
48 def check_config(self, test_desc: dict):
49 """Quick check that the config supports data generation."""
50 return ("meta" in test_desc) and ("data_gen" in test_desc["meta"])
51
52 def set_config(self, test_desc: dict):
53 """Set the test config in the library.
54
55 test_desc - the test desc.json file
56 """
57 self.test_desc = None
58 self.json_config = None
59
60 if not self.check_config(test_desc):
61 raise GenerateError("No meta/data_gen section found in desc.json")
62
63 # Validate the config versus the schema
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010064 self.schema_validator.validate_config(test_desc)
Jeremy Johnson65ba8092023-10-09 16:31:13 +010065
66 self.test_desc = test_desc
67 self.json_config = test_desc["meta"]["data_gen"]
68
69 def _create_buffer(self, dtype: str, shape: tuple):
70 """Helper to create a buffer of the required type."""
Jeremy Johnson718f3472023-11-30 14:18:19 +000071 size = np.prod(shape)
Jeremy Johnson65ba8092023-10-09 16:31:13 +010072
73 if dtype == "FP32":
74 # Create buffer and initialize to zero
75 buffer = (ct.c_float * size)(0)
76 size_bytes = size * 4
Jeremy Johnson718f3472023-11-30 14:18:19 +000077 elif dtype == "FP16":
78 size_bytes = size * 2
79 # Create buffer of bytes and initialize to zero
80 buffer = (ct.c_ubyte * size_bytes)(0)
Won Jeon64e4bfe2024-01-18 06:31:55 +000081 elif dtype == "INT32" or dtype == "SHAPE":
Jeremy Johnsona8420ad2023-12-07 16:35:28 +000082 # Create buffer and initialize to zero
83 buffer = (ct.c_int32 * size)(0)
84 size_bytes = size * 4
Jeremy Johnson65ba8092023-10-09 16:31:13 +010085 else:
86 raise GenerateError(f"Unsupported data type {dtype}")
87
88 return buffer, size_bytes
89
Jeremy Johnson718f3472023-11-30 14:18:19 +000090 def _convert_buffer(self, buffer, dtype: str, shape: tuple):
91 """Helper to convert a buffer to a numpy array."""
92 arr = np.ctypeslib.as_array(buffer)
93
94 if dtype == "FP16":
95 # Convert from bytes back to FP16
96 arr = np.frombuffer(arr, np.float16)
97
98 arr = np.reshape(arr, shape)
99
100 return arr
101
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100102 def _data_gen_array(self, json_config: str, tensor_name: str):
103 """Generate the named tensor data and return a numpy array."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100104 try:
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100105 tensor = json_config["tensors"][tensor_name]
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100106 dtype = tensor["data_type"]
107 shape = tuple(tensor["shape"])
108 except KeyError as e:
109 raise GenerateError(
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100110 f"Missing data in json config for input {tensor_name} - {repr(e)}"
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100111 )
112
113 buffer, size_bytes = self._create_buffer(dtype, shape)
114 buffer_ptr = ct.cast(buffer, ct.c_void_p)
115
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100116 json_bytes = bytes(json.dumps(json_config), "utf8")
117
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100118 result = self.tgd_generate_data(
119 ct.c_char_p(json_bytes),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100120 ct.c_char_p(bytes(tensor_name, "utf8")),
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100121 buffer_ptr,
122 ct.c_size_t(size_bytes),
123 )
124 if not result:
125 raise GenerateError("Data generate failed")
126
Jeremy Johnson718f3472023-11-30 14:18:19 +0000127 arr = self._convert_buffer(buffer, dtype, shape)
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100128 return arr
129
130 def _data_gen_write(
131 self, test_path: Path, json_config: str, ifm_name: str, ifm_file: str
132 ):
133 """Generate the named tensor data and save it in numpy format."""
134 arr = self._data_gen_array(json_config, ifm_name)
135
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100136 file_name = test_path / ifm_file
137 np.save(file_name, arr)
138
139 def write_numpy_files(self, test_path: Path):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100140 """Write out all the desc.json input tensors to numpy data files."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100141 if self.test_desc is None or self.json_config is None:
142 raise GenerateError("Cannot write numpy files as no config set up")
143
144 try:
145 ifm_names = self.test_desc["ifm_name"]
146 ifm_files = self.test_desc["ifm_file"]
147 except KeyError as e:
148 raise GenerateError(f"Missing data in desc.json - {repr(e)}")
149
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100150 failures = []
151 for iname, ifile in zip(ifm_names, ifm_files):
152 try:
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100153 self._data_gen_write(test_path, self.json_config, iname, ifile)
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100154 except GenerateError as e:
155 failures.append(
156 f"ERROR: Failed to create data for tensor {iname} - {repr(e)}"
157 )
158
159 if len(failures) > 0:
160 raise GenerateError("\n".join(failures))
161
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100162 def get_tensor_data(self, tensor_name: str, json_config=None):
163 """Get a numpy array for a named tensor in the data_gen meta data."""
164 if json_config is None:
165 if self.json_config is None:
166 raise GenerateError("Cannot get tensor data as no config set up")
167 json_config = self.json_config
168 else:
169 # Validate the given config
170 self.schema_validator.validate_config(
171 json_config, schema_type=sch.TD_SCHEMA_DATA_GEN
172 )
173
174 return self._data_gen_array(json_config, tensor_name)
175
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100176
177def main(argv=None):
178 """Simple command line interface for the data generator."""
179 import argparse
180 import conformance.model_files as cmf
181
182 parser = argparse.ArgumentParser()
183 parser.add_argument(
184 "--generate-lib-path",
185 type=Path,
186 help="Path to TOSA generate lib",
187 )
188 parser.add_argument(
189 "path", type=Path, help="the path to the test directory to generate data for"
190 )
191 args = parser.parse_args(argv)
192 test_path = args.path
193
194 if args.generate_lib_path is None:
195 # Try to work out ref model directory and find the verify library
196 # but this default only works for the python developer environment
197 # i.e. when using the scripts/py-dev-env.* scripts
198 # otherwise use the command line option --generate-lib-path to specify path
199 ref_model_dir = Path(__file__).absolute().parents[2]
200 args.generate_lib_path = cmf.find_tosa_file(
201 cmf.TosaFileType.GENERATE_LIBRARY, ref_model_dir, False
202 )
203
204 if not test_path.is_dir():
205 print(f"ERROR: Invalid directory - {test_path}")
206 return 2
207
208 test_desc_path = test_path / "desc.json"
209
210 if not test_desc_path.is_file():
211 print(f"ERROR: No test description found: {test_desc_path}")
212 return 2
213
214 # Load the JSON desc.json
215 try:
216 with test_desc_path.open("r") as fd:
217 test_desc = json.load(fd)
218 except Exception as e:
219 print(f"ERROR: Loading {test_desc_path} - {repr(e)}")
220 return 2
221
222 try:
223 dgl = GenerateLibrary(args.generate_lib_path)
224 if not dgl.check_config(test_desc):
225 print(f"WARNING: No data generation supported for {test_path}")
226 return 2
227
228 dgl.set_config(test_desc)
229 except GenerateError as e:
230 print(f"ERROR: Initializing generate library - {repr(e)}")
231 return 1
232
233 try:
234 dgl.write_numpy_files(test_path)
235 except GenerateError as e:
236 print(f"ERROR: Writing out data files to {test_path}\n{repr(e)}")
237 return 1
238
239
240if __name__ == "__main__":
241 exit(main())