blob: 499eadb309d68ffcd78e2f87c18e9e9461406961 [file] [log] [blame]
Grant Watson64285a12022-11-16 15:32:39 +00001"""Generate extended reference model API with eager operator execution entrypoints"""
2# Copyright (c) 2021-2022, ARM Limited.
3# SPDX-License-Identifier: Apache-2.0
4import copy
5import os
6import subprocess
7from xml.dom import minidom
8
9from jinja2 import Environment
10from jinja2 import FileSystemLoader
11
James Wardd34b3fc2023-01-18 14:51:25 +000012# Note: main script designed to be run from the scripts/operator_api/ directory
13
Grant Watson64285a12022-11-16 15:32:39 +000014
15def getTosaArgTypes(tosaXml):
16 """
17 Returns a list of the TOSA argument types from tosa.xml.
18 """
Grant Watsoneb741062023-06-23 16:52:12 +010019 argTypes = {
20 "tensor_t",
21 "in_t",
22 "out_t",
23 "mul_t",
24 "weight_t",
25 "in_out_t",
26 "tensor_list_t",
27 }
Grant Watson64285a12022-11-16 15:32:39 +000028 argTypesXml = tosaXml.getElementsByTagName("type")
29 for argTypeXml in argTypesXml:
30 argTypes.add(argTypeXml.getAttribute("name"))
31 argTypes.remove("TABLE_SIZE")
32 return argTypes
33
34
35def getTosaDataTypes(tosaXml):
36 """
37 Returns a list of the TOSA data types from tosa.xml.
38 """
39 argTypes = getTosaArgTypes(tosaXml)
40 dataTypes = set()
41 dataTypesXml = tosaXml.getElementsByTagName("typesupport")
42 for dataTypeXml in dataTypesXml:
43 for argType in argTypes:
44 dataType = dataTypeXml.getAttribute(argType)
45 if dataType != "":
46 dataTypes.add(f"tosa_datatype_{dataType}")
47 return sorted(dataTypes)
48
49
50def getSerializeOpType(tosaOpName):
51 """
52 Returns the Serialization library operator that matches the TOSA operator specified.
53 """
54 map = {
55 "avg_pool2d": "Pool",
56 "conv2d": "Conv",
57 "conv3d": "Conv",
58 "depthwise_conv2d": "Conv",
59 "fully_connected": "FullyConnected",
60 "matmul": "MatMul",
61 "max_pool2d": "Pool",
62 "transpose_conv2d": "Conv",
63 "clamp": "Clamp",
64 "arithmetic_right_shift": "ArithmeticRightShift",
65 "mul": "Mul",
66 "table": "Table",
67 "negate": "Negate",
68 "pad": "Pad",
69 "reshape": "Reshape",
70 "slice": "Slice",
71 "tile": "Tile",
72 "transpose": "Transpose",
73 "resize": "Resize",
74 "rescale": "Rescale",
75 "cond_if": "CondIf",
76 "while_loop": "WhileLoop",
77 }
78 if tosaOpName not in map.keys():
79 return "None"
80 else:
81 return map[tosaOpName]
82
83
84def getSerializeArgsForOp(tosaOpName, allSerializeArgs, tosaArgs):
85 """
86 Returns the arguments required by the Serialization library for the TOSA operator specified.
87 Generates code to initialize Serialization arguments. If a matching TOSA argument exists,
88 that value is used for initialization, otherwise a default value e.g. 0 is used.
89 """
90 serOpType = getSerializeOpType(tosaOpName)
91 if serOpType not in allSerializeArgs.keys():
92 return {}
93 else:
94 serOpArgs = copy.deepcopy(allSerializeArgs[serOpType])
95 tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
96 serTosaTypeMap = {"ResizeMode": "tosa_mode"}
97 for arg in serOpArgs:
98 argName = arg["name"]
99 init = ""
100 # Translate TOSA data types to Serialization data types for initialization
101 if arg["dType"] in serTosaTypeMap.keys():
102 init = f" = translate_client_{serTosaTypeMap[arg['dType']]}(client_{argName})"
103 # Initialize Serialization arguments to their matching function parameter
104 elif argName in tosaArgsDict:
105 if arg["SV"] == "V":
106 shape = tosaArgsDict[argName]["shape"]
107 if shape == "[]":
108 init = f"(&client_{argName}[0], &client_{argName}[0] + client_{argName}_len)"
109 else:
110 init = f"(&client_{argName}[0], &client_{argName}{shape})"
111 else:
112 init = f" = client_{argName}"
113 else:
114 # Initialize Serialization arguments with no matching fuction parameter
115 if arg["SV"] == "V":
116 init = ""
117 else:
118 if arg["dType"] == "DType":
119 arg["dType"] = "tosa::DType"
120 init = " = tosa::DType::DType_FP32"
121 else:
122 init = " = 0"
123 arg["init"] = init
124 return serOpArgs
125
126
127def updateTosaArgs(tosaArgs, serializeArgs, tosaXml):
128 """
129 Replace TOSA argument data types with their matching Serialization argument data types.
130 Delete TOSA arguments where the type couldn't be determined.
131 Add Serialization arguments that have no matching TOSA argument.
132 """
133 tosaArgTypes = getTosaArgTypes(tosaXml)
134 serArgsDict = {arg["name"]: arg for arg in serializeArgs}
135 tosaArgsNames = [arg["name"] for arg in tosaArgs]
136 delTosaArgs = []
137 # Replace TOSA argument data types with their matching Serialization argument data types.
138 for tosaArg in tosaArgs:
139 if tosaArg["type"] in tosaArgTypes:
140 if tosaArg["name"] in serArgsDict:
141 tosaArg["type"] = serArgsDict[tosaArg["name"]]["dType"]
142 else:
143 # Delete TOSA argument whose data type can't be determined
144 delTosaArgs.append(tosaArgsNames.index(tosaArg["name"]))
145 # Delete corresponding length argument if one exists
146 lenArgName = f"{tosaArg['name']}_len"
147 if lenArgName in tosaArgsNames:
148 delTosaArgs.append(tosaArgsNames.index(lenArgName))
149 # Delete TOSA arguments where the type couldn't be determined
150 for index in sorted(delTosaArgs, key=int, reverse=True):
151 del tosaArgs[index]
152 # Add Serialization arguments that have no matching TOSA argument
153 tosaArgNames = [arg["name"] for arg in tosaArgs]
154 for serArg in serializeArgs:
155 if (serArg["name"] not in tosaArgNames) and (
156 not serArg["dType"] == "tosa::DType"
157 ):
158 serArgName = serArg["name"]
159 if serArg["SV"] == "V":
160 # For vector data types, insert a matching length argument
161 tosaArgs.insert(
162 len(tosaArgs) - 1,
163 {
164 "name": f"{serArgName}_len",
165 "type": "int32_t",
166 "shape": "",
167 "category": "",
168 },
169 )
170 init = f"(&client_{serArgName}[0], &client_{serArgName}[0] + client_{serArgName}_len)"
171 shape = "[]"
172 else:
173 init = f" = client_{serArg['name']}"
174 shape = ""
175 serArg["init"] = init
176 # Insert new argument
177 tosaArgs.insert(
178 len(tosaArgs) - 1,
179 {
180 "name": serArgName,
181 "type": serArg["dType"],
182 "shape": shape,
183 "category": "",
184 },
185 )
186
187
188def getOperators(tosaXml):
189 """
190 Return a list of TOSA operators as defined by tosa.xml.
191 """
192 operators = []
Grant Watsoneb741062023-06-23 16:52:12 +0100193 ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d", "erf"]
Grant Watson64285a12022-11-16 15:32:39 +0000194 opsXml = tosaXml.getElementsByTagName("operator")
195 allSerializeArgs = getSerializeArgs()
196 for opXml in opsXml:
197 opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower()
198 if opName not in ignoreOps:
199 operator = {"name": opName}
200 operator["serializeAttType"] = getSerializeOpType(opName)
201 tosaArgs = getTosaArgs(opXml)
202 serializeArgs = getSerializeArgsForOp(opName, allSerializeArgs, tosaArgs)
Grant Watson61680472023-05-31 14:56:13 +0100203 # Handle "axis" arguments
204 axisList = [arg["name"] for arg in tosaArgs if arg["name"] == "axis"]
205 if operator["serializeAttType"] == "None" and len(axisList) > 0:
206 operator["serializeAttType"] = "Axis"
207 serializeArgs = [
208 {
209 "name": "axis",
210 "dType": "int32_t",
211 "SV": "S",
212 "init": "= client_axis",
213 }
214 ]
Grant Watson64285a12022-11-16 15:32:39 +0000215 updateTosaArgs(tosaArgs, serializeArgs, tosaXml)
216 operator["arguments"] = tosaArgs
217 operator["serializeArgs"] = serializeArgs
218 operator["inputs"] = [
219 arg["name"] for arg in tosaArgs if arg["category"] == "input"
220 ]
221 operator["outputs"] = [
222 arg["name"] for arg in tosaArgs if arg["category"] == "output"
223 ]
224 operators.append(operator)
225 return operators
226
227
228def getTosaArgs(opXml):
229 """
230 Return the arguments required for the TOSA operator specified.
231 """
232 arguments = []
233 argsXml = opXml.getElementsByTagName("argument")
234 tosaTensorTypes = getTosaArgTypes(tosaXml)
235 tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"}
236 for xmlArg in argsXml:
237 argName = xmlArg.getAttribute("name").lower()
Grant Watsoneb741062023-06-23 16:52:12 +0100238 if xmlArg.getAttribute("tensor-element-type") == "resize_mode_t":
239 argType = "tosa_mode_t"
240 else:
241 argType = xmlArg.getAttribute("type")
Grant Watson64285a12022-11-16 15:32:39 +0000242 argShape = xmlArg.getAttribute("shape")
243 argCategory = xmlArg.getAttribute("category")
244 # Update argument type
245 if argType[-1:] == "*":
246 argType = argType[:-1]
247 if argCategory in ["input", "output"] and argType in tosaTensorTypes:
248 argType = "tosa_tensor_t"
249 argShape = ""
250 if argType in tosaTypeMap:
251 argType = tosaTypeMap[argType]
252 # Add a length argument for arrays with unknown compile-time size
253 if argShape != "" and argShape[0] == "[" and not argShape[1:-1].isnumeric():
254 argShape = "[]"
255 arguments.append(
256 {
257 "name": f"{argName}_len",
258 "type": "int32_t",
259 "shape": "",
260 "category": "",
261 }
262 )
263 elif argShape == "" or not argShape[0] == "[":
264 argShape = ""
265 # Append argument
266 arguments.append(
267 {
268 "name": argName,
269 "type": argType,
270 "shape": argShape,
271 "category": argCategory,
272 }
273 )
274 return arguments
275
276
277def clangFormat(filename):
278 cmd = ["clang-format", "-i", filename]
279 with open(os.devnull, "w") as devnull:
280 subprocess.check_call(cmd, stdout=devnull)
281
282
283def getSerializeArgs():
284 """
285 Parse attribute.def file and return a dictionary where the keys are Serialization library operator names.
286 The values are the arguments required by each Serialization library operator.
287 """
288 serializeArgs = {}
289 with open("../../thirdparty/serialization_lib/include/attribute.def") as file:
290 preamble = True
291 inAtt = False
292 opName = ""
293 args = []
294 for line in file:
295 if preamble and not line[: len("DEF_ATTRIBUTE(")] == "DEF_ATTRIBUTE(":
296 continue
297 else:
298 preamble = False
299 line = line.lstrip().rstrip()
300 if not inAtt and "DEF_ATTRIBUTE(" in line:
301 opName = line[len("DEF_ATTRIBUTE(") : line.find(",")]
302 inAtt = True
303 elif inAtt:
304 vals = line.split(",")
305 argName = vals[2].lstrip().strip()
306 if ")" in argName:
307 argName = argName[:-1]
308 arg = {
309 "name": argName,
310 "dType": vals[0].lstrip().strip(),
311 "SV": vals[1].lstrip().strip(),
312 }
313 args.append(arg)
314 if ")" in line:
315 serializeArgs[opName] = args
316 opName = ""
317 args = []
318 inAtt = False
319 return serializeArgs
320
321
322def renderTemplate(environment, dataTypes, operators, template, outfile):
323 content = template.render(dataTypes=dataTypes, operators=operators)
324 with open(outfile, mode="w", encoding="utf-8") as output:
325 output.write(content)
326 print(f"Created {outfile}")
327
328 clangFormat(outfile)
329
330
331def generate(environment, dataTypes, operators):
332 # Generate include/operators.h
333 template = environment.get_template("operators_h.j2")
334 outfile = os.path.join("..", "..", "reference_model", "include", "operators.h")
335 renderTemplate(environment, dataTypes, operators, template, outfile)
336
337 # Generate src/operators.cc
338 template = environment.get_template("operators_cc.j2")
339 outfile = os.path.join("..", "..", "reference_model", "src", "operators.cc")
340 renderTemplate(environment, dataTypes, operators, template, outfile)
341
342
343def getSerializeOpTypeMap():
344 """
345 Utility function for generating the map used in getSerializeOpType()
346 """
347 import re
348
349 allSerializeArgs = getSerializeArgs()
350 serArgs = [
351 re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
352 for name in allSerializeArgs.keys()
353 ]
354 serArgs = sorted(serArgs, key=len, reverse=True)
355 tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
356 opsXml = tosaXml.getElementsByTagName("operator")
357 opNames = [
358 op.getElementsByTagName("name")[0].firstChild.data.lower() for op in opsXml
359 ]
360 map = {}
361 for opName in opNames:
362 for serArg in serArgs:
363 if serArg in opName:
364 components = serArg.split("_")
365 map[opName] = "".join(x.title() for x in components)
366 return map
367
368
369if __name__ == "__main__":
370 environment = Environment(loader=FileSystemLoader("templates/"))
371 tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
372 dataTypes = getTosaDataTypes(tosaXml)
373 operators = getOperators(tosaXml)
374 generate(environment, dataTypes, operators)