blob: 99639f4e798e6cd5e47da2f146272247d10db510 [file] [log] [blame]
Grant Watson64285a12022-11-16 15:32:39 +00001"""Generate extended reference model API with eager operator execution entrypoints"""
Grant Watsone70d9312023-08-28 16:34:28 +01002# Copyright (c) 2021-2023, ARM Limited.
Grant Watson64285a12022-11-16 15:32:39 +00003# SPDX-License-Identifier: Apache-2.0
4import copy
5import os
6import subprocess
Eric Kunze99f8f9f2023-09-07 01:36:07 +00007from pathlib import Path
Grant Watson64285a12022-11-16 15:32:39 +00008from xml.dom import minidom
9
10from jinja2 import Environment
11from jinja2 import FileSystemLoader
12
James Wardd34b3fc2023-01-18 14:51:25 +000013# Note: main script designed to be run from the scripts/operator_api/ directory
14
Grant Watson64285a12022-11-16 15:32:39 +000015
Eric Kunze99f8f9f2023-09-07 01:36:07 +000016def getBasePath():
17 return Path(__file__).resolve().parent.parent.parent
18
19
Grant Watson64285a12022-11-16 15:32:39 +000020def getTosaArgTypes(tosaXml):
21 """
22 Returns a list of the TOSA argument types from tosa.xml.
23 """
Grant Watsoneb741062023-06-23 16:52:12 +010024 argTypes = {
25 "tensor_t",
26 "in_t",
27 "out_t",
28 "mul_t",
29 "weight_t",
30 "in_out_t",
31 "tensor_list_t",
32 }
Grant Watson64285a12022-11-16 15:32:39 +000033 argTypesXml = tosaXml.getElementsByTagName("type")
34 for argTypeXml in argTypesXml:
35 argTypes.add(argTypeXml.getAttribute("name"))
36 argTypes.remove("TABLE_SIZE")
37 return argTypes
38
39
40def getTosaDataTypes(tosaXml):
41 """
42 Returns a list of the TOSA data types from tosa.xml.
43 """
44 argTypes = getTosaArgTypes(tosaXml)
45 dataTypes = set()
46 dataTypesXml = tosaXml.getElementsByTagName("typesupport")
47 for dataTypeXml in dataTypesXml:
48 for argType in argTypes:
49 dataType = dataTypeXml.getAttribute(argType)
50 if dataType != "":
51 dataTypes.add(f"tosa_datatype_{dataType}")
52 return sorted(dataTypes)
53
54
55def getSerializeOpType(tosaOpName):
56 """
57 Returns the Serialization library operator that matches the TOSA operator specified.
58 """
59 map = {
60 "avg_pool2d": "Pool",
61 "conv2d": "Conv",
62 "conv3d": "Conv",
63 "depthwise_conv2d": "Conv",
64 "fully_connected": "FullyConnected",
65 "matmul": "MatMul",
66 "max_pool2d": "Pool",
67 "transpose_conv2d": "Conv",
68 "clamp": "Clamp",
69 "arithmetic_right_shift": "ArithmeticRightShift",
70 "mul": "Mul",
71 "table": "Table",
72 "negate": "Negate",
73 "pad": "Pad",
74 "reshape": "Reshape",
75 "slice": "Slice",
76 "tile": "Tile",
77 "transpose": "Transpose",
78 "resize": "Resize",
79 "rescale": "Rescale",
80 "cond_if": "CondIf",
81 "while_loop": "WhileLoop",
82 }
83 if tosaOpName not in map.keys():
84 return "None"
85 else:
86 return map[tosaOpName]
87
88
Grant Watsoneff70382023-09-12 10:46:36 +010089def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs):
Grant Watson64285a12022-11-16 15:32:39 +000090 """
Grant Watsoneff70382023-09-12 10:46:36 +010091 Returns the attributes required by the Serialization library for the TOSA operator specified.
92 Generates code to initialize Serialization library attributes. If a matching TOSA argument exists,
Grant Watson64285a12022-11-16 15:32:39 +000093 that value is used for initialization, otherwise a default value e.g. 0 is used.
94 """
Grant Watsoneff70382023-09-12 10:46:36 +010095 serLibOpType = getSerializeOpType(tosaOpName)
96 if serLibOpType not in allSerialLibAtts.keys():
Grant Watson64285a12022-11-16 15:32:39 +000097 return {}
98 else:
Grant Watsoneff70382023-09-12 10:46:36 +010099 serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType])
Grant Watson64285a12022-11-16 15:32:39 +0000100 tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
101 serTosaTypeMap = {"ResizeMode": "tosa_mode"}
Grant Watsonce53cd12023-10-31 19:02:14 +0000102 # For reshape operator, change 'new_shape' to 'shape' to match tosa.xml
103 if tosaOpName == "reshape":
104 serLibOpAtts[0]["name"] = "shape"
Grant Watsoneff70382023-09-12 10:46:36 +0100105 for att in serLibOpAtts:
106 attName = att["name"]
107 attType = att["dType"]
Grant Watson64285a12022-11-16 15:32:39 +0000108 init = ""
Grant Watsoneff70382023-09-12 10:46:36 +0100109 # Translate TOSA data types to Serialization library data types for initialization
110 if attType in serTosaTypeMap.keys():
111 init = f"const {attType} {attName} = translate_client_{serTosaTypeMap[att['dType']]}(client_{attName});"
112 # Initialize Serialization library attributes to their matching function parameter
Dmitrii Agibovc8fdccf2023-09-21 11:05:58 +0100113 elif tosaOpName == "avg_pool2d" and attName == "accum_dtype":
114 init = f"const tosa::DType {attName} = translate_client_acc_size(client_acc_size);"
115 att["dType"] = "tosa::DType"
Grant Watsoneff70382023-09-12 10:46:36 +0100116 elif attName in tosaArgsDict:
117 if att["SV"] == "V":
118 if tosaArgsDict[attName]["type"] == "tosa_tensor_t":
119 init = f"std::vector<{attType}> {attName};"
120 init = (
121 init
122 + f"size_t {attName}_size = client_{attName}.size / sizeof({attType});"
123 )
124 init = (
125 init
126 + f"{attType}* {attName}_data = reinterpret_cast<{attType}*>(client_{attName}.data);"
127 )
128 init = (
129 init
130 + f"{attName}.assign({attName}_data, {attName}_data + {attName}_size);"
131 )
Grant Watson64285a12022-11-16 15:32:39 +0000132 else:
Grant Watsoneff70382023-09-12 10:46:36 +0100133 init = f"const std::vector<{attType}> {attName}"
134 shape = tosaArgsDict[attName]["shape"]
135 if shape == "[]":
136 init = (
137 init
138 + f"(&client_{attName}[0], &client_{attName}[0] + client_{attName}_len);"
139 )
140 else:
141 init = (
142 init
143 + f"(&client_{attName}[0], &client_{attName}{shape});"
144 )
Grant Watson64285a12022-11-16 15:32:39 +0000145 else:
Grant Watson64285a12022-11-16 15:32:39 +0000146 init = ""
Grant Watsoneff70382023-09-12 10:46:36 +0100147 else:
148 # Initialize Serialization library attributes with no matching fuction parameter
149 if att["SV"] == "V":
150 init = f"std::vector<int32_t> {attName};"
Grant Watson64285a12022-11-16 15:32:39 +0000151 else:
Grant Watsoneff70382023-09-12 10:46:36 +0100152 if att["dType"] == "DType":
153 att["dType"] = "tosa::DType"
154 init = f"const tosa::DType {attName} = tosa::DType::DType_FP32;"
Grant Watson64285a12022-11-16 15:32:39 +0000155 else:
Grant Watsoneff70382023-09-12 10:46:36 +0100156 init = f"const {attType} {attName} = 0;"
157 att["init"] = init
158 return serLibOpAtts
Grant Watson64285a12022-11-16 15:32:39 +0000159
160
Grant Watsoneff70382023-09-12 10:46:36 +0100161def updateTosaArgs(tosaArgs, serialLibAtts, tosaXml):
Grant Watson64285a12022-11-16 15:32:39 +0000162 """
Grant Watsoneff70382023-09-12 10:46:36 +0100163 Replace TOSA argument data types with their matching Serialization attribute data types.
Grant Watson64285a12022-11-16 15:32:39 +0000164 Delete TOSA arguments where the type couldn't be determined.
Grant Watsoneff70382023-09-12 10:46:36 +0100165 Add Serialization attributes that have no matching TOSA argument.
Grant Watson64285a12022-11-16 15:32:39 +0000166 """
167 tosaArgTypes = getTosaArgTypes(tosaXml)
Grant Watsoneff70382023-09-12 10:46:36 +0100168 serAttsDict = {att["name"]: att for att in serialLibAtts}
Grant Watson64285a12022-11-16 15:32:39 +0000169 tosaArgsNames = [arg["name"] for arg in tosaArgs]
170 delTosaArgs = []
Grant Watsoneff70382023-09-12 10:46:36 +0100171 # Replace TOSA argument data types with their matching Serialization attribute data types.
Grant Watson64285a12022-11-16 15:32:39 +0000172 for tosaArg in tosaArgs:
173 if tosaArg["type"] in tosaArgTypes:
Grant Watsoneff70382023-09-12 10:46:36 +0100174 if tosaArg["name"] in serAttsDict:
175 tosaArg["type"] = serAttsDict[tosaArg["name"]]["dType"]
Grant Watson64285a12022-11-16 15:32:39 +0000176 else:
177 # Delete TOSA argument whose data type can't be determined
178 delTosaArgs.append(tosaArgsNames.index(tosaArg["name"]))
179 # Delete corresponding length argument if one exists
180 lenArgName = f"{tosaArg['name']}_len"
181 if lenArgName in tosaArgsNames:
182 delTosaArgs.append(tosaArgsNames.index(lenArgName))
183 # Delete TOSA arguments where the type couldn't be determined
184 for index in sorted(delTosaArgs, key=int, reverse=True):
185 del tosaArgs[index]
Grant Watsoneff70382023-09-12 10:46:36 +0100186 # Add Serialization attributes that have no matching TOSA argument
Grant Watson64285a12022-11-16 15:32:39 +0000187 tosaArgNames = [arg["name"] for arg in tosaArgs]
Grant Watsoneff70382023-09-12 10:46:36 +0100188 for serAtt in serialLibAtts:
189 attName = serAtt["name"]
190 attType = serAtt["dType"]
191 if (attName not in tosaArgNames) and (not attType == "tosa::DType"):
192 serAttName = serAtt["name"]
193 if serAtt["SV"] == "V":
Grant Watson64285a12022-11-16 15:32:39 +0000194 # For vector data types, insert a matching length argument
195 tosaArgs.insert(
196 len(tosaArgs) - 1,
197 {
Grant Watsoneff70382023-09-12 10:46:36 +0100198 "name": f"{serAttName}_len",
Grant Watson64285a12022-11-16 15:32:39 +0000199 "type": "int32_t",
200 "shape": "",
201 "category": "",
202 },
203 )
Grant Watsoneff70382023-09-12 10:46:36 +0100204 init = f"const std::vector<{attType}> {attName}(&client_{serAttName}[0], &client_{serAttName}[0] + client_{serAttName}_len);"
Grant Watson64285a12022-11-16 15:32:39 +0000205 shape = "[]"
206 else:
Grant Watsoneff70382023-09-12 10:46:36 +0100207 init = ""
Grant Watson64285a12022-11-16 15:32:39 +0000208 shape = ""
Grant Watsoneff70382023-09-12 10:46:36 +0100209 serAtt["init"] = init
Grant Watson64285a12022-11-16 15:32:39 +0000210 # Insert new argument
211 tosaArgs.insert(
212 len(tosaArgs) - 1,
213 {
Grant Watsoneff70382023-09-12 10:46:36 +0100214 "name": serAttName,
215 "type": serAtt["dType"],
Grant Watson64285a12022-11-16 15:32:39 +0000216 "shape": shape,
217 "category": "",
218 },
219 )
220
221
222def getOperators(tosaXml):
223 """
224 Return a list of TOSA operators as defined by tosa.xml.
225 """
226 operators = []
Grant Watsoneff70382023-09-12 10:46:36 +0100227 ignoreOps = [
228 "while_loop",
229 "cond_if",
230 "const",
231 "custom",
232 "fft2d",
233 "rfft2d",
234 "variable",
235 "variable_read",
236 "variable_write",
237 ]
Grant Watson64285a12022-11-16 15:32:39 +0000238 opsXml = tosaXml.getElementsByTagName("operator")
Grant Watsoneff70382023-09-12 10:46:36 +0100239 allSerialLibAtts = getSerialLibAtts()
Grant Watson64285a12022-11-16 15:32:39 +0000240 for opXml in opsXml:
241 opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower()
242 if opName not in ignoreOps:
243 operator = {"name": opName}
244 operator["serializeAttType"] = getSerializeOpType(opName)
245 tosaArgs = getTosaArgs(opXml)
Grant Watsoneff70382023-09-12 10:46:36 +0100246 serialLibAtts = getSerialLibAttsForOp(opName, allSerialLibAtts, tosaArgs)
Grant Watson61680472023-05-31 14:56:13 +0100247 # Handle "axis" arguments
248 axisList = [arg["name"] for arg in tosaArgs if arg["name"] == "axis"]
249 if operator["serializeAttType"] == "None" and len(axisList) > 0:
250 operator["serializeAttType"] = "Axis"
Grant Watsoneff70382023-09-12 10:46:36 +0100251 serialLibAtts = [
Grant Watson61680472023-05-31 14:56:13 +0100252 {
253 "name": "axis",
254 "dType": "int32_t",
255 "SV": "S",
Grant Watsoneff70382023-09-12 10:46:36 +0100256 "init": "",
Grant Watson61680472023-05-31 14:56:13 +0100257 }
258 ]
Grant Watsoneff70382023-09-12 10:46:36 +0100259 updateTosaArgs(tosaArgs, serialLibAtts, tosaXml)
Grant Watson64285a12022-11-16 15:32:39 +0000260 operator["arguments"] = tosaArgs
Grant Watsoneff70382023-09-12 10:46:36 +0100261 operator["serialLibAtts"] = serialLibAtts
262 serializationAttNames = [att["name"] for att in serialLibAtts]
Grant Watson64285a12022-11-16 15:32:39 +0000263 operator["inputs"] = [
Grant Watsoneff70382023-09-12 10:46:36 +0100264 arg["name"]
265 for arg in tosaArgs
266 if arg["category"] == "input"
267 and arg["name"] not in serializationAttNames
Grant Watson64285a12022-11-16 15:32:39 +0000268 ]
269 operator["outputs"] = [
270 arg["name"] for arg in tosaArgs if arg["category"] == "output"
271 ]
272 operators.append(operator)
273 return operators
274
275
276def getTosaArgs(opXml):
277 """
278 Return the arguments required for the TOSA operator specified.
279 """
280 arguments = []
281 argsXml = opXml.getElementsByTagName("argument")
282 tosaTensorTypes = getTosaArgTypes(tosaXml)
283 tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"}
Dmitrii Agibovc8fdccf2023-09-21 11:05:58 +0100284 tensorElemTypeMap = {
285 "resize_mode_t": "tosa_mode_t",
286 "acc_size_t": "tosa_acc_size_t",
287 }
Grant Watson64285a12022-11-16 15:32:39 +0000288 for xmlArg in argsXml:
289 argName = xmlArg.getAttribute("name").lower()
Dmitrii Agibovc8fdccf2023-09-21 11:05:58 +0100290 tensorElemType = xmlArg.getAttribute("tensor-element-type")
291 if tensorElemType in tensorElemTypeMap:
292 argType = tensorElemTypeMap[tensorElemType]
Grant Watsoneb741062023-06-23 16:52:12 +0100293 else:
294 argType = xmlArg.getAttribute("type")
Grant Watson64285a12022-11-16 15:32:39 +0000295 argShape = xmlArg.getAttribute("shape")
296 argCategory = xmlArg.getAttribute("category")
Grant Watsone70d9312023-08-28 16:34:28 +0100297 # FullyConnected workaround
298 if (argName == "weight" or argName == "bias") and (argCategory == "attribute"):
299 argCategory = "input"
Grant Watson64285a12022-11-16 15:32:39 +0000300 # Update argument type
301 if argType[-1:] == "*":
302 argType = argType[:-1]
303 if argCategory in ["input", "output"] and argType in tosaTensorTypes:
304 argType = "tosa_tensor_t"
305 argShape = ""
306 if argType in tosaTypeMap:
307 argType = tosaTypeMap[argType]
308 # Add a length argument for arrays with unknown compile-time size
309 if argShape != "" and argShape[0] == "[" and not argShape[1:-1].isnumeric():
310 argShape = "[]"
311 arguments.append(
312 {
313 "name": f"{argName}_len",
314 "type": "int32_t",
315 "shape": "",
316 "category": "",
317 }
318 )
319 elif argShape == "" or not argShape[0] == "[":
320 argShape = ""
321 # Append argument
322 arguments.append(
323 {
324 "name": argName,
325 "type": argType,
326 "shape": argShape,
327 "category": argCategory,
328 }
329 )
330 return arguments
331
332
333def clangFormat(filename):
334 cmd = ["clang-format", "-i", filename]
335 with open(os.devnull, "w") as devnull:
336 subprocess.check_call(cmd, stdout=devnull)
337
338
Grant Watsoneff70382023-09-12 10:46:36 +0100339def getSerialLibAtts():
Grant Watson64285a12022-11-16 15:32:39 +0000340 """
341 Parse attribute.def file and return a dictionary where the keys are Serialization library operator names.
342 The values are the arguments required by each Serialization library operator.
343 """
Grant Watsoneff70382023-09-12 10:46:36 +0100344 serialLibAtts = {}
Eric Kunze99f8f9f2023-09-07 01:36:07 +0000345 base_path = getBasePath()
346 attr_def = (
347 base_path / "thirdparty" / "serialization_lib" / "include" / "attribute.def"
348 )
349 with open(attr_def) as file:
Grant Watson64285a12022-11-16 15:32:39 +0000350 preamble = True
351 inAtt = False
352 opName = ""
353 args = []
354 for line in file:
355 if preamble and not line[: len("DEF_ATTRIBUTE(")] == "DEF_ATTRIBUTE(":
356 continue
357 else:
358 preamble = False
359 line = line.lstrip().rstrip()
360 if not inAtt and "DEF_ATTRIBUTE(" in line:
361 opName = line[len("DEF_ATTRIBUTE(") : line.find(",")]
362 inAtt = True
363 elif inAtt:
364 vals = line.split(",")
365 argName = vals[2].lstrip().strip()
366 if ")" in argName:
367 argName = argName[:-1]
368 arg = {
369 "name": argName,
370 "dType": vals[0].lstrip().strip(),
371 "SV": vals[1].lstrip().strip(),
372 }
373 args.append(arg)
374 if ")" in line:
Grant Watsoneff70382023-09-12 10:46:36 +0100375 serialLibAtts[opName] = args
Grant Watson64285a12022-11-16 15:32:39 +0000376 opName = ""
377 args = []
378 inAtt = False
Grant Watsoneff70382023-09-12 10:46:36 +0100379 return serialLibAtts
Grant Watson64285a12022-11-16 15:32:39 +0000380
381
382def renderTemplate(environment, dataTypes, operators, template, outfile):
383 content = template.render(dataTypes=dataTypes, operators=operators)
384 with open(outfile, mode="w", encoding="utf-8") as output:
385 output.write(content)
386 print(f"Created {outfile}")
387
388 clangFormat(outfile)
389
390
Eric Kunze99f8f9f2023-09-07 01:36:07 +0000391def generate(environment, dataTypes, operators, base_path):
Grant Watson64285a12022-11-16 15:32:39 +0000392 # Generate include/operators.h
393 template = environment.get_template("operators_h.j2")
Eric Kunze99f8f9f2023-09-07 01:36:07 +0000394 outfile = base_path / "reference_model/include/operators.h"
Grant Watson64285a12022-11-16 15:32:39 +0000395 renderTemplate(environment, dataTypes, operators, template, outfile)
396
397 # Generate src/operators.cc
398 template = environment.get_template("operators_cc.j2")
Eric Kunze99f8f9f2023-09-07 01:36:07 +0000399 outfile = base_path / "reference_model/src/operators.cc"
Grant Watson64285a12022-11-16 15:32:39 +0000400 renderTemplate(environment, dataTypes, operators, template, outfile)
401
402
Grant Watson64285a12022-11-16 15:32:39 +0000403if __name__ == "__main__":
Eric Kunze99f8f9f2023-09-07 01:36:07 +0000404 base_path = getBasePath()
405 environment = Environment(
406 loader=FileSystemLoader(Path(__file__).resolve().parent / "templates")
407 )
408 tosaXml = minidom.parse(str(base_path / "thirdparty/specification/tosa.xml"))
Grant Watson64285a12022-11-16 15:32:39 +0000409 dataTypes = getTosaDataTypes(tosaXml)
410 operators = getOperators(tosaXml)
Eric Kunze99f8f9f2023-09-07 01:36:07 +0000411 generate(environment, dataTypes, operators, base_path)