blob: 1f89f74cfe93891ebf3c470751a07ce70d3abb81 [file] [log] [blame]
Grant Watson64285a12022-11-16 15:32:39 +00001"""Generate extended reference model API with eager operator execution entrypoints"""
2# Copyright (c) 2021-2022, ARM Limited.
3# SPDX-License-Identifier: Apache-2.0
4import copy
5import os
6import subprocess
7from xml.dom import minidom
8
9from jinja2 import Environment
10from jinja2 import FileSystemLoader
11
12
13def getTosaArgTypes(tosaXml):
14 """
15 Returns a list of the TOSA argument types from tosa.xml.
16 """
17 argTypes = {"in_t", "out_t", "mul_t", "weight_t", "in_out_t"}
18 argTypesXml = tosaXml.getElementsByTagName("type")
19 for argTypeXml in argTypesXml:
20 argTypes.add(argTypeXml.getAttribute("name"))
21 argTypes.remove("TABLE_SIZE")
22 return argTypes
23
24
25def getTosaDataTypes(tosaXml):
26 """
27 Returns a list of the TOSA data types from tosa.xml.
28 """
29 argTypes = getTosaArgTypes(tosaXml)
30 dataTypes = set()
31 dataTypesXml = tosaXml.getElementsByTagName("typesupport")
32 for dataTypeXml in dataTypesXml:
33 for argType in argTypes:
34 dataType = dataTypeXml.getAttribute(argType)
35 if dataType != "":
36 dataTypes.add(f"tosa_datatype_{dataType}")
37 return sorted(dataTypes)
38
39
40def getSerializeOpType(tosaOpName):
41 """
42 Returns the Serialization library operator that matches the TOSA operator specified.
43 """
44 map = {
45 "avg_pool2d": "Pool",
46 "conv2d": "Conv",
47 "conv3d": "Conv",
48 "depthwise_conv2d": "Conv",
49 "fully_connected": "FullyConnected",
50 "matmul": "MatMul",
51 "max_pool2d": "Pool",
52 "transpose_conv2d": "Conv",
53 "clamp": "Clamp",
54 "arithmetic_right_shift": "ArithmeticRightShift",
55 "mul": "Mul",
56 "table": "Table",
57 "negate": "Negate",
58 "pad": "Pad",
59 "reshape": "Reshape",
60 "slice": "Slice",
61 "tile": "Tile",
62 "transpose": "Transpose",
63 "resize": "Resize",
64 "rescale": "Rescale",
65 "cond_if": "CondIf",
66 "while_loop": "WhileLoop",
67 }
68 if tosaOpName not in map.keys():
69 return "None"
70 else:
71 return map[tosaOpName]
72
73
74def getSerializeArgsForOp(tosaOpName, allSerializeArgs, tosaArgs):
75 """
76 Returns the arguments required by the Serialization library for the TOSA operator specified.
77 Generates code to initialize Serialization arguments. If a matching TOSA argument exists,
78 that value is used for initialization, otherwise a default value e.g. 0 is used.
79 """
80 serOpType = getSerializeOpType(tosaOpName)
81 if serOpType not in allSerializeArgs.keys():
82 return {}
83 else:
84 serOpArgs = copy.deepcopy(allSerializeArgs[serOpType])
85 tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
86 serTosaTypeMap = {"ResizeMode": "tosa_mode"}
87 for arg in serOpArgs:
88 argName = arg["name"]
89 init = ""
90 # Translate TOSA data types to Serialization data types for initialization
91 if arg["dType"] in serTosaTypeMap.keys():
92 init = f" = translate_client_{serTosaTypeMap[arg['dType']]}(client_{argName})"
93 # Initialize Serialization arguments to their matching function parameter
94 elif argName in tosaArgsDict:
95 if arg["SV"] == "V":
96 shape = tosaArgsDict[argName]["shape"]
97 if shape == "[]":
98 init = f"(&client_{argName}[0], &client_{argName}[0] + client_{argName}_len)"
99 else:
100 init = f"(&client_{argName}[0], &client_{argName}{shape})"
101 else:
102 init = f" = client_{argName}"
103 else:
104 # Initialize Serialization arguments with no matching fuction parameter
105 if arg["SV"] == "V":
106 init = ""
107 else:
108 if arg["dType"] == "DType":
109 arg["dType"] = "tosa::DType"
110 init = " = tosa::DType::DType_FP32"
111 else:
112 init = " = 0"
113 arg["init"] = init
114 return serOpArgs
115
116
117def updateTosaArgs(tosaArgs, serializeArgs, tosaXml):
118 """
119 Replace TOSA argument data types with their matching Serialization argument data types.
120 Delete TOSA arguments where the type couldn't be determined.
121 Add Serialization arguments that have no matching TOSA argument.
122 """
123 tosaArgTypes = getTosaArgTypes(tosaXml)
124 serArgsDict = {arg["name"]: arg for arg in serializeArgs}
125 tosaArgsNames = [arg["name"] for arg in tosaArgs]
126 delTosaArgs = []
127 # Replace TOSA argument data types with their matching Serialization argument data types.
128 for tosaArg in tosaArgs:
129 if tosaArg["type"] in tosaArgTypes:
130 if tosaArg["name"] in serArgsDict:
131 tosaArg["type"] = serArgsDict[tosaArg["name"]]["dType"]
132 else:
133 # Delete TOSA argument whose data type can't be determined
134 delTosaArgs.append(tosaArgsNames.index(tosaArg["name"]))
135 # Delete corresponding length argument if one exists
136 lenArgName = f"{tosaArg['name']}_len"
137 if lenArgName in tosaArgsNames:
138 delTosaArgs.append(tosaArgsNames.index(lenArgName))
139 # Delete TOSA arguments where the type couldn't be determined
140 for index in sorted(delTosaArgs, key=int, reverse=True):
141 del tosaArgs[index]
142 # Add Serialization arguments that have no matching TOSA argument
143 tosaArgNames = [arg["name"] for arg in tosaArgs]
144 for serArg in serializeArgs:
145 if (serArg["name"] not in tosaArgNames) and (
146 not serArg["dType"] == "tosa::DType"
147 ):
148 serArgName = serArg["name"]
149 if serArg["SV"] == "V":
150 # For vector data types, insert a matching length argument
151 tosaArgs.insert(
152 len(tosaArgs) - 1,
153 {
154 "name": f"{serArgName}_len",
155 "type": "int32_t",
156 "shape": "",
157 "category": "",
158 },
159 )
160 init = f"(&client_{serArgName}[0], &client_{serArgName}[0] + client_{serArgName}_len)"
161 shape = "[]"
162 else:
163 init = f" = client_{serArg['name']}"
164 shape = ""
165 serArg["init"] = init
166 # Insert new argument
167 tosaArgs.insert(
168 len(tosaArgs) - 1,
169 {
170 "name": serArgName,
171 "type": serArg["dType"],
172 "shape": shape,
173 "category": "",
174 },
175 )
176
177
178def getOperators(tosaXml):
179 """
180 Return a list of TOSA operators as defined by tosa.xml.
181 """
182 operators = []
183 ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d"]
184 opsXml = tosaXml.getElementsByTagName("operator")
185 allSerializeArgs = getSerializeArgs()
186 for opXml in opsXml:
187 opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower()
188 if opName not in ignoreOps:
189 operator = {"name": opName}
190 operator["serializeAttType"] = getSerializeOpType(opName)
191 tosaArgs = getTosaArgs(opXml)
192 serializeArgs = getSerializeArgsForOp(opName, allSerializeArgs, tosaArgs)
193 updateTosaArgs(tosaArgs, serializeArgs, tosaXml)
194 operator["arguments"] = tosaArgs
195 operator["serializeArgs"] = serializeArgs
196 operator["inputs"] = [
197 arg["name"] for arg in tosaArgs if arg["category"] == "input"
198 ]
199 operator["outputs"] = [
200 arg["name"] for arg in tosaArgs if arg["category"] == "output"
201 ]
202 operators.append(operator)
203 return operators
204
205
206def getTosaArgs(opXml):
207 """
208 Return the arguments required for the TOSA operator specified.
209 """
210 arguments = []
211 argsXml = opXml.getElementsByTagName("argument")
212 tosaTensorTypes = getTosaArgTypes(tosaXml)
213 tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"}
214 for xmlArg in argsXml:
215 argName = xmlArg.getAttribute("name").lower()
216 argType = xmlArg.getAttribute("type")
217 argShape = xmlArg.getAttribute("shape")
218 argCategory = xmlArg.getAttribute("category")
219 # Update argument type
220 if argType[-1:] == "*":
221 argType = argType[:-1]
222 if argCategory in ["input", "output"] and argType in tosaTensorTypes:
223 argType = "tosa_tensor_t"
224 argShape = ""
225 if argType in tosaTypeMap:
226 argType = tosaTypeMap[argType]
227 # Add a length argument for arrays with unknown compile-time size
228 if argShape != "" and argShape[0] == "[" and not argShape[1:-1].isnumeric():
229 argShape = "[]"
230 arguments.append(
231 {
232 "name": f"{argName}_len",
233 "type": "int32_t",
234 "shape": "",
235 "category": "",
236 }
237 )
238 elif argShape == "" or not argShape[0] == "[":
239 argShape = ""
240 # Append argument
241 arguments.append(
242 {
243 "name": argName,
244 "type": argType,
245 "shape": argShape,
246 "category": argCategory,
247 }
248 )
249 return arguments
250
251
252def clangFormat(filename):
253 cmd = ["clang-format", "-i", filename]
254 with open(os.devnull, "w") as devnull:
255 subprocess.check_call(cmd, stdout=devnull)
256
257
258def getSerializeArgs():
259 """
260 Parse attribute.def file and return a dictionary where the keys are Serialization library operator names.
261 The values are the arguments required by each Serialization library operator.
262 """
263 serializeArgs = {}
264 with open("../../thirdparty/serialization_lib/include/attribute.def") as file:
265 preamble = True
266 inAtt = False
267 opName = ""
268 args = []
269 for line in file:
270 if preamble and not line[: len("DEF_ATTRIBUTE(")] == "DEF_ATTRIBUTE(":
271 continue
272 else:
273 preamble = False
274 line = line.lstrip().rstrip()
275 if not inAtt and "DEF_ATTRIBUTE(" in line:
276 opName = line[len("DEF_ATTRIBUTE(") : line.find(",")]
277 inAtt = True
278 elif inAtt:
279 vals = line.split(",")
280 argName = vals[2].lstrip().strip()
281 if ")" in argName:
282 argName = argName[:-1]
283 arg = {
284 "name": argName,
285 "dType": vals[0].lstrip().strip(),
286 "SV": vals[1].lstrip().strip(),
287 }
288 args.append(arg)
289 if ")" in line:
290 serializeArgs[opName] = args
291 opName = ""
292 args = []
293 inAtt = False
294 return serializeArgs
295
296
297def renderTemplate(environment, dataTypes, operators, template, outfile):
298 content = template.render(dataTypes=dataTypes, operators=operators)
299 with open(outfile, mode="w", encoding="utf-8") as output:
300 output.write(content)
301 print(f"Created {outfile}")
302
303 clangFormat(outfile)
304
305
306def generate(environment, dataTypes, operators):
307 # Generate include/operators.h
308 template = environment.get_template("operators_h.j2")
309 outfile = os.path.join("..", "..", "reference_model", "include", "operators.h")
310 renderTemplate(environment, dataTypes, operators, template, outfile)
311
312 # Generate src/operators.cc
313 template = environment.get_template("operators_cc.j2")
314 outfile = os.path.join("..", "..", "reference_model", "src", "operators.cc")
315 renderTemplate(environment, dataTypes, operators, template, outfile)
316
317
318def getSerializeOpTypeMap():
319 """
320 Utility function for generating the map used in getSerializeOpType()
321 """
322 import re
323
324 allSerializeArgs = getSerializeArgs()
325 serArgs = [
326 re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
327 for name in allSerializeArgs.keys()
328 ]
329 serArgs = sorted(serArgs, key=len, reverse=True)
330 tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
331 opsXml = tosaXml.getElementsByTagName("operator")
332 opNames = [
333 op.getElementsByTagName("name")[0].firstChild.data.lower() for op in opsXml
334 ]
335 map = {}
336 for opName in opNames:
337 for serArg in serArgs:
338 if serArg in opName:
339 components = serArg.split("_")
340 map[opName] = "".join(x.title() for x in components)
341 return map
342
343
344if __name__ == "__main__":
345 environment = Environment(loader=FileSystemLoader("templates/"))
346 tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
347 dataTypes = getTosaDataTypes(tosaXml)
348 operators = getOperators(tosaXml)
349 generate(environment, dataTypes, operators)