blob: 7f105680718ccef94e6c67d9ff5065f2a8543bd7 [file] [log] [blame]
Grant Watson64285a12022-11-16 15:32:39 +00001"""Generate extended reference model API with eager operator execution entrypoints"""
Grant Watsone70d9312023-08-28 16:34:28 +01002# Copyright (c) 2021-2023, ARM Limited.
Grant Watson64285a12022-11-16 15:32:39 +00003# SPDX-License-Identifier: Apache-2.0
4import copy
5import os
6import subprocess
Eric Kunze99f8f9f2023-09-07 01:36:07 +00007from pathlib import Path
Grant Watson64285a12022-11-16 15:32:39 +00008from xml.dom import minidom
9
10from jinja2 import Environment
11from jinja2 import FileSystemLoader
12
James Wardd34b3fc2023-01-18 14:51:25 +000013# Note: main script designed to be run from the scripts/operator_api/ directory
14
Grant Watson64285a12022-11-16 15:32:39 +000015
Eric Kunze99f8f9f2023-09-07 01:36:07 +000016def getBasePath():
17 return Path(__file__).resolve().parent.parent.parent
18
19
Grant Watson64285a12022-11-16 15:32:39 +000020def getTosaArgTypes(tosaXml):
21 """
22 Returns a list of the TOSA argument types from tosa.xml.
23 """
Grant Watsoneb741062023-06-23 16:52:12 +010024 argTypes = {
25 "tensor_t",
26 "in_t",
27 "out_t",
28 "mul_t",
29 "weight_t",
30 "in_out_t",
31 "tensor_list_t",
32 }
Grant Watson64285a12022-11-16 15:32:39 +000033 argTypesXml = tosaXml.getElementsByTagName("type")
34 for argTypeXml in argTypesXml:
35 argTypes.add(argTypeXml.getAttribute("name"))
36 argTypes.remove("TABLE_SIZE")
37 return argTypes
38
39
40def getTosaDataTypes(tosaXml):
41 """
42 Returns a list of the TOSA data types from tosa.xml.
43 """
44 argTypes = getTosaArgTypes(tosaXml)
45 dataTypes = set()
46 dataTypesXml = tosaXml.getElementsByTagName("typesupport")
47 for dataTypeXml in dataTypesXml:
48 for argType in argTypes:
49 dataType = dataTypeXml.getAttribute(argType)
50 if dataType != "":
51 dataTypes.add(f"tosa_datatype_{dataType}")
52 return sorted(dataTypes)
53
54
55def getSerializeOpType(tosaOpName):
56 """
57 Returns the Serialization library operator that matches the TOSA operator specified.
58 """
59 map = {
60 "avg_pool2d": "Pool",
61 "conv2d": "Conv",
62 "conv3d": "Conv",
63 "depthwise_conv2d": "Conv",
64 "fully_connected": "FullyConnected",
Dhruv Chauhan35a3aa92023-11-28 15:00:34 +000065 "fft2d": "FFT",
66 "rfft2d": "RFFT",
Grant Watson64285a12022-11-16 15:32:39 +000067 "matmul": "MatMul",
68 "max_pool2d": "Pool",
Dmitrii Agibovb0b9e332023-11-01 13:49:37 +000069 "transpose_conv2d": "TransposeConv",
Grant Watson64285a12022-11-16 15:32:39 +000070 "clamp": "Clamp",
71 "arithmetic_right_shift": "ArithmeticRightShift",
72 "mul": "Mul",
73 "table": "Table",
74 "negate": "Negate",
75 "pad": "Pad",
76 "reshape": "Reshape",
77 "slice": "Slice",
78 "tile": "Tile",
79 "transpose": "Transpose",
80 "resize": "Resize",
81 "rescale": "Rescale",
82 "cond_if": "CondIf",
83 "while_loop": "WhileLoop",
84 }
85 if tosaOpName not in map.keys():
86 return "None"
87 else:
88 return map[tosaOpName]
89
90
Grant Watsoneff70382023-09-12 10:46:36 +010091def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs):
Grant Watson64285a12022-11-16 15:32:39 +000092 """
Grant Watsoneff70382023-09-12 10:46:36 +010093 Returns the attributes required by the Serialization library for the TOSA operator specified.
94 Generates code to initialize Serialization library attributes. If a matching TOSA argument exists,
Grant Watson64285a12022-11-16 15:32:39 +000095 that value is used for initialization, otherwise a default value e.g. 0 is used.
96 """
Grant Watsoneff70382023-09-12 10:46:36 +010097 serLibOpType = getSerializeOpType(tosaOpName)
98 if serLibOpType not in allSerialLibAtts.keys():
Grant Watson64285a12022-11-16 15:32:39 +000099 return {}
100 else:
Grant Watsoneff70382023-09-12 10:46:36 +0100101 serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType])
Grant Watson64285a12022-11-16 15:32:39 +0000102 tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
103 serTosaTypeMap = {"ResizeMode": "tosa_mode"}
Dmitrii Agibovb0b9e332023-11-01 13:49:37 +0000104 serAttsToFix = {
105 "reshape": {"new_shape": "shape"},
106 "transpose_conv2d": {"output_shape": "out_shape"},
107 }
108 if tosaOpName in serAttsToFix:
109 # Fix attributes names to match with tosa.xml
110 for attDefName, tosaSpecName in serAttsToFix[tosaOpName].items():
111 for opAtts in serLibOpAtts:
112 if opAtts["name"] == attDefName:
113 opAtts["name"] = tosaSpecName
Grant Watsoneff70382023-09-12 10:46:36 +0100114 for att in serLibOpAtts:
115 attName = att["name"]
116 attType = att["dType"]
Grant Watson64285a12022-11-16 15:32:39 +0000117 init = ""
Grant Watsoneff70382023-09-12 10:46:36 +0100118 # Translate TOSA data types to Serialization library data types for initialization
119 if attType in serTosaTypeMap.keys():
120 init = f"const {attType} {attName} = translate_client_{serTosaTypeMap[att['dType']]}(client_{attName});"
121 # Initialize Serialization library attributes to their matching function parameter
Dmitrii Agibovc8fdccf2023-09-21 11:05:58 +0100122 elif tosaOpName == "avg_pool2d" and attName == "accum_dtype":
123 init = f"const tosa::DType {attName} = translate_client_acc_size(client_acc_size);"
124 att["dType"] = "tosa::DType"
Grant Watsoneff70382023-09-12 10:46:36 +0100125 elif attName in tosaArgsDict:
126 if att["SV"] == "V":
127 if tosaArgsDict[attName]["type"] == "tosa_tensor_t":
128 init = f"std::vector<{attType}> {attName};"
129 init = (
130 init
131 + f"size_t {attName}_size = client_{attName}.size / sizeof({attType});"
132 )
133 init = (
134 init
135 + f"{attType}* {attName}_data = reinterpret_cast<{attType}*>(client_{attName}.data);"
136 )
137 init = (
138 init
139 + f"{attName}.assign({attName}_data, {attName}_data + {attName}_size);"
140 )
Grant Watson64285a12022-11-16 15:32:39 +0000141 else:
Grant Watsoneff70382023-09-12 10:46:36 +0100142 init = f"const std::vector<{attType}> {attName}"
143 shape = tosaArgsDict[attName]["shape"]
144 if shape == "[]":
145 init = (
146 init
147 + f"(&client_{attName}[0], &client_{attName}[0] + client_{attName}_len);"
148 )
149 else:
150 init = (
151 init
152 + f"(&client_{attName}[0], &client_{attName}{shape});"
153 )
Grant Watson64285a12022-11-16 15:32:39 +0000154 else:
Grant Watson64285a12022-11-16 15:32:39 +0000155 init = ""
Grant Watsoneff70382023-09-12 10:46:36 +0100156 else:
157 # Initialize Serialization library attributes with no matching fuction parameter
158 if att["SV"] == "V":
159 init = f"std::vector<int32_t> {attName};"
Grant Watson64285a12022-11-16 15:32:39 +0000160 else:
Grant Watsoneff70382023-09-12 10:46:36 +0100161 if att["dType"] == "DType":
162 att["dType"] = "tosa::DType"
163 init = f"const tosa::DType {attName} = tosa::DType::DType_FP32;"
Grant Watson64285a12022-11-16 15:32:39 +0000164 else:
Grant Watsoneff70382023-09-12 10:46:36 +0100165 init = f"const {attType} {attName} = 0;"
166 att["init"] = init
167 return serLibOpAtts
Grant Watson64285a12022-11-16 15:32:39 +0000168
169
Grant Watsoneff70382023-09-12 10:46:36 +0100170def updateTosaArgs(tosaArgs, serialLibAtts, tosaXml):
Grant Watson64285a12022-11-16 15:32:39 +0000171 """
Grant Watsoneff70382023-09-12 10:46:36 +0100172 Replace TOSA argument data types with their matching Serialization attribute data types.
Grant Watson64285a12022-11-16 15:32:39 +0000173 Delete TOSA arguments where the type couldn't be determined.
Grant Watsoneff70382023-09-12 10:46:36 +0100174 Add Serialization attributes that have no matching TOSA argument.
Grant Watson64285a12022-11-16 15:32:39 +0000175 """
176 tosaArgTypes = getTosaArgTypes(tosaXml)
Grant Watsoneff70382023-09-12 10:46:36 +0100177 serAttsDict = {att["name"]: att for att in serialLibAtts}
Grant Watson64285a12022-11-16 15:32:39 +0000178 tosaArgsNames = [arg["name"] for arg in tosaArgs]
179 delTosaArgs = []
Grant Watsoneff70382023-09-12 10:46:36 +0100180 # Replace TOSA argument data types with their matching Serialization attribute data types.
Grant Watson64285a12022-11-16 15:32:39 +0000181 for tosaArg in tosaArgs:
182 if tosaArg["type"] in tosaArgTypes:
Grant Watsoneff70382023-09-12 10:46:36 +0100183 if tosaArg["name"] in serAttsDict:
184 tosaArg["type"] = serAttsDict[tosaArg["name"]]["dType"]
Grant Watson64285a12022-11-16 15:32:39 +0000185 else:
186 # Delete TOSA argument whose data type can't be determined
187 delTosaArgs.append(tosaArgsNames.index(tosaArg["name"]))
188 # Delete corresponding length argument if one exists
189 lenArgName = f"{tosaArg['name']}_len"
190 if lenArgName in tosaArgsNames:
191 delTosaArgs.append(tosaArgsNames.index(lenArgName))
192 # Delete TOSA arguments where the type couldn't be determined
193 for index in sorted(delTosaArgs, key=int, reverse=True):
194 del tosaArgs[index]
Grant Watsoneff70382023-09-12 10:46:36 +0100195 # Add Serialization attributes that have no matching TOSA argument
Grant Watson64285a12022-11-16 15:32:39 +0000196 tosaArgNames = [arg["name"] for arg in tosaArgs]
Grant Watsoneff70382023-09-12 10:46:36 +0100197 for serAtt in serialLibAtts:
198 attName = serAtt["name"]
199 attType = serAtt["dType"]
200 if (attName not in tosaArgNames) and (not attType == "tosa::DType"):
201 serAttName = serAtt["name"]
202 if serAtt["SV"] == "V":
Grant Watson64285a12022-11-16 15:32:39 +0000203 # For vector data types, insert a matching length argument
204 tosaArgs.insert(
205 len(tosaArgs) - 1,
206 {
Grant Watsoneff70382023-09-12 10:46:36 +0100207 "name": f"{serAttName}_len",
Grant Watson64285a12022-11-16 15:32:39 +0000208 "type": "int32_t",
209 "shape": "",
210 "category": "",
211 },
212 )
Grant Watsoneff70382023-09-12 10:46:36 +0100213 init = f"const std::vector<{attType}> {attName}(&client_{serAttName}[0], &client_{serAttName}[0] + client_{serAttName}_len);"
Grant Watson64285a12022-11-16 15:32:39 +0000214 shape = "[]"
215 else:
Grant Watsoneff70382023-09-12 10:46:36 +0100216 init = ""
Grant Watson64285a12022-11-16 15:32:39 +0000217 shape = ""
Grant Watsoneff70382023-09-12 10:46:36 +0100218 serAtt["init"] = init
Grant Watson64285a12022-11-16 15:32:39 +0000219 # Insert new argument
220 tosaArgs.insert(
221 len(tosaArgs) - 1,
222 {
Grant Watsoneff70382023-09-12 10:46:36 +0100223 "name": serAttName,
224 "type": serAtt["dType"],
Grant Watson64285a12022-11-16 15:32:39 +0000225 "shape": shape,
226 "category": "",
227 },
228 )
229
230
231def getOperators(tosaXml):
232 """
233 Return a list of TOSA operators as defined by tosa.xml.
234 """
235 operators = []
Grant Watsoneff70382023-09-12 10:46:36 +0100236 ignoreOps = [
237 "while_loop",
238 "cond_if",
239 "const",
240 "custom",
Grant Watsoneff70382023-09-12 10:46:36 +0100241 "variable",
242 "variable_read",
243 "variable_write",
244 ]
Grant Watson64285a12022-11-16 15:32:39 +0000245 opsXml = tosaXml.getElementsByTagName("operator")
Grant Watsoneff70382023-09-12 10:46:36 +0100246 allSerialLibAtts = getSerialLibAtts()
Grant Watson64285a12022-11-16 15:32:39 +0000247 for opXml in opsXml:
248 opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower()
249 if opName not in ignoreOps:
250 operator = {"name": opName}
251 operator["serializeAttType"] = getSerializeOpType(opName)
252 tosaArgs = getTosaArgs(opXml)
Grant Watsoneff70382023-09-12 10:46:36 +0100253 serialLibAtts = getSerialLibAttsForOp(opName, allSerialLibAtts, tosaArgs)
Grant Watson61680472023-05-31 14:56:13 +0100254 # Handle "axis" arguments
255 axisList = [arg["name"] for arg in tosaArgs if arg["name"] == "axis"]
256 if operator["serializeAttType"] == "None" and len(axisList) > 0:
257 operator["serializeAttType"] = "Axis"
Grant Watsoneff70382023-09-12 10:46:36 +0100258 serialLibAtts = [
Grant Watson61680472023-05-31 14:56:13 +0100259 {
260 "name": "axis",
261 "dType": "int32_t",
262 "SV": "S",
Grant Watsoneff70382023-09-12 10:46:36 +0100263 "init": "",
Grant Watson61680472023-05-31 14:56:13 +0100264 }
265 ]
Grant Watsoneff70382023-09-12 10:46:36 +0100266 updateTosaArgs(tosaArgs, serialLibAtts, tosaXml)
Grant Watson64285a12022-11-16 15:32:39 +0000267 operator["arguments"] = tosaArgs
Grant Watsoneff70382023-09-12 10:46:36 +0100268 operator["serialLibAtts"] = serialLibAtts
269 serializationAttNames = [att["name"] for att in serialLibAtts]
Grant Watson64285a12022-11-16 15:32:39 +0000270 operator["inputs"] = [
Grant Watsoneff70382023-09-12 10:46:36 +0100271 arg["name"]
272 for arg in tosaArgs
273 if arg["category"] == "input"
274 and arg["name"] not in serializationAttNames
Grant Watson64285a12022-11-16 15:32:39 +0000275 ]
276 operator["outputs"] = [
277 arg["name"] for arg in tosaArgs if arg["category"] == "output"
278 ]
279 operators.append(operator)
280 return operators
281
282
283def getTosaArgs(opXml):
284 """
285 Return the arguments required for the TOSA operator specified.
286 """
287 arguments = []
288 argsXml = opXml.getElementsByTagName("argument")
289 tosaTensorTypes = getTosaArgTypes(tosaXml)
290 tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"}
Dmitrii Agibovc8fdccf2023-09-21 11:05:58 +0100291 tensorElemTypeMap = {
292 "resize_mode_t": "tosa_mode_t",
293 "acc_size_t": "tosa_acc_size_t",
294 }
Grant Watson64285a12022-11-16 15:32:39 +0000295 for xmlArg in argsXml:
296 argName = xmlArg.getAttribute("name").lower()
Dmitrii Agibovc8fdccf2023-09-21 11:05:58 +0100297 tensorElemType = xmlArg.getAttribute("tensor-element-type")
298 if tensorElemType in tensorElemTypeMap:
299 argType = tensorElemTypeMap[tensorElemType]
Grant Watsoneb741062023-06-23 16:52:12 +0100300 else:
301 argType = xmlArg.getAttribute("type")
Grant Watson64285a12022-11-16 15:32:39 +0000302 argShape = xmlArg.getAttribute("shape")
303 argCategory = xmlArg.getAttribute("category")
Grant Watsone70d9312023-08-28 16:34:28 +0100304 # FullyConnected workaround
305 if (argName == "weight" or argName == "bias") and (argCategory == "attribute"):
306 argCategory = "input"
Grant Watson64285a12022-11-16 15:32:39 +0000307 # Update argument type
308 if argType[-1:] == "*":
309 argType = argType[:-1]
310 if argCategory in ["input", "output"] and argType in tosaTensorTypes:
311 argType = "tosa_tensor_t"
312 argShape = ""
313 if argType in tosaTypeMap:
314 argType = tosaTypeMap[argType]
315 # Add a length argument for arrays with unknown compile-time size
316 if argShape != "" and argShape[0] == "[" and not argShape[1:-1].isnumeric():
317 argShape = "[]"
318 arguments.append(
319 {
320 "name": f"{argName}_len",
321 "type": "int32_t",
322 "shape": "",
323 "category": "",
324 }
325 )
326 elif argShape == "" or not argShape[0] == "[":
327 argShape = ""
328 # Append argument
329 arguments.append(
330 {
331 "name": argName,
332 "type": argType,
333 "shape": argShape,
334 "category": argCategory,
335 }
336 )
337 return arguments
338
339
340def clangFormat(filename):
341 cmd = ["clang-format", "-i", filename]
342 with open(os.devnull, "w") as devnull:
343 subprocess.check_call(cmd, stdout=devnull)
344
345
Grant Watsoneff70382023-09-12 10:46:36 +0100346def getSerialLibAtts():
Grant Watson64285a12022-11-16 15:32:39 +0000347 """
348 Parse attribute.def file and return a dictionary where the keys are Serialization library operator names.
349 The values are the arguments required by each Serialization library operator.
350 """
Grant Watsoneff70382023-09-12 10:46:36 +0100351 serialLibAtts = {}
Eric Kunze99f8f9f2023-09-07 01:36:07 +0000352 base_path = getBasePath()
353 attr_def = (
354 base_path / "thirdparty" / "serialization_lib" / "include" / "attribute.def"
355 )
356 with open(attr_def) as file:
Grant Watson64285a12022-11-16 15:32:39 +0000357 preamble = True
358 inAtt = False
359 opName = ""
360 args = []
361 for line in file:
362 if preamble and not line[: len("DEF_ATTRIBUTE(")] == "DEF_ATTRIBUTE(":
363 continue
364 else:
365 preamble = False
366 line = line.lstrip().rstrip()
367 if not inAtt and "DEF_ATTRIBUTE(" in line:
368 opName = line[len("DEF_ATTRIBUTE(") : line.find(",")]
369 inAtt = True
370 elif inAtt:
371 vals = line.split(",")
372 argName = vals[2].lstrip().strip()
373 if ")" in argName:
374 argName = argName[:-1]
375 arg = {
376 "name": argName,
377 "dType": vals[0].lstrip().strip(),
378 "SV": vals[1].lstrip().strip(),
379 }
380 args.append(arg)
381 if ")" in line:
Grant Watsoneff70382023-09-12 10:46:36 +0100382 serialLibAtts[opName] = args
Grant Watson64285a12022-11-16 15:32:39 +0000383 opName = ""
384 args = []
385 inAtt = False
Grant Watsoneff70382023-09-12 10:46:36 +0100386 return serialLibAtts
Grant Watson64285a12022-11-16 15:32:39 +0000387
388
389def renderTemplate(environment, dataTypes, operators, template, outfile):
390 content = template.render(dataTypes=dataTypes, operators=operators)
391 with open(outfile, mode="w", encoding="utf-8") as output:
392 output.write(content)
393 print(f"Created {outfile}")
394
395 clangFormat(outfile)
396
397
Eric Kunze99f8f9f2023-09-07 01:36:07 +0000398def generate(environment, dataTypes, operators, base_path):
Grant Watson64285a12022-11-16 15:32:39 +0000399 # Generate include/operators.h
400 template = environment.get_template("operators_h.j2")
Eric Kunze99f8f9f2023-09-07 01:36:07 +0000401 outfile = base_path / "reference_model/include/operators.h"
Grant Watson64285a12022-11-16 15:32:39 +0000402 renderTemplate(environment, dataTypes, operators, template, outfile)
403
404 # Generate src/operators.cc
405 template = environment.get_template("operators_cc.j2")
Eric Kunze99f8f9f2023-09-07 01:36:07 +0000406 outfile = base_path / "reference_model/src/operators.cc"
Grant Watson64285a12022-11-16 15:32:39 +0000407 renderTemplate(environment, dataTypes, operators, template, outfile)
408
409
Grant Watson64285a12022-11-16 15:32:39 +0000410if __name__ == "__main__":
Eric Kunze99f8f9f2023-09-07 01:36:07 +0000411 base_path = getBasePath()
412 environment = Environment(
413 loader=FileSystemLoader(Path(__file__).resolve().parent / "templates")
414 )
415 tosaXml = minidom.parse(str(base_path / "thirdparty/specification/tosa.xml"))
Grant Watson64285a12022-11-16 15:32:39 +0000416 dataTypes = getTosaDataTypes(tosaXml)
417 operators = getOperators(tosaXml)
Eric Kunze99f8f9f2023-09-07 01:36:07 +0000418 generate(environment, dataTypes, operators, base_path)