blob: 671d9021526cff76529091ebfc1e1628eb43a5a9 [file] [log] [blame]
Grant Watson64285a12022-11-16 15:32:39 +00001"""Generate extended reference model API with eager operator execution entrypoints"""
2# Copyright (c) 2021-2022, ARM Limited.
3# SPDX-License-Identifier: Apache-2.0
4import copy
5import os
6import subprocess
7from xml.dom import minidom
8
9from jinja2 import Environment
10from jinja2 import FileSystemLoader
11
James Wardd34b3fc2023-01-18 14:51:25 +000012# Note: main script designed to be run from the scripts/operator_api/ directory
13
Grant Watson64285a12022-11-16 15:32:39 +000014
15def getTosaArgTypes(tosaXml):
16 """
17 Returns a list of the TOSA argument types from tosa.xml.
18 """
19 argTypes = {"in_t", "out_t", "mul_t", "weight_t", "in_out_t"}
20 argTypesXml = tosaXml.getElementsByTagName("type")
21 for argTypeXml in argTypesXml:
22 argTypes.add(argTypeXml.getAttribute("name"))
23 argTypes.remove("TABLE_SIZE")
24 return argTypes
25
26
27def getTosaDataTypes(tosaXml):
28 """
29 Returns a list of the TOSA data types from tosa.xml.
30 """
31 argTypes = getTosaArgTypes(tosaXml)
32 dataTypes = set()
33 dataTypesXml = tosaXml.getElementsByTagName("typesupport")
34 for dataTypeXml in dataTypesXml:
35 for argType in argTypes:
36 dataType = dataTypeXml.getAttribute(argType)
37 if dataType != "":
38 dataTypes.add(f"tosa_datatype_{dataType}")
39 return sorted(dataTypes)
40
41
42def getSerializeOpType(tosaOpName):
43 """
44 Returns the Serialization library operator that matches the TOSA operator specified.
45 """
46 map = {
47 "avg_pool2d": "Pool",
48 "conv2d": "Conv",
49 "conv3d": "Conv",
50 "depthwise_conv2d": "Conv",
51 "fully_connected": "FullyConnected",
52 "matmul": "MatMul",
53 "max_pool2d": "Pool",
54 "transpose_conv2d": "Conv",
55 "clamp": "Clamp",
56 "arithmetic_right_shift": "ArithmeticRightShift",
57 "mul": "Mul",
58 "table": "Table",
59 "negate": "Negate",
60 "pad": "Pad",
61 "reshape": "Reshape",
62 "slice": "Slice",
63 "tile": "Tile",
64 "transpose": "Transpose",
65 "resize": "Resize",
66 "rescale": "Rescale",
67 "cond_if": "CondIf",
68 "while_loop": "WhileLoop",
69 }
70 if tosaOpName not in map.keys():
71 return "None"
72 else:
73 return map[tosaOpName]
74
75
76def getSerializeArgsForOp(tosaOpName, allSerializeArgs, tosaArgs):
77 """
78 Returns the arguments required by the Serialization library for the TOSA operator specified.
79 Generates code to initialize Serialization arguments. If a matching TOSA argument exists,
80 that value is used for initialization, otherwise a default value e.g. 0 is used.
81 """
82 serOpType = getSerializeOpType(tosaOpName)
83 if serOpType not in allSerializeArgs.keys():
84 return {}
85 else:
86 serOpArgs = copy.deepcopy(allSerializeArgs[serOpType])
87 tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
88 serTosaTypeMap = {"ResizeMode": "tosa_mode"}
89 for arg in serOpArgs:
90 argName = arg["name"]
91 init = ""
92 # Translate TOSA data types to Serialization data types for initialization
93 if arg["dType"] in serTosaTypeMap.keys():
94 init = f" = translate_client_{serTosaTypeMap[arg['dType']]}(client_{argName})"
95 # Initialize Serialization arguments to their matching function parameter
96 elif argName in tosaArgsDict:
97 if arg["SV"] == "V":
98 shape = tosaArgsDict[argName]["shape"]
99 if shape == "[]":
100 init = f"(&client_{argName}[0], &client_{argName}[0] + client_{argName}_len)"
101 else:
102 init = f"(&client_{argName}[0], &client_{argName}{shape})"
103 else:
104 init = f" = client_{argName}"
105 else:
106 # Initialize Serialization arguments with no matching fuction parameter
107 if arg["SV"] == "V":
108 init = ""
109 else:
110 if arg["dType"] == "DType":
111 arg["dType"] = "tosa::DType"
112 init = " = tosa::DType::DType_FP32"
113 else:
114 init = " = 0"
115 arg["init"] = init
116 return serOpArgs
117
118
119def updateTosaArgs(tosaArgs, serializeArgs, tosaXml):
120 """
121 Replace TOSA argument data types with their matching Serialization argument data types.
122 Delete TOSA arguments where the type couldn't be determined.
123 Add Serialization arguments that have no matching TOSA argument.
124 """
125 tosaArgTypes = getTosaArgTypes(tosaXml)
126 serArgsDict = {arg["name"]: arg for arg in serializeArgs}
127 tosaArgsNames = [arg["name"] for arg in tosaArgs]
128 delTosaArgs = []
129 # Replace TOSA argument data types with their matching Serialization argument data types.
130 for tosaArg in tosaArgs:
131 if tosaArg["type"] in tosaArgTypes:
132 if tosaArg["name"] in serArgsDict:
133 tosaArg["type"] = serArgsDict[tosaArg["name"]]["dType"]
134 else:
135 # Delete TOSA argument whose data type can't be determined
136 delTosaArgs.append(tosaArgsNames.index(tosaArg["name"]))
137 # Delete corresponding length argument if one exists
138 lenArgName = f"{tosaArg['name']}_len"
139 if lenArgName in tosaArgsNames:
140 delTosaArgs.append(tosaArgsNames.index(lenArgName))
141 # Delete TOSA arguments where the type couldn't be determined
142 for index in sorted(delTosaArgs, key=int, reverse=True):
143 del tosaArgs[index]
144 # Add Serialization arguments that have no matching TOSA argument
145 tosaArgNames = [arg["name"] for arg in tosaArgs]
146 for serArg in serializeArgs:
147 if (serArg["name"] not in tosaArgNames) and (
148 not serArg["dType"] == "tosa::DType"
149 ):
150 serArgName = serArg["name"]
151 if serArg["SV"] == "V":
152 # For vector data types, insert a matching length argument
153 tosaArgs.insert(
154 len(tosaArgs) - 1,
155 {
156 "name": f"{serArgName}_len",
157 "type": "int32_t",
158 "shape": "",
159 "category": "",
160 },
161 )
162 init = f"(&client_{serArgName}[0], &client_{serArgName}[0] + client_{serArgName}_len)"
163 shape = "[]"
164 else:
165 init = f" = client_{serArg['name']}"
166 shape = ""
167 serArg["init"] = init
168 # Insert new argument
169 tosaArgs.insert(
170 len(tosaArgs) - 1,
171 {
172 "name": serArgName,
173 "type": serArg["dType"],
174 "shape": shape,
175 "category": "",
176 },
177 )
178
179
180def getOperators(tosaXml):
181 """
182 Return a list of TOSA operators as defined by tosa.xml.
183 """
184 operators = []
185 ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d"]
186 opsXml = tosaXml.getElementsByTagName("operator")
187 allSerializeArgs = getSerializeArgs()
188 for opXml in opsXml:
189 opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower()
190 if opName not in ignoreOps:
191 operator = {"name": opName}
192 operator["serializeAttType"] = getSerializeOpType(opName)
193 tosaArgs = getTosaArgs(opXml)
194 serializeArgs = getSerializeArgsForOp(opName, allSerializeArgs, tosaArgs)
195 updateTosaArgs(tosaArgs, serializeArgs, tosaXml)
196 operator["arguments"] = tosaArgs
197 operator["serializeArgs"] = serializeArgs
198 operator["inputs"] = [
199 arg["name"] for arg in tosaArgs if arg["category"] == "input"
200 ]
201 operator["outputs"] = [
202 arg["name"] for arg in tosaArgs if arg["category"] == "output"
203 ]
204 operators.append(operator)
205 return operators
206
207
208def getTosaArgs(opXml):
209 """
210 Return the arguments required for the TOSA operator specified.
211 """
212 arguments = []
213 argsXml = opXml.getElementsByTagName("argument")
214 tosaTensorTypes = getTosaArgTypes(tosaXml)
215 tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"}
216 for xmlArg in argsXml:
217 argName = xmlArg.getAttribute("name").lower()
218 argType = xmlArg.getAttribute("type")
219 argShape = xmlArg.getAttribute("shape")
220 argCategory = xmlArg.getAttribute("category")
221 # Update argument type
222 if argType[-1:] == "*":
223 argType = argType[:-1]
224 if argCategory in ["input", "output"] and argType in tosaTensorTypes:
225 argType = "tosa_tensor_t"
226 argShape = ""
227 if argType in tosaTypeMap:
228 argType = tosaTypeMap[argType]
229 # Add a length argument for arrays with unknown compile-time size
230 if argShape != "" and argShape[0] == "[" and not argShape[1:-1].isnumeric():
231 argShape = "[]"
232 arguments.append(
233 {
234 "name": f"{argName}_len",
235 "type": "int32_t",
236 "shape": "",
237 "category": "",
238 }
239 )
240 elif argShape == "" or not argShape[0] == "[":
241 argShape = ""
242 # Append argument
243 arguments.append(
244 {
245 "name": argName,
246 "type": argType,
247 "shape": argShape,
248 "category": argCategory,
249 }
250 )
251 return arguments
252
253
254def clangFormat(filename):
255 cmd = ["clang-format", "-i", filename]
256 with open(os.devnull, "w") as devnull:
257 subprocess.check_call(cmd, stdout=devnull)
258
259
260def getSerializeArgs():
261 """
262 Parse attribute.def file and return a dictionary where the keys are Serialization library operator names.
263 The values are the arguments required by each Serialization library operator.
264 """
265 serializeArgs = {}
266 with open("../../thirdparty/serialization_lib/include/attribute.def") as file:
267 preamble = True
268 inAtt = False
269 opName = ""
270 args = []
271 for line in file:
272 if preamble and not line[: len("DEF_ATTRIBUTE(")] == "DEF_ATTRIBUTE(":
273 continue
274 else:
275 preamble = False
276 line = line.lstrip().rstrip()
277 if not inAtt and "DEF_ATTRIBUTE(" in line:
278 opName = line[len("DEF_ATTRIBUTE(") : line.find(",")]
279 inAtt = True
280 elif inAtt:
281 vals = line.split(",")
282 argName = vals[2].lstrip().strip()
283 if ")" in argName:
284 argName = argName[:-1]
285 arg = {
286 "name": argName,
287 "dType": vals[0].lstrip().strip(),
288 "SV": vals[1].lstrip().strip(),
289 }
290 args.append(arg)
291 if ")" in line:
292 serializeArgs[opName] = args
293 opName = ""
294 args = []
295 inAtt = False
296 return serializeArgs
297
298
299def renderTemplate(environment, dataTypes, operators, template, outfile):
300 content = template.render(dataTypes=dataTypes, operators=operators)
301 with open(outfile, mode="w", encoding="utf-8") as output:
302 output.write(content)
303 print(f"Created {outfile}")
304
305 clangFormat(outfile)
306
307
308def generate(environment, dataTypes, operators):
309 # Generate include/operators.h
310 template = environment.get_template("operators_h.j2")
311 outfile = os.path.join("..", "..", "reference_model", "include", "operators.h")
312 renderTemplate(environment, dataTypes, operators, template, outfile)
313
314 # Generate src/operators.cc
315 template = environment.get_template("operators_cc.j2")
316 outfile = os.path.join("..", "..", "reference_model", "src", "operators.cc")
317 renderTemplate(environment, dataTypes, operators, template, outfile)
318
319
320def getSerializeOpTypeMap():
321 """
322 Utility function for generating the map used in getSerializeOpType()
323 """
324 import re
325
326 allSerializeArgs = getSerializeArgs()
327 serArgs = [
328 re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
329 for name in allSerializeArgs.keys()
330 ]
331 serArgs = sorted(serArgs, key=len, reverse=True)
332 tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
333 opsXml = tosaXml.getElementsByTagName("operator")
334 opNames = [
335 op.getElementsByTagName("name")[0].firstChild.data.lower() for op in opsXml
336 ]
337 map = {}
338 for opName in opNames:
339 for serArg in serArgs:
340 if serArg in opName:
341 components = serArg.split("_")
342 map[opName] = "".join(x.title() for x in components)
343 return map
344
345
346if __name__ == "__main__":
347 environment = Environment(loader=FileSystemLoader("templates/"))
348 tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
349 dataTypes = getTosaDataTypes(tosaXml)
350 operators = getOperators(tosaXml)
351 generate(environment, dataTypes, operators)