blob: c5c762d2f52f2d369ff6cd240132343b099ace92 [file] [log] [blame]
Grant Watson64285a12022-11-16 15:32:39 +00001"""Generate extended reference model API with eager operator execution entrypoints"""
Grant Watsone70d9312023-08-28 16:34:28 +01002# Copyright (c) 2021-2023, ARM Limited.
Grant Watson64285a12022-11-16 15:32:39 +00003# SPDX-License-Identifier: Apache-2.0
4import copy
5import os
6import subprocess
7from xml.dom import minidom
8
9from jinja2 import Environment
10from jinja2 import FileSystemLoader
11
James Wardd34b3fc2023-01-18 14:51:25 +000012# Note: main script designed to be run from the scripts/operator_api/ directory
13
Grant Watson64285a12022-11-16 15:32:39 +000014
15def getTosaArgTypes(tosaXml):
16 """
17 Returns a list of the TOSA argument types from tosa.xml.
18 """
Grant Watsoneb741062023-06-23 16:52:12 +010019 argTypes = {
20 "tensor_t",
21 "in_t",
22 "out_t",
23 "mul_t",
24 "weight_t",
25 "in_out_t",
26 "tensor_list_t",
27 }
Grant Watson64285a12022-11-16 15:32:39 +000028 argTypesXml = tosaXml.getElementsByTagName("type")
29 for argTypeXml in argTypesXml:
30 argTypes.add(argTypeXml.getAttribute("name"))
31 argTypes.remove("TABLE_SIZE")
32 return argTypes
33
34
35def getTosaDataTypes(tosaXml):
36 """
37 Returns a list of the TOSA data types from tosa.xml.
38 """
39 argTypes = getTosaArgTypes(tosaXml)
40 dataTypes = set()
41 dataTypesXml = tosaXml.getElementsByTagName("typesupport")
42 for dataTypeXml in dataTypesXml:
43 for argType in argTypes:
44 dataType = dataTypeXml.getAttribute(argType)
45 if dataType != "":
46 dataTypes.add(f"tosa_datatype_{dataType}")
47 return sorted(dataTypes)
48
49
50def getSerializeOpType(tosaOpName):
51 """
52 Returns the Serialization library operator that matches the TOSA operator specified.
53 """
54 map = {
55 "avg_pool2d": "Pool",
56 "conv2d": "Conv",
57 "conv3d": "Conv",
58 "depthwise_conv2d": "Conv",
59 "fully_connected": "FullyConnected",
60 "matmul": "MatMul",
61 "max_pool2d": "Pool",
62 "transpose_conv2d": "Conv",
63 "clamp": "Clamp",
64 "arithmetic_right_shift": "ArithmeticRightShift",
65 "mul": "Mul",
66 "table": "Table",
67 "negate": "Negate",
68 "pad": "Pad",
69 "reshape": "Reshape",
70 "slice": "Slice",
71 "tile": "Tile",
72 "transpose": "Transpose",
73 "resize": "Resize",
74 "rescale": "Rescale",
75 "cond_if": "CondIf",
76 "while_loop": "WhileLoop",
77 }
78 if tosaOpName not in map.keys():
79 return "None"
80 else:
81 return map[tosaOpName]
82
83
Grant Watsoneff70382023-09-12 10:46:36 +010084def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs):
Grant Watson64285a12022-11-16 15:32:39 +000085 """
Grant Watsoneff70382023-09-12 10:46:36 +010086 Returns the attributes required by the Serialization library for the TOSA operator specified.
87 Generates code to initialize Serialization library attributes. If a matching TOSA argument exists,
Grant Watson64285a12022-11-16 15:32:39 +000088 that value is used for initialization, otherwise a default value e.g. 0 is used.
89 """
Grant Watsoneff70382023-09-12 10:46:36 +010090 serLibOpType = getSerializeOpType(tosaOpName)
91 if serLibOpType not in allSerialLibAtts.keys():
Grant Watson64285a12022-11-16 15:32:39 +000092 return {}
93 else:
Grant Watsoneff70382023-09-12 10:46:36 +010094 serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType])
Grant Watson64285a12022-11-16 15:32:39 +000095 tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
96 serTosaTypeMap = {"ResizeMode": "tosa_mode"}
Grant Watsoneff70382023-09-12 10:46:36 +010097 for att in serLibOpAtts:
98 attName = att["name"]
99 attType = att["dType"]
Grant Watson64285a12022-11-16 15:32:39 +0000100 init = ""
Grant Watsoneff70382023-09-12 10:46:36 +0100101 # Translate TOSA data types to Serialization library data types for initialization
102 if attType in serTosaTypeMap.keys():
103 init = f"const {attType} {attName} = translate_client_{serTosaTypeMap[att['dType']]}(client_{attName});"
104 # Initialize Serialization library attributes to their matching function parameter
105 elif attName in tosaArgsDict:
106 if att["SV"] == "V":
107 if tosaArgsDict[attName]["type"] == "tosa_tensor_t":
108 init = f"std::vector<{attType}> {attName};"
109 init = (
110 init
111 + f"size_t {attName}_size = client_{attName}.size / sizeof({attType});"
112 )
113 init = (
114 init
115 + f"{attType}* {attName}_data = reinterpret_cast<{attType}*>(client_{attName}.data);"
116 )
117 init = (
118 init
119 + f"{attName}.assign({attName}_data, {attName}_data + {attName}_size);"
120 )
Grant Watson64285a12022-11-16 15:32:39 +0000121 else:
Grant Watsoneff70382023-09-12 10:46:36 +0100122 init = f"const std::vector<{attType}> {attName}"
123 shape = tosaArgsDict[attName]["shape"]
124 if shape == "[]":
125 init = (
126 init
127 + f"(&client_{attName}[0], &client_{attName}[0] + client_{attName}_len);"
128 )
129 else:
130 init = (
131 init
132 + f"(&client_{attName}[0], &client_{attName}{shape});"
133 )
Grant Watson64285a12022-11-16 15:32:39 +0000134 else:
Grant Watson64285a12022-11-16 15:32:39 +0000135 init = ""
Grant Watsoneff70382023-09-12 10:46:36 +0100136 else:
137 # Initialize Serialization library attributes with no matching fuction parameter
138 if att["SV"] == "V":
139 init = f"std::vector<int32_t> {attName};"
Grant Watson64285a12022-11-16 15:32:39 +0000140 else:
Grant Watsoneff70382023-09-12 10:46:36 +0100141 if att["dType"] == "DType":
142 att["dType"] = "tosa::DType"
143 init = f"const tosa::DType {attName} = tosa::DType::DType_FP32;"
Grant Watson64285a12022-11-16 15:32:39 +0000144 else:
Grant Watsoneff70382023-09-12 10:46:36 +0100145 init = f"const {attType} {attName} = 0;"
146 att["init"] = init
147 return serLibOpAtts
Grant Watson64285a12022-11-16 15:32:39 +0000148
149
Grant Watsoneff70382023-09-12 10:46:36 +0100150def updateTosaArgs(tosaArgs, serialLibAtts, tosaXml):
Grant Watson64285a12022-11-16 15:32:39 +0000151 """
Grant Watsoneff70382023-09-12 10:46:36 +0100152 Replace TOSA argument data types with their matching Serialization attribute data types.
Grant Watson64285a12022-11-16 15:32:39 +0000153 Delete TOSA arguments where the type couldn't be determined.
Grant Watsoneff70382023-09-12 10:46:36 +0100154 Add Serialization attributes that have no matching TOSA argument.
Grant Watson64285a12022-11-16 15:32:39 +0000155 """
156 tosaArgTypes = getTosaArgTypes(tosaXml)
Grant Watsoneff70382023-09-12 10:46:36 +0100157 serAttsDict = {att["name"]: att for att in serialLibAtts}
Grant Watson64285a12022-11-16 15:32:39 +0000158 tosaArgsNames = [arg["name"] for arg in tosaArgs]
159 delTosaArgs = []
Grant Watsoneff70382023-09-12 10:46:36 +0100160 # Replace TOSA argument data types with their matching Serialization attribute data types.
Grant Watson64285a12022-11-16 15:32:39 +0000161 for tosaArg in tosaArgs:
162 if tosaArg["type"] in tosaArgTypes:
Grant Watsoneff70382023-09-12 10:46:36 +0100163 if tosaArg["name"] in serAttsDict:
164 tosaArg["type"] = serAttsDict[tosaArg["name"]]["dType"]
Grant Watson64285a12022-11-16 15:32:39 +0000165 else:
166 # Delete TOSA argument whose data type can't be determined
167 delTosaArgs.append(tosaArgsNames.index(tosaArg["name"]))
168 # Delete corresponding length argument if one exists
169 lenArgName = f"{tosaArg['name']}_len"
170 if lenArgName in tosaArgsNames:
171 delTosaArgs.append(tosaArgsNames.index(lenArgName))
172 # Delete TOSA arguments where the type couldn't be determined
173 for index in sorted(delTosaArgs, key=int, reverse=True):
174 del tosaArgs[index]
Grant Watsoneff70382023-09-12 10:46:36 +0100175 # Add Serialization attributes that have no matching TOSA argument
Grant Watson64285a12022-11-16 15:32:39 +0000176 tosaArgNames = [arg["name"] for arg in tosaArgs]
Grant Watsoneff70382023-09-12 10:46:36 +0100177 for serAtt in serialLibAtts:
178 attName = serAtt["name"]
179 attType = serAtt["dType"]
180 if (attName not in tosaArgNames) and (not attType == "tosa::DType"):
181 serAttName = serAtt["name"]
182 if serAtt["SV"] == "V":
Grant Watson64285a12022-11-16 15:32:39 +0000183 # For vector data types, insert a matching length argument
184 tosaArgs.insert(
185 len(tosaArgs) - 1,
186 {
Grant Watsoneff70382023-09-12 10:46:36 +0100187 "name": f"{serAttName}_len",
Grant Watson64285a12022-11-16 15:32:39 +0000188 "type": "int32_t",
189 "shape": "",
190 "category": "",
191 },
192 )
Grant Watsoneff70382023-09-12 10:46:36 +0100193 init = f"const std::vector<{attType}> {attName}(&client_{serAttName}[0], &client_{serAttName}[0] + client_{serAttName}_len);"
Grant Watson64285a12022-11-16 15:32:39 +0000194 shape = "[]"
195 else:
Grant Watsoneff70382023-09-12 10:46:36 +0100196 init = ""
Grant Watson64285a12022-11-16 15:32:39 +0000197 shape = ""
Grant Watsoneff70382023-09-12 10:46:36 +0100198 serAtt["init"] = init
Grant Watson64285a12022-11-16 15:32:39 +0000199 # Insert new argument
200 tosaArgs.insert(
201 len(tosaArgs) - 1,
202 {
Grant Watsoneff70382023-09-12 10:46:36 +0100203 "name": serAttName,
204 "type": serAtt["dType"],
Grant Watson64285a12022-11-16 15:32:39 +0000205 "shape": shape,
206 "category": "",
207 },
208 )
209
210
211def getOperators(tosaXml):
212 """
213 Return a list of TOSA operators as defined by tosa.xml.
214 """
215 operators = []
Grant Watsoneff70382023-09-12 10:46:36 +0100216 ignoreOps = [
217 "while_loop",
218 "cond_if",
219 "const",
220 "custom",
221 "fft2d",
222 "rfft2d",
223 "variable",
224 "variable_read",
225 "variable_write",
226 ]
Grant Watson64285a12022-11-16 15:32:39 +0000227 opsXml = tosaXml.getElementsByTagName("operator")
Grant Watsoneff70382023-09-12 10:46:36 +0100228 allSerialLibAtts = getSerialLibAtts()
Grant Watson64285a12022-11-16 15:32:39 +0000229 for opXml in opsXml:
230 opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower()
231 if opName not in ignoreOps:
232 operator = {"name": opName}
233 operator["serializeAttType"] = getSerializeOpType(opName)
234 tosaArgs = getTosaArgs(opXml)
Grant Watsoneff70382023-09-12 10:46:36 +0100235 serialLibAtts = getSerialLibAttsForOp(opName, allSerialLibAtts, tosaArgs)
Grant Watson61680472023-05-31 14:56:13 +0100236 # Handle "axis" arguments
237 axisList = [arg["name"] for arg in tosaArgs if arg["name"] == "axis"]
238 if operator["serializeAttType"] == "None" and len(axisList) > 0:
239 operator["serializeAttType"] = "Axis"
Grant Watsoneff70382023-09-12 10:46:36 +0100240 serialLibAtts = [
Grant Watson61680472023-05-31 14:56:13 +0100241 {
242 "name": "axis",
243 "dType": "int32_t",
244 "SV": "S",
Grant Watsoneff70382023-09-12 10:46:36 +0100245 "init": "",
Grant Watson61680472023-05-31 14:56:13 +0100246 }
247 ]
Grant Watsoneff70382023-09-12 10:46:36 +0100248 updateTosaArgs(tosaArgs, serialLibAtts, tosaXml)
Grant Watson64285a12022-11-16 15:32:39 +0000249 operator["arguments"] = tosaArgs
Grant Watsoneff70382023-09-12 10:46:36 +0100250 operator["serialLibAtts"] = serialLibAtts
251 serializationAttNames = [att["name"] for att in serialLibAtts]
Grant Watson64285a12022-11-16 15:32:39 +0000252 operator["inputs"] = [
Grant Watsoneff70382023-09-12 10:46:36 +0100253 arg["name"]
254 for arg in tosaArgs
255 if arg["category"] == "input"
256 and arg["name"] not in serializationAttNames
Grant Watson64285a12022-11-16 15:32:39 +0000257 ]
258 operator["outputs"] = [
259 arg["name"] for arg in tosaArgs if arg["category"] == "output"
260 ]
261 operators.append(operator)
262 return operators
263
264
265def getTosaArgs(opXml):
266 """
267 Return the arguments required for the TOSA operator specified.
268 """
269 arguments = []
270 argsXml = opXml.getElementsByTagName("argument")
271 tosaTensorTypes = getTosaArgTypes(tosaXml)
272 tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"}
273 for xmlArg in argsXml:
274 argName = xmlArg.getAttribute("name").lower()
Grant Watsoneb741062023-06-23 16:52:12 +0100275 if xmlArg.getAttribute("tensor-element-type") == "resize_mode_t":
276 argType = "tosa_mode_t"
277 else:
278 argType = xmlArg.getAttribute("type")
Grant Watson64285a12022-11-16 15:32:39 +0000279 argShape = xmlArg.getAttribute("shape")
280 argCategory = xmlArg.getAttribute("category")
Grant Watsone70d9312023-08-28 16:34:28 +0100281 # FullyConnected workaround
282 if (argName == "weight" or argName == "bias") and (argCategory == "attribute"):
283 argCategory = "input"
Grant Watson64285a12022-11-16 15:32:39 +0000284 # Update argument type
285 if argType[-1:] == "*":
286 argType = argType[:-1]
287 if argCategory in ["input", "output"] and argType in tosaTensorTypes:
288 argType = "tosa_tensor_t"
289 argShape = ""
290 if argType in tosaTypeMap:
291 argType = tosaTypeMap[argType]
292 # Add a length argument for arrays with unknown compile-time size
293 if argShape != "" and argShape[0] == "[" and not argShape[1:-1].isnumeric():
294 argShape = "[]"
295 arguments.append(
296 {
297 "name": f"{argName}_len",
298 "type": "int32_t",
299 "shape": "",
300 "category": "",
301 }
302 )
303 elif argShape == "" or not argShape[0] == "[":
304 argShape = ""
305 # Append argument
306 arguments.append(
307 {
308 "name": argName,
309 "type": argType,
310 "shape": argShape,
311 "category": argCategory,
312 }
313 )
314 return arguments
315
316
317def clangFormat(filename):
318 cmd = ["clang-format", "-i", filename]
319 with open(os.devnull, "w") as devnull:
320 subprocess.check_call(cmd, stdout=devnull)
321
322
Grant Watsoneff70382023-09-12 10:46:36 +0100323def getSerialLibAtts():
Grant Watson64285a12022-11-16 15:32:39 +0000324 """
325 Parse attribute.def file and return a dictionary where the keys are Serialization library operator names.
326 The values are the arguments required by each Serialization library operator.
327 """
Grant Watsoneff70382023-09-12 10:46:36 +0100328 serialLibAtts = {}
Grant Watson64285a12022-11-16 15:32:39 +0000329 with open("../../thirdparty/serialization_lib/include/attribute.def") as file:
330 preamble = True
331 inAtt = False
332 opName = ""
333 args = []
334 for line in file:
335 if preamble and not line[: len("DEF_ATTRIBUTE(")] == "DEF_ATTRIBUTE(":
336 continue
337 else:
338 preamble = False
339 line = line.lstrip().rstrip()
340 if not inAtt and "DEF_ATTRIBUTE(" in line:
341 opName = line[len("DEF_ATTRIBUTE(") : line.find(",")]
342 inAtt = True
343 elif inAtt:
344 vals = line.split(",")
345 argName = vals[2].lstrip().strip()
346 if ")" in argName:
347 argName = argName[:-1]
348 arg = {
349 "name": argName,
350 "dType": vals[0].lstrip().strip(),
351 "SV": vals[1].lstrip().strip(),
352 }
353 args.append(arg)
354 if ")" in line:
Grant Watsoneff70382023-09-12 10:46:36 +0100355 serialLibAtts[opName] = args
Grant Watson64285a12022-11-16 15:32:39 +0000356 opName = ""
357 args = []
358 inAtt = False
Grant Watsoneff70382023-09-12 10:46:36 +0100359 return serialLibAtts
Grant Watson64285a12022-11-16 15:32:39 +0000360
361
362def renderTemplate(environment, dataTypes, operators, template, outfile):
363 content = template.render(dataTypes=dataTypes, operators=operators)
364 with open(outfile, mode="w", encoding="utf-8") as output:
365 output.write(content)
366 print(f"Created {outfile}")
367
368 clangFormat(outfile)
369
370
371def generate(environment, dataTypes, operators):
372 # Generate include/operators.h
373 template = environment.get_template("operators_h.j2")
374 outfile = os.path.join("..", "..", "reference_model", "include", "operators.h")
375 renderTemplate(environment, dataTypes, operators, template, outfile)
376
377 # Generate src/operators.cc
378 template = environment.get_template("operators_cc.j2")
379 outfile = os.path.join("..", "..", "reference_model", "src", "operators.cc")
380 renderTemplate(environment, dataTypes, operators, template, outfile)
381
382
383def getSerializeOpTypeMap():
384 """
385 Utility function for generating the map used in getSerializeOpType()
386 """
387 import re
388
Grant Watsoneff70382023-09-12 10:46:36 +0100389 allSerialLibAtts = getSerialLibAtts()
390 serAtts = [
Grant Watson64285a12022-11-16 15:32:39 +0000391 re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
Grant Watsoneff70382023-09-12 10:46:36 +0100392 for name in allSerialLibAtts.keys()
Grant Watson64285a12022-11-16 15:32:39 +0000393 ]
Grant Watsoneff70382023-09-12 10:46:36 +0100394 serAtts = sorted(serAtts, key=len, reverse=True)
Grant Watson64285a12022-11-16 15:32:39 +0000395 tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
396 opsXml = tosaXml.getElementsByTagName("operator")
397 opNames = [
398 op.getElementsByTagName("name")[0].firstChild.data.lower() for op in opsXml
399 ]
400 map = {}
401 for opName in opNames:
Grant Watsoneff70382023-09-12 10:46:36 +0100402 for serAtt in serAtts:
403 if serAtt in opName:
404 components = serAtt.split("_")
Grant Watson64285a12022-11-16 15:32:39 +0000405 map[opName] = "".join(x.title() for x in components)
406 return map
407
408
409if __name__ == "__main__":
410 environment = Environment(loader=FileSystemLoader("templates/"))
411 tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
412 dataTypes = getTosaDataTypes(tosaXml)
413 operators = getOperators(tosaXml)
414 generate(environment, dataTypes, operators)