Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 1 | """Generate extended reference model API with eager operator execution entrypoints""" |
| 2 | # Copyright (c) 2021-2022, ARM Limited. |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | import copy |
| 5 | import os |
| 6 | import subprocess |
| 7 | from xml.dom import minidom |
| 8 | |
| 9 | from jinja2 import Environment |
| 10 | from jinja2 import FileSystemLoader |
| 11 | |
| 12 | |
| 13 | def getTosaArgTypes(tosaXml): |
| 14 | """ |
| 15 | Returns a list of the TOSA argument types from tosa.xml. |
| 16 | """ |
| 17 | argTypes = {"in_t", "out_t", "mul_t", "weight_t", "in_out_t"} |
| 18 | argTypesXml = tosaXml.getElementsByTagName("type") |
| 19 | for argTypeXml in argTypesXml: |
| 20 | argTypes.add(argTypeXml.getAttribute("name")) |
| 21 | argTypes.remove("TABLE_SIZE") |
| 22 | return argTypes |
| 23 | |
| 24 | |
| 25 | def getTosaDataTypes(tosaXml): |
| 26 | """ |
| 27 | Returns a list of the TOSA data types from tosa.xml. |
| 28 | """ |
| 29 | argTypes = getTosaArgTypes(tosaXml) |
| 30 | dataTypes = set() |
| 31 | dataTypesXml = tosaXml.getElementsByTagName("typesupport") |
| 32 | for dataTypeXml in dataTypesXml: |
| 33 | for argType in argTypes: |
| 34 | dataType = dataTypeXml.getAttribute(argType) |
| 35 | if dataType != "": |
| 36 | dataTypes.add(f"tosa_datatype_{dataType}") |
| 37 | return sorted(dataTypes) |
| 38 | |
| 39 | |
| 40 | def getSerializeOpType(tosaOpName): |
| 41 | """ |
| 42 | Returns the Serialization library operator that matches the TOSA operator specified. |
| 43 | """ |
| 44 | map = { |
| 45 | "avg_pool2d": "Pool", |
| 46 | "conv2d": "Conv", |
| 47 | "conv3d": "Conv", |
| 48 | "depthwise_conv2d": "Conv", |
| 49 | "fully_connected": "FullyConnected", |
| 50 | "matmul": "MatMul", |
| 51 | "max_pool2d": "Pool", |
| 52 | "transpose_conv2d": "Conv", |
| 53 | "clamp": "Clamp", |
| 54 | "arithmetic_right_shift": "ArithmeticRightShift", |
| 55 | "mul": "Mul", |
| 56 | "table": "Table", |
| 57 | "negate": "Negate", |
| 58 | "pad": "Pad", |
| 59 | "reshape": "Reshape", |
| 60 | "slice": "Slice", |
| 61 | "tile": "Tile", |
| 62 | "transpose": "Transpose", |
| 63 | "resize": "Resize", |
| 64 | "rescale": "Rescale", |
| 65 | "cond_if": "CondIf", |
| 66 | "while_loop": "WhileLoop", |
| 67 | } |
| 68 | if tosaOpName not in map.keys(): |
| 69 | return "None" |
| 70 | else: |
| 71 | return map[tosaOpName] |
| 72 | |
| 73 | |
| 74 | def getSerializeArgsForOp(tosaOpName, allSerializeArgs, tosaArgs): |
| 75 | """ |
| 76 | Returns the arguments required by the Serialization library for the TOSA operator specified. |
| 77 | Generates code to initialize Serialization arguments. If a matching TOSA argument exists, |
| 78 | that value is used for initialization, otherwise a default value e.g. 0 is used. |
| 79 | """ |
| 80 | serOpType = getSerializeOpType(tosaOpName) |
| 81 | if serOpType not in allSerializeArgs.keys(): |
| 82 | return {} |
| 83 | else: |
| 84 | serOpArgs = copy.deepcopy(allSerializeArgs[serOpType]) |
| 85 | tosaArgsDict = {arg["name"]: arg for arg in tosaArgs} |
| 86 | serTosaTypeMap = {"ResizeMode": "tosa_mode"} |
| 87 | for arg in serOpArgs: |
| 88 | argName = arg["name"] |
| 89 | init = "" |
| 90 | # Translate TOSA data types to Serialization data types for initialization |
| 91 | if arg["dType"] in serTosaTypeMap.keys(): |
| 92 | init = f" = translate_client_{serTosaTypeMap[arg['dType']]}(client_{argName})" |
| 93 | # Initialize Serialization arguments to their matching function parameter |
| 94 | elif argName in tosaArgsDict: |
| 95 | if arg["SV"] == "V": |
| 96 | shape = tosaArgsDict[argName]["shape"] |
| 97 | if shape == "[]": |
| 98 | init = f"(&client_{argName}[0], &client_{argName}[0] + client_{argName}_len)" |
| 99 | else: |
| 100 | init = f"(&client_{argName}[0], &client_{argName}{shape})" |
| 101 | else: |
| 102 | init = f" = client_{argName}" |
| 103 | else: |
| 104 | # Initialize Serialization arguments with no matching fuction parameter |
| 105 | if arg["SV"] == "V": |
| 106 | init = "" |
| 107 | else: |
| 108 | if arg["dType"] == "DType": |
| 109 | arg["dType"] = "tosa::DType" |
| 110 | init = " = tosa::DType::DType_FP32" |
| 111 | else: |
| 112 | init = " = 0" |
| 113 | arg["init"] = init |
| 114 | return serOpArgs |
| 115 | |
| 116 | |
| 117 | def updateTosaArgs(tosaArgs, serializeArgs, tosaXml): |
| 118 | """ |
| 119 | Replace TOSA argument data types with their matching Serialization argument data types. |
| 120 | Delete TOSA arguments where the type couldn't be determined. |
| 121 | Add Serialization arguments that have no matching TOSA argument. |
| 122 | """ |
| 123 | tosaArgTypes = getTosaArgTypes(tosaXml) |
| 124 | serArgsDict = {arg["name"]: arg for arg in serializeArgs} |
| 125 | tosaArgsNames = [arg["name"] for arg in tosaArgs] |
| 126 | delTosaArgs = [] |
| 127 | # Replace TOSA argument data types with their matching Serialization argument data types. |
| 128 | for tosaArg in tosaArgs: |
| 129 | if tosaArg["type"] in tosaArgTypes: |
| 130 | if tosaArg["name"] in serArgsDict: |
| 131 | tosaArg["type"] = serArgsDict[tosaArg["name"]]["dType"] |
| 132 | else: |
| 133 | # Delete TOSA argument whose data type can't be determined |
| 134 | delTosaArgs.append(tosaArgsNames.index(tosaArg["name"])) |
| 135 | # Delete corresponding length argument if one exists |
| 136 | lenArgName = f"{tosaArg['name']}_len" |
| 137 | if lenArgName in tosaArgsNames: |
| 138 | delTosaArgs.append(tosaArgsNames.index(lenArgName)) |
| 139 | # Delete TOSA arguments where the type couldn't be determined |
| 140 | for index in sorted(delTosaArgs, key=int, reverse=True): |
| 141 | del tosaArgs[index] |
| 142 | # Add Serialization arguments that have no matching TOSA argument |
| 143 | tosaArgNames = [arg["name"] for arg in tosaArgs] |
| 144 | for serArg in serializeArgs: |
| 145 | if (serArg["name"] not in tosaArgNames) and ( |
| 146 | not serArg["dType"] == "tosa::DType" |
| 147 | ): |
| 148 | serArgName = serArg["name"] |
| 149 | if serArg["SV"] == "V": |
| 150 | # For vector data types, insert a matching length argument |
| 151 | tosaArgs.insert( |
| 152 | len(tosaArgs) - 1, |
| 153 | { |
| 154 | "name": f"{serArgName}_len", |
| 155 | "type": "int32_t", |
| 156 | "shape": "", |
| 157 | "category": "", |
| 158 | }, |
| 159 | ) |
| 160 | init = f"(&client_{serArgName}[0], &client_{serArgName}[0] + client_{serArgName}_len)" |
| 161 | shape = "[]" |
| 162 | else: |
| 163 | init = f" = client_{serArg['name']}" |
| 164 | shape = "" |
| 165 | serArg["init"] = init |
| 166 | # Insert new argument |
| 167 | tosaArgs.insert( |
| 168 | len(tosaArgs) - 1, |
| 169 | { |
| 170 | "name": serArgName, |
| 171 | "type": serArg["dType"], |
| 172 | "shape": shape, |
| 173 | "category": "", |
| 174 | }, |
| 175 | ) |
| 176 | |
| 177 | |
| 178 | def getOperators(tosaXml): |
| 179 | """ |
| 180 | Return a list of TOSA operators as defined by tosa.xml. |
| 181 | """ |
| 182 | operators = [] |
| 183 | ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d"] |
| 184 | opsXml = tosaXml.getElementsByTagName("operator") |
| 185 | allSerializeArgs = getSerializeArgs() |
| 186 | for opXml in opsXml: |
| 187 | opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower() |
| 188 | if opName not in ignoreOps: |
| 189 | operator = {"name": opName} |
| 190 | operator["serializeAttType"] = getSerializeOpType(opName) |
| 191 | tosaArgs = getTosaArgs(opXml) |
| 192 | serializeArgs = getSerializeArgsForOp(opName, allSerializeArgs, tosaArgs) |
| 193 | updateTosaArgs(tosaArgs, serializeArgs, tosaXml) |
| 194 | operator["arguments"] = tosaArgs |
| 195 | operator["serializeArgs"] = serializeArgs |
| 196 | operator["inputs"] = [ |
| 197 | arg["name"] for arg in tosaArgs if arg["category"] == "input" |
| 198 | ] |
| 199 | operator["outputs"] = [ |
| 200 | arg["name"] for arg in tosaArgs if arg["category"] == "output" |
| 201 | ] |
| 202 | operators.append(operator) |
| 203 | return operators |
| 204 | |
| 205 | |
| 206 | def getTosaArgs(opXml): |
| 207 | """ |
| 208 | Return the arguments required for the TOSA operator specified. |
| 209 | """ |
| 210 | arguments = [] |
| 211 | argsXml = opXml.getElementsByTagName("argument") |
| 212 | tosaTensorTypes = getTosaArgTypes(tosaXml) |
| 213 | tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"} |
| 214 | for xmlArg in argsXml: |
| 215 | argName = xmlArg.getAttribute("name").lower() |
| 216 | argType = xmlArg.getAttribute("type") |
| 217 | argShape = xmlArg.getAttribute("shape") |
| 218 | argCategory = xmlArg.getAttribute("category") |
| 219 | # Update argument type |
| 220 | if argType[-1:] == "*": |
| 221 | argType = argType[:-1] |
| 222 | if argCategory in ["input", "output"] and argType in tosaTensorTypes: |
| 223 | argType = "tosa_tensor_t" |
| 224 | argShape = "" |
| 225 | if argType in tosaTypeMap: |
| 226 | argType = tosaTypeMap[argType] |
| 227 | # Add a length argument for arrays with unknown compile-time size |
| 228 | if argShape != "" and argShape[0] == "[" and not argShape[1:-1].isnumeric(): |
| 229 | argShape = "[]" |
| 230 | arguments.append( |
| 231 | { |
| 232 | "name": f"{argName}_len", |
| 233 | "type": "int32_t", |
| 234 | "shape": "", |
| 235 | "category": "", |
| 236 | } |
| 237 | ) |
| 238 | elif argShape == "" or not argShape[0] == "[": |
| 239 | argShape = "" |
| 240 | # Append argument |
| 241 | arguments.append( |
| 242 | { |
| 243 | "name": argName, |
| 244 | "type": argType, |
| 245 | "shape": argShape, |
| 246 | "category": argCategory, |
| 247 | } |
| 248 | ) |
| 249 | return arguments |
| 250 | |
| 251 | |
| 252 | def clangFormat(filename): |
| 253 | cmd = ["clang-format", "-i", filename] |
| 254 | with open(os.devnull, "w") as devnull: |
| 255 | subprocess.check_call(cmd, stdout=devnull) |
| 256 | |
| 257 | |
| 258 | def getSerializeArgs(): |
| 259 | """ |
| 260 | Parse attribute.def file and return a dictionary where the keys are Serialization library operator names. |
| 261 | The values are the arguments required by each Serialization library operator. |
| 262 | """ |
| 263 | serializeArgs = {} |
| 264 | with open("../../thirdparty/serialization_lib/include/attribute.def") as file: |
| 265 | preamble = True |
| 266 | inAtt = False |
| 267 | opName = "" |
| 268 | args = [] |
| 269 | for line in file: |
| 270 | if preamble and not line[: len("DEF_ATTRIBUTE(")] == "DEF_ATTRIBUTE(": |
| 271 | continue |
| 272 | else: |
| 273 | preamble = False |
| 274 | line = line.lstrip().rstrip() |
| 275 | if not inAtt and "DEF_ATTRIBUTE(" in line: |
| 276 | opName = line[len("DEF_ATTRIBUTE(") : line.find(",")] |
| 277 | inAtt = True |
| 278 | elif inAtt: |
| 279 | vals = line.split(",") |
| 280 | argName = vals[2].lstrip().strip() |
| 281 | if ")" in argName: |
| 282 | argName = argName[:-1] |
| 283 | arg = { |
| 284 | "name": argName, |
| 285 | "dType": vals[0].lstrip().strip(), |
| 286 | "SV": vals[1].lstrip().strip(), |
| 287 | } |
| 288 | args.append(arg) |
| 289 | if ")" in line: |
| 290 | serializeArgs[opName] = args |
| 291 | opName = "" |
| 292 | args = [] |
| 293 | inAtt = False |
| 294 | return serializeArgs |
| 295 | |
| 296 | |
| 297 | def renderTemplate(environment, dataTypes, operators, template, outfile): |
| 298 | content = template.render(dataTypes=dataTypes, operators=operators) |
| 299 | with open(outfile, mode="w", encoding="utf-8") as output: |
| 300 | output.write(content) |
| 301 | print(f"Created {outfile}") |
| 302 | |
| 303 | clangFormat(outfile) |
| 304 | |
| 305 | |
| 306 | def generate(environment, dataTypes, operators): |
| 307 | # Generate include/operators.h |
| 308 | template = environment.get_template("operators_h.j2") |
| 309 | outfile = os.path.join("..", "..", "reference_model", "include", "operators.h") |
| 310 | renderTemplate(environment, dataTypes, operators, template, outfile) |
| 311 | |
| 312 | # Generate src/operators.cc |
| 313 | template = environment.get_template("operators_cc.j2") |
| 314 | outfile = os.path.join("..", "..", "reference_model", "src", "operators.cc") |
| 315 | renderTemplate(environment, dataTypes, operators, template, outfile) |
| 316 | |
| 317 | |
| 318 | def getSerializeOpTypeMap(): |
| 319 | """ |
| 320 | Utility function for generating the map used in getSerializeOpType() |
| 321 | """ |
| 322 | import re |
| 323 | |
| 324 | allSerializeArgs = getSerializeArgs() |
| 325 | serArgs = [ |
| 326 | re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower() |
| 327 | for name in allSerializeArgs.keys() |
| 328 | ] |
| 329 | serArgs = sorted(serArgs, key=len, reverse=True) |
| 330 | tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml") |
| 331 | opsXml = tosaXml.getElementsByTagName("operator") |
| 332 | opNames = [ |
| 333 | op.getElementsByTagName("name")[0].firstChild.data.lower() for op in opsXml |
| 334 | ] |
| 335 | map = {} |
| 336 | for opName in opNames: |
| 337 | for serArg in serArgs: |
| 338 | if serArg in opName: |
| 339 | components = serArg.split("_") |
| 340 | map[opName] = "".join(x.title() for x in components) |
| 341 | return map |
| 342 | |
| 343 | |
| 344 | if __name__ == "__main__": |
| 345 | environment = Environment(loader=FileSystemLoader("templates/")) |
| 346 | tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml") |
| 347 | dataTypes = getTosaDataTypes(tosaXml) |
| 348 | operators = getOperators(tosaXml) |
| 349 | generate(environment, dataTypes, operators) |