blob: 0d590847989c27f155bc7100a9d047df4192c910 [file] [log] [blame]
Jeremy Johnson65ba8092023-10-09 16:31:13 +01001# Copyright (c) 2023, ARM Limited.
2# SPDX-License-Identifier: Apache-2.0
3"""Calls the data generation library to create the test data."""
4import ctypes as ct
5import json
6from pathlib import Path
7
8import numpy as np
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +01009import schemavalidation.schemavalidation as sch
Jeremy Johnson65ba8092023-10-09 16:31:13 +010010
11
12class GenerateError(Exception):
13 """Exception raised for errors performing data generation."""
14
15
16class GenerateLibrary:
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010017 """Python interface to the C generate library.
18
19 Simple usage to write out all input files:
20 set_config(test_desc)
21 write_numpy_files(test_path)
22
23 To get data buffers (for const data):
24 get_tensor_data(tensor_name)
25 """
Jeremy Johnson65ba8092023-10-09 16:31:13 +010026
27 def __init__(self, generate_lib_path):
28 """Find the library and set up the interface."""
29 self.lib_path = generate_lib_path
30 if not self.lib_path.is_file():
31 raise GenerateError(f"Could not find generate library - {self.lib_path}")
32
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010033 self.schema_validator = sch.TestDescSchemaValidator()
34
Jeremy Johnson65ba8092023-10-09 16:31:13 +010035 self.test_desc = None
36 self.json_config = None
37 self.lib = ct.cdll.LoadLibrary(self.lib_path)
38
39 self.tgd_generate_data = self.lib.tgd_generate_data
40 self.tgd_generate_data.argtypes = [
41 ct.c_char_p,
42 ct.c_char_p,
43 ct.c_void_p,
44 ct.c_size_t,
45 ]
46 self.tgd_generate_data.restype = ct.c_bool
47
48 def check_config(self, test_desc: dict):
49 """Quick check that the config supports data generation."""
50 return ("meta" in test_desc) and ("data_gen" in test_desc["meta"])
51
52 def set_config(self, test_desc: dict):
53 """Set the test config in the library.
54
55 test_desc - the test desc.json file
56 """
57 self.test_desc = None
58 self.json_config = None
59
60 if not self.check_config(test_desc):
61 raise GenerateError("No meta/data_gen section found in desc.json")
62
63 # Validate the config versus the schema
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010064 self.schema_validator.validate_config(test_desc)
Jeremy Johnson65ba8092023-10-09 16:31:13 +010065
66 self.test_desc = test_desc
67 self.json_config = test_desc["meta"]["data_gen"]
68
69 def _create_buffer(self, dtype: str, shape: tuple):
70 """Helper to create a buffer of the required type."""
71 size = 1
72 for dim in shape:
73 size *= dim
74
75 if dtype == "FP32":
76 # Create buffer and initialize to zero
77 buffer = (ct.c_float * size)(0)
78 size_bytes = size * 4
79 else:
80 raise GenerateError(f"Unsupported data type {dtype}")
81
82 return buffer, size_bytes
83
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010084 def _data_gen_array(self, json_config: str, tensor_name: str):
85 """Generate the named tensor data and return a numpy array."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +010086 try:
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010087 tensor = json_config["tensors"][tensor_name]
Jeremy Johnson65ba8092023-10-09 16:31:13 +010088 dtype = tensor["data_type"]
89 shape = tuple(tensor["shape"])
90 except KeyError as e:
91 raise GenerateError(
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010092 f"Missing data in json config for input {tensor_name} - {repr(e)}"
Jeremy Johnson65ba8092023-10-09 16:31:13 +010093 )
94
95 buffer, size_bytes = self._create_buffer(dtype, shape)
96 buffer_ptr = ct.cast(buffer, ct.c_void_p)
97
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +010098 json_bytes = bytes(json.dumps(json_config), "utf8")
99
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100100 result = self.tgd_generate_data(
101 ct.c_char_p(json_bytes),
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100102 ct.c_char_p(bytes(tensor_name, "utf8")),
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100103 buffer_ptr,
104 ct.c_size_t(size_bytes),
105 )
106 if not result:
107 raise GenerateError("Data generate failed")
108
109 arr = np.ctypeslib.as_array(buffer)
110 arr = np.reshape(arr, shape)
111
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100112 return arr
113
114 def _data_gen_write(
115 self, test_path: Path, json_config: str, ifm_name: str, ifm_file: str
116 ):
117 """Generate the named tensor data and save it in numpy format."""
118 arr = self._data_gen_array(json_config, ifm_name)
119
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100120 file_name = test_path / ifm_file
121 np.save(file_name, arr)
122
123 def write_numpy_files(self, test_path: Path):
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100124 """Write out all the desc.json input tensors to numpy data files."""
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100125 if self.test_desc is None or self.json_config is None:
126 raise GenerateError("Cannot write numpy files as no config set up")
127
128 try:
129 ifm_names = self.test_desc["ifm_name"]
130 ifm_files = self.test_desc["ifm_file"]
131 except KeyError as e:
132 raise GenerateError(f"Missing data in desc.json - {repr(e)}")
133
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100134 failures = []
135 for iname, ifile in zip(ifm_names, ifm_files):
136 try:
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100137 self._data_gen_write(test_path, self.json_config, iname, ifile)
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100138 except GenerateError as e:
139 failures.append(
140 f"ERROR: Failed to create data for tensor {iname} - {repr(e)}"
141 )
142
143 if len(failures) > 0:
144 raise GenerateError("\n".join(failures))
145
Jeremy Johnsond1a08ce2023-10-18 17:22:21 +0100146 def get_tensor_data(self, tensor_name: str, json_config=None):
147 """Get a numpy array for a named tensor in the data_gen meta data."""
148 if json_config is None:
149 if self.json_config is None:
150 raise GenerateError("Cannot get tensor data as no config set up")
151 json_config = self.json_config
152 else:
153 # Validate the given config
154 self.schema_validator.validate_config(
155 json_config, schema_type=sch.TD_SCHEMA_DATA_GEN
156 )
157
158 return self._data_gen_array(json_config, tensor_name)
159
Jeremy Johnson65ba8092023-10-09 16:31:13 +0100160
161def main(argv=None):
162 """Simple command line interface for the data generator."""
163 import argparse
164 import conformance.model_files as cmf
165
166 parser = argparse.ArgumentParser()
167 parser.add_argument(
168 "--generate-lib-path",
169 type=Path,
170 help="Path to TOSA generate lib",
171 )
172 parser.add_argument(
173 "path", type=Path, help="the path to the test directory to generate data for"
174 )
175 args = parser.parse_args(argv)
176 test_path = args.path
177
178 if args.generate_lib_path is None:
179 # Try to work out ref model directory and find the verify library
180 # but this default only works for the python developer environment
181 # i.e. when using the scripts/py-dev-env.* scripts
182 # otherwise use the command line option --generate-lib-path to specify path
183 ref_model_dir = Path(__file__).absolute().parents[2]
184 args.generate_lib_path = cmf.find_tosa_file(
185 cmf.TosaFileType.GENERATE_LIBRARY, ref_model_dir, False
186 )
187
188 if not test_path.is_dir():
189 print(f"ERROR: Invalid directory - {test_path}")
190 return 2
191
192 test_desc_path = test_path / "desc.json"
193
194 if not test_desc_path.is_file():
195 print(f"ERROR: No test description found: {test_desc_path}")
196 return 2
197
198 # Load the JSON desc.json
199 try:
200 with test_desc_path.open("r") as fd:
201 test_desc = json.load(fd)
202 except Exception as e:
203 print(f"ERROR: Loading {test_desc_path} - {repr(e)}")
204 return 2
205
206 try:
207 dgl = GenerateLibrary(args.generate_lib_path)
208 if not dgl.check_config(test_desc):
209 print(f"WARNING: No data generation supported for {test_path}")
210 return 2
211
212 dgl.set_config(test_desc)
213 except GenerateError as e:
214 print(f"ERROR: Initializing generate library - {repr(e)}")
215 return 1
216
217 try:
218 dgl.write_numpy_files(test_path)
219 except GenerateError as e:
220 print(f"ERROR: Writing out data files to {test_path}\n{repr(e)}")
221 return 1
222
223
224if __name__ == "__main__":
225 exit(main())