Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 1 | """Generate extended reference model API with eager operator execution entrypoints""" |
Grant Watson | e70d931 | 2023-08-28 16:34:28 +0100 | [diff] [blame] | 2 | # Copyright (c) 2021-2023, ARM Limited. |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | import copy |
| 5 | import os |
| 6 | import subprocess |
Eric Kunze | 99f8f9f | 2023-09-07 01:36:07 +0000 | [diff] [blame] | 7 | from pathlib import Path |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 8 | from xml.dom import minidom |
| 9 | |
| 10 | from jinja2 import Environment |
| 11 | from jinja2 import FileSystemLoader |
| 12 | |
James Ward | d34b3fc | 2023-01-18 14:51:25 +0000 | [diff] [blame] | 13 | # Note: main script designed to be run from the scripts/operator_api/ directory |
| 14 | |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 15 | |
Eric Kunze | 99f8f9f | 2023-09-07 01:36:07 +0000 | [diff] [blame] | 16 | def getBasePath(): |
| 17 | return Path(__file__).resolve().parent.parent.parent |
| 18 | |
| 19 | |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 20 | def getTosaArgTypes(tosaXml): |
| 21 | """ |
| 22 | Returns a list of the TOSA argument types from tosa.xml. |
| 23 | """ |
Grant Watson | eb74106 | 2023-06-23 16:52:12 +0100 | [diff] [blame] | 24 | argTypes = { |
| 25 | "tensor_t", |
| 26 | "in_t", |
| 27 | "out_t", |
| 28 | "mul_t", |
| 29 | "weight_t", |
| 30 | "in_out_t", |
| 31 | "tensor_list_t", |
| 32 | } |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 33 | argTypesXml = tosaXml.getElementsByTagName("type") |
| 34 | for argTypeXml in argTypesXml: |
| 35 | argTypes.add(argTypeXml.getAttribute("name")) |
| 36 | argTypes.remove("TABLE_SIZE") |
| 37 | return argTypes |
| 38 | |
| 39 | |
| 40 | def getTosaDataTypes(tosaXml): |
| 41 | """ |
| 42 | Returns a list of the TOSA data types from tosa.xml. |
| 43 | """ |
| 44 | argTypes = getTosaArgTypes(tosaXml) |
| 45 | dataTypes = set() |
| 46 | dataTypesXml = tosaXml.getElementsByTagName("typesupport") |
| 47 | for dataTypeXml in dataTypesXml: |
| 48 | for argType in argTypes: |
| 49 | dataType = dataTypeXml.getAttribute(argType) |
| 50 | if dataType != "": |
| 51 | dataTypes.add(f"tosa_datatype_{dataType}") |
| 52 | return sorted(dataTypes) |
| 53 | |
| 54 | |
| 55 | def getSerializeOpType(tosaOpName): |
| 56 | """ |
| 57 | Returns the Serialization library operator that matches the TOSA operator specified. |
| 58 | """ |
| 59 | map = { |
| 60 | "avg_pool2d": "Pool", |
| 61 | "conv2d": "Conv", |
| 62 | "conv3d": "Conv", |
| 63 | "depthwise_conv2d": "Conv", |
| 64 | "fully_connected": "FullyConnected", |
| 65 | "matmul": "MatMul", |
| 66 | "max_pool2d": "Pool", |
| 67 | "transpose_conv2d": "Conv", |
| 68 | "clamp": "Clamp", |
| 69 | "arithmetic_right_shift": "ArithmeticRightShift", |
| 70 | "mul": "Mul", |
| 71 | "table": "Table", |
| 72 | "negate": "Negate", |
| 73 | "pad": "Pad", |
| 74 | "reshape": "Reshape", |
| 75 | "slice": "Slice", |
| 76 | "tile": "Tile", |
| 77 | "transpose": "Transpose", |
| 78 | "resize": "Resize", |
| 79 | "rescale": "Rescale", |
| 80 | "cond_if": "CondIf", |
| 81 | "while_loop": "WhileLoop", |
| 82 | } |
| 83 | if tosaOpName not in map.keys(): |
| 84 | return "None" |
| 85 | else: |
| 86 | return map[tosaOpName] |
| 87 | |
| 88 | |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 89 | def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs): |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 90 | """ |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 91 | Returns the attributes required by the Serialization library for the TOSA operator specified. |
| 92 | Generates code to initialize Serialization library attributes. If a matching TOSA argument exists, |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 93 | that value is used for initialization, otherwise a default value e.g. 0 is used. |
| 94 | """ |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 95 | serLibOpType = getSerializeOpType(tosaOpName) |
| 96 | if serLibOpType not in allSerialLibAtts.keys(): |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 97 | return {} |
| 98 | else: |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 99 | serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType]) |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 100 | tosaArgsDict = {arg["name"]: arg for arg in tosaArgs} |
| 101 | serTosaTypeMap = {"ResizeMode": "tosa_mode"} |
Grant Watson | ce53cd1 | 2023-10-31 19:02:14 +0000 | [diff] [blame^] | 102 | # For reshape operator, change 'new_shape' to 'shape' to match tosa.xml |
| 103 | if tosaOpName == "reshape": |
| 104 | serLibOpAtts[0]["name"] = "shape" |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 105 | for att in serLibOpAtts: |
| 106 | attName = att["name"] |
| 107 | attType = att["dType"] |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 108 | init = "" |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 109 | # Translate TOSA data types to Serialization library data types for initialization |
| 110 | if attType in serTosaTypeMap.keys(): |
| 111 | init = f"const {attType} {attName} = translate_client_{serTosaTypeMap[att['dType']]}(client_{attName});" |
| 112 | # Initialize Serialization library attributes to their matching function parameter |
Dmitrii Agibov | c8fdccf | 2023-09-21 11:05:58 +0100 | [diff] [blame] | 113 | elif tosaOpName == "avg_pool2d" and attName == "accum_dtype": |
| 114 | init = f"const tosa::DType {attName} = translate_client_acc_size(client_acc_size);" |
| 115 | att["dType"] = "tosa::DType" |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 116 | elif attName in tosaArgsDict: |
| 117 | if att["SV"] == "V": |
| 118 | if tosaArgsDict[attName]["type"] == "tosa_tensor_t": |
| 119 | init = f"std::vector<{attType}> {attName};" |
| 120 | init = ( |
| 121 | init |
| 122 | + f"size_t {attName}_size = client_{attName}.size / sizeof({attType});" |
| 123 | ) |
| 124 | init = ( |
| 125 | init |
| 126 | + f"{attType}* {attName}_data = reinterpret_cast<{attType}*>(client_{attName}.data);" |
| 127 | ) |
| 128 | init = ( |
| 129 | init |
| 130 | + f"{attName}.assign({attName}_data, {attName}_data + {attName}_size);" |
| 131 | ) |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 132 | else: |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 133 | init = f"const std::vector<{attType}> {attName}" |
| 134 | shape = tosaArgsDict[attName]["shape"] |
| 135 | if shape == "[]": |
| 136 | init = ( |
| 137 | init |
| 138 | + f"(&client_{attName}[0], &client_{attName}[0] + client_{attName}_len);" |
| 139 | ) |
| 140 | else: |
| 141 | init = ( |
| 142 | init |
| 143 | + f"(&client_{attName}[0], &client_{attName}{shape});" |
| 144 | ) |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 145 | else: |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 146 | init = "" |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 147 | else: |
| 148 | # Initialize Serialization library attributes with no matching fuction parameter |
| 149 | if att["SV"] == "V": |
| 150 | init = f"std::vector<int32_t> {attName};" |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 151 | else: |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 152 | if att["dType"] == "DType": |
| 153 | att["dType"] = "tosa::DType" |
| 154 | init = f"const tosa::DType {attName} = tosa::DType::DType_FP32;" |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 155 | else: |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 156 | init = f"const {attType} {attName} = 0;" |
| 157 | att["init"] = init |
| 158 | return serLibOpAtts |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 159 | |
| 160 | |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 161 | def updateTosaArgs(tosaArgs, serialLibAtts, tosaXml): |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 162 | """ |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 163 | Replace TOSA argument data types with their matching Serialization attribute data types. |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 164 | Delete TOSA arguments where the type couldn't be determined. |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 165 | Add Serialization attributes that have no matching TOSA argument. |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 166 | """ |
| 167 | tosaArgTypes = getTosaArgTypes(tosaXml) |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 168 | serAttsDict = {att["name"]: att for att in serialLibAtts} |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 169 | tosaArgsNames = [arg["name"] for arg in tosaArgs] |
| 170 | delTosaArgs = [] |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 171 | # Replace TOSA argument data types with their matching Serialization attribute data types. |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 172 | for tosaArg in tosaArgs: |
| 173 | if tosaArg["type"] in tosaArgTypes: |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 174 | if tosaArg["name"] in serAttsDict: |
| 175 | tosaArg["type"] = serAttsDict[tosaArg["name"]]["dType"] |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 176 | else: |
| 177 | # Delete TOSA argument whose data type can't be determined |
| 178 | delTosaArgs.append(tosaArgsNames.index(tosaArg["name"])) |
| 179 | # Delete corresponding length argument if one exists |
| 180 | lenArgName = f"{tosaArg['name']}_len" |
| 181 | if lenArgName in tosaArgsNames: |
| 182 | delTosaArgs.append(tosaArgsNames.index(lenArgName)) |
| 183 | # Delete TOSA arguments where the type couldn't be determined |
| 184 | for index in sorted(delTosaArgs, key=int, reverse=True): |
| 185 | del tosaArgs[index] |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 186 | # Add Serialization attributes that have no matching TOSA argument |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 187 | tosaArgNames = [arg["name"] for arg in tosaArgs] |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 188 | for serAtt in serialLibAtts: |
| 189 | attName = serAtt["name"] |
| 190 | attType = serAtt["dType"] |
| 191 | if (attName not in tosaArgNames) and (not attType == "tosa::DType"): |
| 192 | serAttName = serAtt["name"] |
| 193 | if serAtt["SV"] == "V": |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 194 | # For vector data types, insert a matching length argument |
| 195 | tosaArgs.insert( |
| 196 | len(tosaArgs) - 1, |
| 197 | { |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 198 | "name": f"{serAttName}_len", |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 199 | "type": "int32_t", |
| 200 | "shape": "", |
| 201 | "category": "", |
| 202 | }, |
| 203 | ) |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 204 | init = f"const std::vector<{attType}> {attName}(&client_{serAttName}[0], &client_{serAttName}[0] + client_{serAttName}_len);" |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 205 | shape = "[]" |
| 206 | else: |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 207 | init = "" |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 208 | shape = "" |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 209 | serAtt["init"] = init |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 210 | # Insert new argument |
| 211 | tosaArgs.insert( |
| 212 | len(tosaArgs) - 1, |
| 213 | { |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 214 | "name": serAttName, |
| 215 | "type": serAtt["dType"], |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 216 | "shape": shape, |
| 217 | "category": "", |
| 218 | }, |
| 219 | ) |
| 220 | |
| 221 | |
| 222 | def getOperators(tosaXml): |
| 223 | """ |
| 224 | Return a list of TOSA operators as defined by tosa.xml. |
| 225 | """ |
| 226 | operators = [] |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 227 | ignoreOps = [ |
| 228 | "while_loop", |
| 229 | "cond_if", |
| 230 | "const", |
| 231 | "custom", |
| 232 | "fft2d", |
| 233 | "rfft2d", |
| 234 | "variable", |
| 235 | "variable_read", |
| 236 | "variable_write", |
| 237 | ] |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 238 | opsXml = tosaXml.getElementsByTagName("operator") |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 239 | allSerialLibAtts = getSerialLibAtts() |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 240 | for opXml in opsXml: |
| 241 | opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower() |
| 242 | if opName not in ignoreOps: |
| 243 | operator = {"name": opName} |
| 244 | operator["serializeAttType"] = getSerializeOpType(opName) |
| 245 | tosaArgs = getTosaArgs(opXml) |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 246 | serialLibAtts = getSerialLibAttsForOp(opName, allSerialLibAtts, tosaArgs) |
Grant Watson | 6168047 | 2023-05-31 14:56:13 +0100 | [diff] [blame] | 247 | # Handle "axis" arguments |
| 248 | axisList = [arg["name"] for arg in tosaArgs if arg["name"] == "axis"] |
| 249 | if operator["serializeAttType"] == "None" and len(axisList) > 0: |
| 250 | operator["serializeAttType"] = "Axis" |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 251 | serialLibAtts = [ |
Grant Watson | 6168047 | 2023-05-31 14:56:13 +0100 | [diff] [blame] | 252 | { |
| 253 | "name": "axis", |
| 254 | "dType": "int32_t", |
| 255 | "SV": "S", |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 256 | "init": "", |
Grant Watson | 6168047 | 2023-05-31 14:56:13 +0100 | [diff] [blame] | 257 | } |
| 258 | ] |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 259 | updateTosaArgs(tosaArgs, serialLibAtts, tosaXml) |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 260 | operator["arguments"] = tosaArgs |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 261 | operator["serialLibAtts"] = serialLibAtts |
| 262 | serializationAttNames = [att["name"] for att in serialLibAtts] |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 263 | operator["inputs"] = [ |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 264 | arg["name"] |
| 265 | for arg in tosaArgs |
| 266 | if arg["category"] == "input" |
| 267 | and arg["name"] not in serializationAttNames |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 268 | ] |
| 269 | operator["outputs"] = [ |
| 270 | arg["name"] for arg in tosaArgs if arg["category"] == "output" |
| 271 | ] |
| 272 | operators.append(operator) |
| 273 | return operators |
| 274 | |
| 275 | |
| 276 | def getTosaArgs(opXml): |
| 277 | """ |
| 278 | Return the arguments required for the TOSA operator specified. |
| 279 | """ |
| 280 | arguments = [] |
| 281 | argsXml = opXml.getElementsByTagName("argument") |
| 282 | tosaTensorTypes = getTosaArgTypes(tosaXml) |
| 283 | tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"} |
Dmitrii Agibov | c8fdccf | 2023-09-21 11:05:58 +0100 | [diff] [blame] | 284 | tensorElemTypeMap = { |
| 285 | "resize_mode_t": "tosa_mode_t", |
| 286 | "acc_size_t": "tosa_acc_size_t", |
| 287 | } |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 288 | for xmlArg in argsXml: |
| 289 | argName = xmlArg.getAttribute("name").lower() |
Dmitrii Agibov | c8fdccf | 2023-09-21 11:05:58 +0100 | [diff] [blame] | 290 | tensorElemType = xmlArg.getAttribute("tensor-element-type") |
| 291 | if tensorElemType in tensorElemTypeMap: |
| 292 | argType = tensorElemTypeMap[tensorElemType] |
Grant Watson | eb74106 | 2023-06-23 16:52:12 +0100 | [diff] [blame] | 293 | else: |
| 294 | argType = xmlArg.getAttribute("type") |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 295 | argShape = xmlArg.getAttribute("shape") |
| 296 | argCategory = xmlArg.getAttribute("category") |
Grant Watson | e70d931 | 2023-08-28 16:34:28 +0100 | [diff] [blame] | 297 | # FullyConnected workaround |
| 298 | if (argName == "weight" or argName == "bias") and (argCategory == "attribute"): |
| 299 | argCategory = "input" |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 300 | # Update argument type |
| 301 | if argType[-1:] == "*": |
| 302 | argType = argType[:-1] |
| 303 | if argCategory in ["input", "output"] and argType in tosaTensorTypes: |
| 304 | argType = "tosa_tensor_t" |
| 305 | argShape = "" |
| 306 | if argType in tosaTypeMap: |
| 307 | argType = tosaTypeMap[argType] |
| 308 | # Add a length argument for arrays with unknown compile-time size |
| 309 | if argShape != "" and argShape[0] == "[" and not argShape[1:-1].isnumeric(): |
| 310 | argShape = "[]" |
| 311 | arguments.append( |
| 312 | { |
| 313 | "name": f"{argName}_len", |
| 314 | "type": "int32_t", |
| 315 | "shape": "", |
| 316 | "category": "", |
| 317 | } |
| 318 | ) |
| 319 | elif argShape == "" or not argShape[0] == "[": |
| 320 | argShape = "" |
| 321 | # Append argument |
| 322 | arguments.append( |
| 323 | { |
| 324 | "name": argName, |
| 325 | "type": argType, |
| 326 | "shape": argShape, |
| 327 | "category": argCategory, |
| 328 | } |
| 329 | ) |
| 330 | return arguments |
| 331 | |
| 332 | |
| 333 | def clangFormat(filename): |
| 334 | cmd = ["clang-format", "-i", filename] |
| 335 | with open(os.devnull, "w") as devnull: |
| 336 | subprocess.check_call(cmd, stdout=devnull) |
| 337 | |
| 338 | |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 339 | def getSerialLibAtts(): |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 340 | """ |
| 341 | Parse attribute.def file and return a dictionary where the keys are Serialization library operator names. |
| 342 | The values are the arguments required by each Serialization library operator. |
| 343 | """ |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 344 | serialLibAtts = {} |
Eric Kunze | 99f8f9f | 2023-09-07 01:36:07 +0000 | [diff] [blame] | 345 | base_path = getBasePath() |
| 346 | attr_def = ( |
| 347 | base_path / "thirdparty" / "serialization_lib" / "include" / "attribute.def" |
| 348 | ) |
| 349 | with open(attr_def) as file: |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 350 | preamble = True |
| 351 | inAtt = False |
| 352 | opName = "" |
| 353 | args = [] |
| 354 | for line in file: |
| 355 | if preamble and not line[: len("DEF_ATTRIBUTE(")] == "DEF_ATTRIBUTE(": |
| 356 | continue |
| 357 | else: |
| 358 | preamble = False |
| 359 | line = line.lstrip().rstrip() |
| 360 | if not inAtt and "DEF_ATTRIBUTE(" in line: |
| 361 | opName = line[len("DEF_ATTRIBUTE(") : line.find(",")] |
| 362 | inAtt = True |
| 363 | elif inAtt: |
| 364 | vals = line.split(",") |
| 365 | argName = vals[2].lstrip().strip() |
| 366 | if ")" in argName: |
| 367 | argName = argName[:-1] |
| 368 | arg = { |
| 369 | "name": argName, |
| 370 | "dType": vals[0].lstrip().strip(), |
| 371 | "SV": vals[1].lstrip().strip(), |
| 372 | } |
| 373 | args.append(arg) |
| 374 | if ")" in line: |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 375 | serialLibAtts[opName] = args |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 376 | opName = "" |
| 377 | args = [] |
| 378 | inAtt = False |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 379 | return serialLibAtts |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 380 | |
| 381 | |
| 382 | def renderTemplate(environment, dataTypes, operators, template, outfile): |
| 383 | content = template.render(dataTypes=dataTypes, operators=operators) |
| 384 | with open(outfile, mode="w", encoding="utf-8") as output: |
| 385 | output.write(content) |
| 386 | print(f"Created {outfile}") |
| 387 | |
| 388 | clangFormat(outfile) |
| 389 | |
| 390 | |
Eric Kunze | 99f8f9f | 2023-09-07 01:36:07 +0000 | [diff] [blame] | 391 | def generate(environment, dataTypes, operators, base_path): |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 392 | # Generate include/operators.h |
| 393 | template = environment.get_template("operators_h.j2") |
Eric Kunze | 99f8f9f | 2023-09-07 01:36:07 +0000 | [diff] [blame] | 394 | outfile = base_path / "reference_model/include/operators.h" |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 395 | renderTemplate(environment, dataTypes, operators, template, outfile) |
| 396 | |
| 397 | # Generate src/operators.cc |
| 398 | template = environment.get_template("operators_cc.j2") |
Eric Kunze | 99f8f9f | 2023-09-07 01:36:07 +0000 | [diff] [blame] | 399 | outfile = base_path / "reference_model/src/operators.cc" |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 400 | renderTemplate(environment, dataTypes, operators, template, outfile) |
| 401 | |
| 402 | |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 403 | if __name__ == "__main__": |
Eric Kunze | 99f8f9f | 2023-09-07 01:36:07 +0000 | [diff] [blame] | 404 | base_path = getBasePath() |
| 405 | environment = Environment( |
| 406 | loader=FileSystemLoader(Path(__file__).resolve().parent / "templates") |
| 407 | ) |
| 408 | tosaXml = minidom.parse(str(base_path / "thirdparty/specification/tosa.xml")) |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 409 | dataTypes = getTosaDataTypes(tosaXml) |
| 410 | operators = getOperators(tosaXml) |
Eric Kunze | 99f8f9f | 2023-09-07 01:36:07 +0000 | [diff] [blame] | 411 | generate(environment, dataTypes, operators, base_path) |