blob: d9077f06d256a8caff61e19cf7ea05620548ded6 [file] [log] [blame]
Grant Watson64285a12022-11-16 15:32:39 +00001"""Generate extended reference model API with eager operator execution entrypoints"""
Grant Watsone70d9312023-08-28 16:34:28 +01002# Copyright (c) 2021-2023, ARM Limited.
Grant Watson64285a12022-11-16 15:32:39 +00003# SPDX-License-Identifier: Apache-2.0
4import copy
5import os
6import subprocess
Eric Kunze99f8f9f2023-09-07 01:36:07 +00007from pathlib import Path
Grant Watson64285a12022-11-16 15:32:39 +00008from xml.dom import minidom
9
10from jinja2 import Environment
11from jinja2 import FileSystemLoader
12
James Wardd34b3fc2023-01-18 14:51:25 +000013# Note: main script designed to be run from the scripts/operator_api/ directory
14
Grant Watson64285a12022-11-16 15:32:39 +000015
Eric Kunze99f8f9f2023-09-07 01:36:07 +000016def getBasePath():
17 return Path(__file__).resolve().parent.parent.parent
18
19
Grant Watson64285a12022-11-16 15:32:39 +000020def getTosaArgTypes(tosaXml):
21 """
22 Returns a list of the TOSA argument types from tosa.xml.
23 """
Grant Watsoneb741062023-06-23 16:52:12 +010024 argTypes = {
25 "tensor_t",
26 "in_t",
27 "out_t",
28 "mul_t",
29 "weight_t",
30 "in_out_t",
31 "tensor_list_t",
32 }
Grant Watson64285a12022-11-16 15:32:39 +000033 argTypesXml = tosaXml.getElementsByTagName("type")
34 for argTypeXml in argTypesXml:
35 argTypes.add(argTypeXml.getAttribute("name"))
36 argTypes.remove("TABLE_SIZE")
37 return argTypes
38
39
40def getTosaDataTypes(tosaXml):
41 """
42 Returns a list of the TOSA data types from tosa.xml.
43 """
44 argTypes = getTosaArgTypes(tosaXml)
45 dataTypes = set()
46 dataTypesXml = tosaXml.getElementsByTagName("typesupport")
47 for dataTypeXml in dataTypesXml:
48 for argType in argTypes:
49 dataType = dataTypeXml.getAttribute(argType)
50 if dataType != "":
51 dataTypes.add(f"tosa_datatype_{dataType}")
52 return sorted(dataTypes)
53
54
55def getSerializeOpType(tosaOpName):
56 """
57 Returns the Serialization library operator that matches the TOSA operator specified.
58 """
59 map = {
60 "avg_pool2d": "Pool",
61 "conv2d": "Conv",
62 "conv3d": "Conv",
63 "depthwise_conv2d": "Conv",
64 "fully_connected": "FullyConnected",
65 "matmul": "MatMul",
66 "max_pool2d": "Pool",
Dmitrii Agibovb0b9e332023-11-01 13:49:37 +000067 "transpose_conv2d": "TransposeConv",
Grant Watson64285a12022-11-16 15:32:39 +000068 "clamp": "Clamp",
69 "arithmetic_right_shift": "ArithmeticRightShift",
70 "mul": "Mul",
71 "table": "Table",
72 "negate": "Negate",
73 "pad": "Pad",
74 "reshape": "Reshape",
75 "slice": "Slice",
76 "tile": "Tile",
77 "transpose": "Transpose",
78 "resize": "Resize",
79 "rescale": "Rescale",
80 "cond_if": "CondIf",
81 "while_loop": "WhileLoop",
82 }
83 if tosaOpName not in map.keys():
84 return "None"
85 else:
86 return map[tosaOpName]
87
88
Grant Watsoneff70382023-09-12 10:46:36 +010089def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs):
Grant Watson64285a12022-11-16 15:32:39 +000090 """
Grant Watsoneff70382023-09-12 10:46:36 +010091 Returns the attributes required by the Serialization library for the TOSA operator specified.
92 Generates code to initialize Serialization library attributes. If a matching TOSA argument exists,
Grant Watson64285a12022-11-16 15:32:39 +000093 that value is used for initialization, otherwise a default value e.g. 0 is used.
94 """
Grant Watsoneff70382023-09-12 10:46:36 +010095 serLibOpType = getSerializeOpType(tosaOpName)
96 if serLibOpType not in allSerialLibAtts.keys():
Grant Watson64285a12022-11-16 15:32:39 +000097 return {}
98 else:
Grant Watsoneff70382023-09-12 10:46:36 +010099 serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType])
Grant Watson64285a12022-11-16 15:32:39 +0000100 tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
101 serTosaTypeMap = {"ResizeMode": "tosa_mode"}
Dmitrii Agibovb0b9e332023-11-01 13:49:37 +0000102 serAttsToFix = {
103 "reshape": {"new_shape": "shape"},
104 "transpose_conv2d": {"output_shape": "out_shape"},
105 }
106 if tosaOpName in serAttsToFix:
107 # Fix attributes names to match with tosa.xml
108 for attDefName, tosaSpecName in serAttsToFix[tosaOpName].items():
109 for opAtts in serLibOpAtts:
110 if opAtts["name"] == attDefName:
111 opAtts["name"] = tosaSpecName
Grant Watsoneff70382023-09-12 10:46:36 +0100112 for att in serLibOpAtts:
113 attName = att["name"]
114 attType = att["dType"]
Grant Watson64285a12022-11-16 15:32:39 +0000115 init = ""
Grant Watsoneff70382023-09-12 10:46:36 +0100116 # Translate TOSA data types to Serialization library data types for initialization
117 if attType in serTosaTypeMap.keys():
118 init = f"const {attType} {attName} = translate_client_{serTosaTypeMap[att['dType']]}(client_{attName});"
119 # Initialize Serialization library attributes to their matching function parameter
Dmitrii Agibovc8fdccf2023-09-21 11:05:58 +0100120 elif tosaOpName == "avg_pool2d" and attName == "accum_dtype":
121 init = f"const tosa::DType {attName} = translate_client_acc_size(client_acc_size);"
122 att["dType"] = "tosa::DType"
Grant Watsoneff70382023-09-12 10:46:36 +0100123 elif attName in tosaArgsDict:
124 if att["SV"] == "V":
125 if tosaArgsDict[attName]["type"] == "tosa_tensor_t":
126 init = f"std::vector<{attType}> {attName};"
127 init = (
128 init
129 + f"size_t {attName}_size = client_{attName}.size / sizeof({attType});"
130 )
131 init = (
132 init
133 + f"{attType}* {attName}_data = reinterpret_cast<{attType}*>(client_{attName}.data);"
134 )
135 init = (
136 init
137 + f"{attName}.assign({attName}_data, {attName}_data + {attName}_size);"
138 )
Grant Watson64285a12022-11-16 15:32:39 +0000139 else:
Grant Watsoneff70382023-09-12 10:46:36 +0100140 init = f"const std::vector<{attType}> {attName}"
141 shape = tosaArgsDict[attName]["shape"]
142 if shape == "[]":
143 init = (
144 init
145 + f"(&client_{attName}[0], &client_{attName}[0] + client_{attName}_len);"
146 )
147 else:
148 init = (
149 init
150 + f"(&client_{attName}[0], &client_{attName}{shape});"
151 )
Grant Watson64285a12022-11-16 15:32:39 +0000152 else:
Grant Watson64285a12022-11-16 15:32:39 +0000153 init = ""
Grant Watsoneff70382023-09-12 10:46:36 +0100154 else:
155 # Initialize Serialization library attributes with no matching fuction parameter
156 if att["SV"] == "V":
157 init = f"std::vector<int32_t> {attName};"
Grant Watson64285a12022-11-16 15:32:39 +0000158 else:
Grant Watsoneff70382023-09-12 10:46:36 +0100159 if att["dType"] == "DType":
160 att["dType"] = "tosa::DType"
161 init = f"const tosa::DType {attName} = tosa::DType::DType_FP32;"
Grant Watson64285a12022-11-16 15:32:39 +0000162 else:
Grant Watsoneff70382023-09-12 10:46:36 +0100163 init = f"const {attType} {attName} = 0;"
164 att["init"] = init
165 return serLibOpAtts
Grant Watson64285a12022-11-16 15:32:39 +0000166
167
Grant Watsoneff70382023-09-12 10:46:36 +0100168def updateTosaArgs(tosaArgs, serialLibAtts, tosaXml):
Grant Watson64285a12022-11-16 15:32:39 +0000169 """
Grant Watsoneff70382023-09-12 10:46:36 +0100170 Replace TOSA argument data types with their matching Serialization attribute data types.
Grant Watson64285a12022-11-16 15:32:39 +0000171 Delete TOSA arguments where the type couldn't be determined.
Grant Watsoneff70382023-09-12 10:46:36 +0100172 Add Serialization attributes that have no matching TOSA argument.
Grant Watson64285a12022-11-16 15:32:39 +0000173 """
174 tosaArgTypes = getTosaArgTypes(tosaXml)
Grant Watsoneff70382023-09-12 10:46:36 +0100175 serAttsDict = {att["name"]: att for att in serialLibAtts}
Grant Watson64285a12022-11-16 15:32:39 +0000176 tosaArgsNames = [arg["name"] for arg in tosaArgs]
177 delTosaArgs = []
Grant Watsoneff70382023-09-12 10:46:36 +0100178 # Replace TOSA argument data types with their matching Serialization attribute data types.
Grant Watson64285a12022-11-16 15:32:39 +0000179 for tosaArg in tosaArgs:
180 if tosaArg["type"] in tosaArgTypes:
Grant Watsoneff70382023-09-12 10:46:36 +0100181 if tosaArg["name"] in serAttsDict:
182 tosaArg["type"] = serAttsDict[tosaArg["name"]]["dType"]
Grant Watson64285a12022-11-16 15:32:39 +0000183 else:
184 # Delete TOSA argument whose data type can't be determined
185 delTosaArgs.append(tosaArgsNames.index(tosaArg["name"]))
186 # Delete corresponding length argument if one exists
187 lenArgName = f"{tosaArg['name']}_len"
188 if lenArgName in tosaArgsNames:
189 delTosaArgs.append(tosaArgsNames.index(lenArgName))
190 # Delete TOSA arguments where the type couldn't be determined
191 for index in sorted(delTosaArgs, key=int, reverse=True):
192 del tosaArgs[index]
Grant Watsoneff70382023-09-12 10:46:36 +0100193 # Add Serialization attributes that have no matching TOSA argument
Grant Watson64285a12022-11-16 15:32:39 +0000194 tosaArgNames = [arg["name"] for arg in tosaArgs]
Grant Watsoneff70382023-09-12 10:46:36 +0100195 for serAtt in serialLibAtts:
196 attName = serAtt["name"]
197 attType = serAtt["dType"]
198 if (attName not in tosaArgNames) and (not attType == "tosa::DType"):
199 serAttName = serAtt["name"]
200 if serAtt["SV"] == "V":
Grant Watson64285a12022-11-16 15:32:39 +0000201 # For vector data types, insert a matching length argument
202 tosaArgs.insert(
203 len(tosaArgs) - 1,
204 {
Grant Watsoneff70382023-09-12 10:46:36 +0100205 "name": f"{serAttName}_len",
Grant Watson64285a12022-11-16 15:32:39 +0000206 "type": "int32_t",
207 "shape": "",
208 "category": "",
209 },
210 )
Grant Watsoneff70382023-09-12 10:46:36 +0100211 init = f"const std::vector<{attType}> {attName}(&client_{serAttName}[0], &client_{serAttName}[0] + client_{serAttName}_len);"
Grant Watson64285a12022-11-16 15:32:39 +0000212 shape = "[]"
213 else:
Grant Watsoneff70382023-09-12 10:46:36 +0100214 init = ""
Grant Watson64285a12022-11-16 15:32:39 +0000215 shape = ""
Grant Watsoneff70382023-09-12 10:46:36 +0100216 serAtt["init"] = init
Grant Watson64285a12022-11-16 15:32:39 +0000217 # Insert new argument
218 tosaArgs.insert(
219 len(tosaArgs) - 1,
220 {
Grant Watsoneff70382023-09-12 10:46:36 +0100221 "name": serAttName,
222 "type": serAtt["dType"],
Grant Watson64285a12022-11-16 15:32:39 +0000223 "shape": shape,
224 "category": "",
225 },
226 )
227
228
229def getOperators(tosaXml):
230 """
231 Return a list of TOSA operators as defined by tosa.xml.
232 """
233 operators = []
Grant Watsoneff70382023-09-12 10:46:36 +0100234 ignoreOps = [
235 "while_loop",
236 "cond_if",
237 "const",
238 "custom",
239 "fft2d",
240 "rfft2d",
241 "variable",
242 "variable_read",
243 "variable_write",
244 ]
Grant Watson64285a12022-11-16 15:32:39 +0000245 opsXml = tosaXml.getElementsByTagName("operator")
Grant Watsoneff70382023-09-12 10:46:36 +0100246 allSerialLibAtts = getSerialLibAtts()
Grant Watson64285a12022-11-16 15:32:39 +0000247 for opXml in opsXml:
248 opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower()
249 if opName not in ignoreOps:
250 operator = {"name": opName}
251 operator["serializeAttType"] = getSerializeOpType(opName)
252 tosaArgs = getTosaArgs(opXml)
Grant Watsoneff70382023-09-12 10:46:36 +0100253 serialLibAtts = getSerialLibAttsForOp(opName, allSerialLibAtts, tosaArgs)
Grant Watson61680472023-05-31 14:56:13 +0100254 # Handle "axis" arguments
255 axisList = [arg["name"] for arg in tosaArgs if arg["name"] == "axis"]
256 if operator["serializeAttType"] == "None" and len(axisList) > 0:
257 operator["serializeAttType"] = "Axis"
Grant Watsoneff70382023-09-12 10:46:36 +0100258 serialLibAtts = [
Grant Watson61680472023-05-31 14:56:13 +0100259 {
260 "name": "axis",
261 "dType": "int32_t",
262 "SV": "S",
Grant Watsoneff70382023-09-12 10:46:36 +0100263 "init": "",
Grant Watson61680472023-05-31 14:56:13 +0100264 }
265 ]
Grant Watsoneff70382023-09-12 10:46:36 +0100266 updateTosaArgs(tosaArgs, serialLibAtts, tosaXml)
Grant Watson64285a12022-11-16 15:32:39 +0000267 operator["arguments"] = tosaArgs
Grant Watsoneff70382023-09-12 10:46:36 +0100268 operator["serialLibAtts"] = serialLibAtts
269 serializationAttNames = [att["name"] for att in serialLibAtts]
Grant Watson64285a12022-11-16 15:32:39 +0000270 operator["inputs"] = [
Grant Watsoneff70382023-09-12 10:46:36 +0100271 arg["name"]
272 for arg in tosaArgs
273 if arg["category"] == "input"
274 and arg["name"] not in serializationAttNames
Grant Watson64285a12022-11-16 15:32:39 +0000275 ]
276 operator["outputs"] = [
277 arg["name"] for arg in tosaArgs if arg["category"] == "output"
278 ]
279 operators.append(operator)
280 return operators
281
282
283def getTosaArgs(opXml):
284 """
285 Return the arguments required for the TOSA operator specified.
286 """
287 arguments = []
288 argsXml = opXml.getElementsByTagName("argument")
289 tosaTensorTypes = getTosaArgTypes(tosaXml)
290 tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"}
Dmitrii Agibovc8fdccf2023-09-21 11:05:58 +0100291 tensorElemTypeMap = {
292 "resize_mode_t": "tosa_mode_t",
293 "acc_size_t": "tosa_acc_size_t",
294 }
Grant Watson64285a12022-11-16 15:32:39 +0000295 for xmlArg in argsXml:
296 argName = xmlArg.getAttribute("name").lower()
Dmitrii Agibovc8fdccf2023-09-21 11:05:58 +0100297 tensorElemType = xmlArg.getAttribute("tensor-element-type")
298 if tensorElemType in tensorElemTypeMap:
299 argType = tensorElemTypeMap[tensorElemType]
Grant Watsoneb741062023-06-23 16:52:12 +0100300 else:
301 argType = xmlArg.getAttribute("type")
Grant Watson64285a12022-11-16 15:32:39 +0000302 argShape = xmlArg.getAttribute("shape")
303 argCategory = xmlArg.getAttribute("category")
Grant Watsone70d9312023-08-28 16:34:28 +0100304 # FullyConnected workaround
305 if (argName == "weight" or argName == "bias") and (argCategory == "attribute"):
306 argCategory = "input"
Grant Watson64285a12022-11-16 15:32:39 +0000307 # Update argument type
308 if argType[-1:] == "*":
309 argType = argType[:-1]
310 if argCategory in ["input", "output"] and argType in tosaTensorTypes:
311 argType = "tosa_tensor_t"
312 argShape = ""
313 if argType in tosaTypeMap:
314 argType = tosaTypeMap[argType]
315 # Add a length argument for arrays with unknown compile-time size
316 if argShape != "" and argShape[0] == "[" and not argShape[1:-1].isnumeric():
317 argShape = "[]"
318 arguments.append(
319 {
320 "name": f"{argName}_len",
321 "type": "int32_t",
322 "shape": "",
323 "category": "",
324 }
325 )
326 elif argShape == "" or not argShape[0] == "[":
327 argShape = ""
328 # Append argument
329 arguments.append(
330 {
331 "name": argName,
332 "type": argType,
333 "shape": argShape,
334 "category": argCategory,
335 }
336 )
337 return arguments
338
339
340def clangFormat(filename):
341 cmd = ["clang-format", "-i", filename]
342 with open(os.devnull, "w") as devnull:
343 subprocess.check_call(cmd, stdout=devnull)
344
345
Grant Watsoneff70382023-09-12 10:46:36 +0100346def getSerialLibAtts():
Grant Watson64285a12022-11-16 15:32:39 +0000347 """
348 Parse attribute.def file and return a dictionary where the keys are Serialization library operator names.
349 The values are the arguments required by each Serialization library operator.
350 """
Grant Watsoneff70382023-09-12 10:46:36 +0100351 serialLibAtts = {}
Eric Kunze99f8f9f2023-09-07 01:36:07 +0000352 base_path = getBasePath()
353 attr_def = (
354 base_path / "thirdparty" / "serialization_lib" / "include" / "attribute.def"
355 )
356 with open(attr_def) as file:
Grant Watson64285a12022-11-16 15:32:39 +0000357 preamble = True
358 inAtt = False
359 opName = ""
360 args = []
361 for line in file:
362 if preamble and not line[: len("DEF_ATTRIBUTE(")] == "DEF_ATTRIBUTE(":
363 continue
364 else:
365 preamble = False
366 line = line.lstrip().rstrip()
367 if not inAtt and "DEF_ATTRIBUTE(" in line:
368 opName = line[len("DEF_ATTRIBUTE(") : line.find(",")]
369 inAtt = True
370 elif inAtt:
371 vals = line.split(",")
372 argName = vals[2].lstrip().strip()
373 if ")" in argName:
374 argName = argName[:-1]
375 arg = {
376 "name": argName,
377 "dType": vals[0].lstrip().strip(),
378 "SV": vals[1].lstrip().strip(),
379 }
380 args.append(arg)
381 if ")" in line:
Grant Watsoneff70382023-09-12 10:46:36 +0100382 serialLibAtts[opName] = args
Grant Watson64285a12022-11-16 15:32:39 +0000383 opName = ""
384 args = []
385 inAtt = False
Grant Watsoneff70382023-09-12 10:46:36 +0100386 return serialLibAtts
Grant Watson64285a12022-11-16 15:32:39 +0000387
388
389def renderTemplate(environment, dataTypes, operators, template, outfile):
390 content = template.render(dataTypes=dataTypes, operators=operators)
391 with open(outfile, mode="w", encoding="utf-8") as output:
392 output.write(content)
393 print(f"Created {outfile}")
394
395 clangFormat(outfile)
396
397
Eric Kunze99f8f9f2023-09-07 01:36:07 +0000398def generate(environment, dataTypes, operators, base_path):
Grant Watson64285a12022-11-16 15:32:39 +0000399 # Generate include/operators.h
400 template = environment.get_template("operators_h.j2")
Eric Kunze99f8f9f2023-09-07 01:36:07 +0000401 outfile = base_path / "reference_model/include/operators.h"
Grant Watson64285a12022-11-16 15:32:39 +0000402 renderTemplate(environment, dataTypes, operators, template, outfile)
403
404 # Generate src/operators.cc
405 template = environment.get_template("operators_cc.j2")
Eric Kunze99f8f9f2023-09-07 01:36:07 +0000406 outfile = base_path / "reference_model/src/operators.cc"
Grant Watson64285a12022-11-16 15:32:39 +0000407 renderTemplate(environment, dataTypes, operators, template, outfile)
408
409
Grant Watson64285a12022-11-16 15:32:39 +0000410if __name__ == "__main__":
Eric Kunze99f8f9f2023-09-07 01:36:07 +0000411 base_path = getBasePath()
412 environment = Environment(
413 loader=FileSystemLoader(Path(__file__).resolve().parent / "templates")
414 )
415 tosaXml = minidom.parse(str(base_path / "thirdparty/specification/tosa.xml"))
Grant Watson64285a12022-11-16 15:32:39 +0000416 dataTypes = getTosaDataTypes(tosaXml)
417 operators = getOperators(tosaXml)
Eric Kunze99f8f9f2023-09-07 01:36:07 +0000418 generate(environment, dataTypes, operators, base_path)