blob: f1cb6e03be76574d312e9ac27cf1d461ba8537ea [file] [log] [blame]
Grant Watson64285a12022-11-16 15:32:39 +00001"""Generate extended reference model API with eager operator execution entrypoints"""
Grant Watsone70d9312023-08-28 16:34:28 +01002# Copyright (c) 2021-2023, ARM Limited.
Grant Watson64285a12022-11-16 15:32:39 +00003# SPDX-License-Identifier: Apache-2.0
4import copy
5import os
6import subprocess
7from xml.dom import minidom
8
9from jinja2 import Environment
10from jinja2 import FileSystemLoader
11
James Wardd34b3fc2023-01-18 14:51:25 +000012# Note: main script designed to be run from the scripts/operator_api/ directory
13
Grant Watson64285a12022-11-16 15:32:39 +000014
15def getTosaArgTypes(tosaXml):
16 """
17 Returns a list of the TOSA argument types from tosa.xml.
18 """
Grant Watsoneb741062023-06-23 16:52:12 +010019 argTypes = {
20 "tensor_t",
21 "in_t",
22 "out_t",
23 "mul_t",
24 "weight_t",
25 "in_out_t",
26 "tensor_list_t",
27 }
Grant Watson64285a12022-11-16 15:32:39 +000028 argTypesXml = tosaXml.getElementsByTagName("type")
29 for argTypeXml in argTypesXml:
30 argTypes.add(argTypeXml.getAttribute("name"))
31 argTypes.remove("TABLE_SIZE")
32 return argTypes
33
34
35def getTosaDataTypes(tosaXml):
36 """
37 Returns a list of the TOSA data types from tosa.xml.
38 """
39 argTypes = getTosaArgTypes(tosaXml)
40 dataTypes = set()
41 dataTypesXml = tosaXml.getElementsByTagName("typesupport")
42 for dataTypeXml in dataTypesXml:
43 for argType in argTypes:
44 dataType = dataTypeXml.getAttribute(argType)
45 if dataType != "":
46 dataTypes.add(f"tosa_datatype_{dataType}")
47 return sorted(dataTypes)
48
49
50def getSerializeOpType(tosaOpName):
51 """
52 Returns the Serialization library operator that matches the TOSA operator specified.
53 """
54 map = {
55 "avg_pool2d": "Pool",
56 "conv2d": "Conv",
57 "conv3d": "Conv",
58 "depthwise_conv2d": "Conv",
59 "fully_connected": "FullyConnected",
60 "matmul": "MatMul",
61 "max_pool2d": "Pool",
62 "transpose_conv2d": "Conv",
63 "clamp": "Clamp",
64 "arithmetic_right_shift": "ArithmeticRightShift",
65 "mul": "Mul",
66 "table": "Table",
67 "negate": "Negate",
68 "pad": "Pad",
69 "reshape": "Reshape",
70 "slice": "Slice",
71 "tile": "Tile",
72 "transpose": "Transpose",
73 "resize": "Resize",
74 "rescale": "Rescale",
75 "cond_if": "CondIf",
76 "while_loop": "WhileLoop",
77 }
78 if tosaOpName not in map.keys():
79 return "None"
80 else:
81 return map[tosaOpName]
82
83
84def getSerializeArgsForOp(tosaOpName, allSerializeArgs, tosaArgs):
85 """
86 Returns the arguments required by the Serialization library for the TOSA operator specified.
87 Generates code to initialize Serialization arguments. If a matching TOSA argument exists,
88 that value is used for initialization, otherwise a default value e.g. 0 is used.
89 """
90 serOpType = getSerializeOpType(tosaOpName)
91 if serOpType not in allSerializeArgs.keys():
92 return {}
93 else:
94 serOpArgs = copy.deepcopy(allSerializeArgs[serOpType])
95 tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
96 serTosaTypeMap = {"ResizeMode": "tosa_mode"}
97 for arg in serOpArgs:
98 argName = arg["name"]
99 init = ""
100 # Translate TOSA data types to Serialization data types for initialization
101 if arg["dType"] in serTosaTypeMap.keys():
102 init = f" = translate_client_{serTosaTypeMap[arg['dType']]}(client_{argName})"
103 # Initialize Serialization arguments to their matching function parameter
104 elif argName in tosaArgsDict:
105 if arg["SV"] == "V":
106 shape = tosaArgsDict[argName]["shape"]
107 if shape == "[]":
108 init = f"(&client_{argName}[0], &client_{argName}[0] + client_{argName}_len)"
109 else:
110 init = f"(&client_{argName}[0], &client_{argName}{shape})"
111 else:
112 init = f" = client_{argName}"
113 else:
114 # Initialize Serialization arguments with no matching fuction parameter
115 if arg["SV"] == "V":
116 init = ""
117 else:
118 if arg["dType"] == "DType":
119 arg["dType"] = "tosa::DType"
120 init = " = tosa::DType::DType_FP32"
121 else:
122 init = " = 0"
123 arg["init"] = init
124 return serOpArgs
125
126
127def updateTosaArgs(tosaArgs, serializeArgs, tosaXml):
128 """
129 Replace TOSA argument data types with their matching Serialization argument data types.
130 Delete TOSA arguments where the type couldn't be determined.
131 Add Serialization arguments that have no matching TOSA argument.
132 """
133 tosaArgTypes = getTosaArgTypes(tosaXml)
134 serArgsDict = {arg["name"]: arg for arg in serializeArgs}
135 tosaArgsNames = [arg["name"] for arg in tosaArgs]
136 delTosaArgs = []
137 # Replace TOSA argument data types with their matching Serialization argument data types.
138 for tosaArg in tosaArgs:
139 if tosaArg["type"] in tosaArgTypes:
140 if tosaArg["name"] in serArgsDict:
141 tosaArg["type"] = serArgsDict[tosaArg["name"]]["dType"]
142 else:
143 # Delete TOSA argument whose data type can't be determined
144 delTosaArgs.append(tosaArgsNames.index(tosaArg["name"]))
145 # Delete corresponding length argument if one exists
146 lenArgName = f"{tosaArg['name']}_len"
147 if lenArgName in tosaArgsNames:
148 delTosaArgs.append(tosaArgsNames.index(lenArgName))
149 # Delete TOSA arguments where the type couldn't be determined
150 for index in sorted(delTosaArgs, key=int, reverse=True):
151 del tosaArgs[index]
152 # Add Serialization arguments that have no matching TOSA argument
153 tosaArgNames = [arg["name"] for arg in tosaArgs]
154 for serArg in serializeArgs:
155 if (serArg["name"] not in tosaArgNames) and (
156 not serArg["dType"] == "tosa::DType"
157 ):
158 serArgName = serArg["name"]
159 if serArg["SV"] == "V":
160 # For vector data types, insert a matching length argument
161 tosaArgs.insert(
162 len(tosaArgs) - 1,
163 {
164 "name": f"{serArgName}_len",
165 "type": "int32_t",
166 "shape": "",
167 "category": "",
168 },
169 )
170 init = f"(&client_{serArgName}[0], &client_{serArgName}[0] + client_{serArgName}_len)"
171 shape = "[]"
172 else:
173 init = f" = client_{serArg['name']}"
174 shape = ""
175 serArg["init"] = init
176 # Insert new argument
177 tosaArgs.insert(
178 len(tosaArgs) - 1,
179 {
180 "name": serArgName,
181 "type": serArg["dType"],
182 "shape": shape,
183 "category": "",
184 },
185 )
186
187
188def getOperators(tosaXml):
189 """
190 Return a list of TOSA operators as defined by tosa.xml.
191 """
192 operators = []
Grant Watsone70d9312023-08-28 16:34:28 +0100193 ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d"]
Grant Watson64285a12022-11-16 15:32:39 +0000194 opsXml = tosaXml.getElementsByTagName("operator")
195 allSerializeArgs = getSerializeArgs()
196 for opXml in opsXml:
197 opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower()
198 if opName not in ignoreOps:
199 operator = {"name": opName}
200 operator["serializeAttType"] = getSerializeOpType(opName)
201 tosaArgs = getTosaArgs(opXml)
202 serializeArgs = getSerializeArgsForOp(opName, allSerializeArgs, tosaArgs)
Grant Watson61680472023-05-31 14:56:13 +0100203 # Handle "axis" arguments
204 axisList = [arg["name"] for arg in tosaArgs if arg["name"] == "axis"]
205 if operator["serializeAttType"] == "None" and len(axisList) > 0:
206 operator["serializeAttType"] = "Axis"
207 serializeArgs = [
208 {
209 "name": "axis",
210 "dType": "int32_t",
211 "SV": "S",
212 "init": "= client_axis",
213 }
214 ]
Grant Watson64285a12022-11-16 15:32:39 +0000215 updateTosaArgs(tosaArgs, serializeArgs, tosaXml)
216 operator["arguments"] = tosaArgs
217 operator["serializeArgs"] = serializeArgs
218 operator["inputs"] = [
219 arg["name"] for arg in tosaArgs if arg["category"] == "input"
220 ]
221 operator["outputs"] = [
222 arg["name"] for arg in tosaArgs if arg["category"] == "output"
223 ]
224 operators.append(operator)
225 return operators
226
227
228def getTosaArgs(opXml):
229 """
230 Return the arguments required for the TOSA operator specified.
231 """
232 arguments = []
233 argsXml = opXml.getElementsByTagName("argument")
234 tosaTensorTypes = getTosaArgTypes(tosaXml)
235 tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"}
236 for xmlArg in argsXml:
237 argName = xmlArg.getAttribute("name").lower()
Grant Watsoneb741062023-06-23 16:52:12 +0100238 if xmlArg.getAttribute("tensor-element-type") == "resize_mode_t":
239 argType = "tosa_mode_t"
240 else:
241 argType = xmlArg.getAttribute("type")
Grant Watson64285a12022-11-16 15:32:39 +0000242 argShape = xmlArg.getAttribute("shape")
243 argCategory = xmlArg.getAttribute("category")
Grant Watsone70d9312023-08-28 16:34:28 +0100244 # FullyConnected workaround
245 if (argName == "weight" or argName == "bias") and (argCategory == "attribute"):
246 argCategory = "input"
Grant Watson64285a12022-11-16 15:32:39 +0000247 # Update argument type
248 if argType[-1:] == "*":
249 argType = argType[:-1]
250 if argCategory in ["input", "output"] and argType in tosaTensorTypes:
251 argType = "tosa_tensor_t"
252 argShape = ""
253 if argType in tosaTypeMap:
254 argType = tosaTypeMap[argType]
255 # Add a length argument for arrays with unknown compile-time size
256 if argShape != "" and argShape[0] == "[" and not argShape[1:-1].isnumeric():
257 argShape = "[]"
258 arguments.append(
259 {
260 "name": f"{argName}_len",
261 "type": "int32_t",
262 "shape": "",
263 "category": "",
264 }
265 )
266 elif argShape == "" or not argShape[0] == "[":
267 argShape = ""
268 # Append argument
269 arguments.append(
270 {
271 "name": argName,
272 "type": argType,
273 "shape": argShape,
274 "category": argCategory,
275 }
276 )
277 return arguments
278
279
280def clangFormat(filename):
281 cmd = ["clang-format", "-i", filename]
282 with open(os.devnull, "w") as devnull:
283 subprocess.check_call(cmd, stdout=devnull)
284
285
286def getSerializeArgs():
287 """
288 Parse attribute.def file and return a dictionary where the keys are Serialization library operator names.
289 The values are the arguments required by each Serialization library operator.
290 """
291 serializeArgs = {}
292 with open("../../thirdparty/serialization_lib/include/attribute.def") as file:
293 preamble = True
294 inAtt = False
295 opName = ""
296 args = []
297 for line in file:
298 if preamble and not line[: len("DEF_ATTRIBUTE(")] == "DEF_ATTRIBUTE(":
299 continue
300 else:
301 preamble = False
302 line = line.lstrip().rstrip()
303 if not inAtt and "DEF_ATTRIBUTE(" in line:
304 opName = line[len("DEF_ATTRIBUTE(") : line.find(",")]
305 inAtt = True
306 elif inAtt:
307 vals = line.split(",")
308 argName = vals[2].lstrip().strip()
309 if ")" in argName:
310 argName = argName[:-1]
311 arg = {
312 "name": argName,
313 "dType": vals[0].lstrip().strip(),
314 "SV": vals[1].lstrip().strip(),
315 }
316 args.append(arg)
317 if ")" in line:
318 serializeArgs[opName] = args
319 opName = ""
320 args = []
321 inAtt = False
322 return serializeArgs
323
324
325def renderTemplate(environment, dataTypes, operators, template, outfile):
326 content = template.render(dataTypes=dataTypes, operators=operators)
327 with open(outfile, mode="w", encoding="utf-8") as output:
328 output.write(content)
329 print(f"Created {outfile}")
330
331 clangFormat(outfile)
332
333
334def generate(environment, dataTypes, operators):
335 # Generate include/operators.h
336 template = environment.get_template("operators_h.j2")
337 outfile = os.path.join("..", "..", "reference_model", "include", "operators.h")
338 renderTemplate(environment, dataTypes, operators, template, outfile)
339
340 # Generate src/operators.cc
341 template = environment.get_template("operators_cc.j2")
342 outfile = os.path.join("..", "..", "reference_model", "src", "operators.cc")
343 renderTemplate(environment, dataTypes, operators, template, outfile)
344
345
346def getSerializeOpTypeMap():
347 """
348 Utility function for generating the map used in getSerializeOpType()
349 """
350 import re
351
352 allSerializeArgs = getSerializeArgs()
353 serArgs = [
354 re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
355 for name in allSerializeArgs.keys()
356 ]
357 serArgs = sorted(serArgs, key=len, reverse=True)
358 tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
359 opsXml = tosaXml.getElementsByTagName("operator")
360 opNames = [
361 op.getElementsByTagName("name")[0].firstChild.data.lower() for op in opsXml
362 ]
363 map = {}
364 for opName in opNames:
365 for serArg in serArgs:
366 if serArg in opName:
367 components = serArg.split("_")
368 map[opName] = "".join(x.title() for x in components)
369 return map
370
371
372if __name__ == "__main__":
373 environment = Environment(loader=FileSystemLoader("templates/"))
374 tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
375 dataTypes = getTosaDataTypes(tosaXml)
376 operators = getOperators(tosaXml)
377 generate(environment, dataTypes, operators)