Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 1 | """Generate extended reference model API with eager operator execution entrypoints""" |
Grant Watson | e70d931 | 2023-08-28 16:34:28 +0100 | [diff] [blame] | 2 | # Copyright (c) 2021-2023, ARM Limited. |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | import copy |
| 5 | import os |
| 6 | import subprocess |
Eric Kunze | 99f8f9f | 2023-09-07 01:36:07 +0000 | [diff] [blame] | 7 | from pathlib import Path |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 8 | from xml.dom import minidom |
| 9 | |
| 10 | from jinja2 import Environment |
| 11 | from jinja2 import FileSystemLoader |
| 12 | |
James Ward | d34b3fc | 2023-01-18 14:51:25 +0000 | [diff] [blame] | 13 | # Note: main script designed to be run from the scripts/operator_api/ directory |
| 14 | |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 15 | |
Eric Kunze | 99f8f9f | 2023-09-07 01:36:07 +0000 | [diff] [blame] | 16 | def getBasePath(): |
| 17 | return Path(__file__).resolve().parent.parent.parent |
| 18 | |
| 19 | |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 20 | def getTosaArgTypes(tosaXml): |
| 21 | """ |
| 22 | Returns a list of the TOSA argument types from tosa.xml. |
| 23 | """ |
Grant Watson | eb74106 | 2023-06-23 16:52:12 +0100 | [diff] [blame] | 24 | argTypes = { |
| 25 | "tensor_t", |
| 26 | "in_t", |
| 27 | "out_t", |
| 28 | "mul_t", |
| 29 | "weight_t", |
| 30 | "in_out_t", |
| 31 | "tensor_list_t", |
| 32 | } |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 33 | argTypesXml = tosaXml.getElementsByTagName("type") |
| 34 | for argTypeXml in argTypesXml: |
| 35 | argTypes.add(argTypeXml.getAttribute("name")) |
| 36 | argTypes.remove("TABLE_SIZE") |
| 37 | return argTypes |
| 38 | |
| 39 | |
| 40 | def getTosaDataTypes(tosaXml): |
| 41 | """ |
| 42 | Returns a list of the TOSA data types from tosa.xml. |
| 43 | """ |
| 44 | argTypes = getTosaArgTypes(tosaXml) |
| 45 | dataTypes = set() |
| 46 | dataTypesXml = tosaXml.getElementsByTagName("typesupport") |
| 47 | for dataTypeXml in dataTypesXml: |
| 48 | for argType in argTypes: |
| 49 | dataType = dataTypeXml.getAttribute(argType) |
| 50 | if dataType != "": |
| 51 | dataTypes.add(f"tosa_datatype_{dataType}") |
| 52 | return sorted(dataTypes) |
| 53 | |
| 54 | |
| 55 | def getSerializeOpType(tosaOpName): |
| 56 | """ |
| 57 | Returns the Serialization library operator that matches the TOSA operator specified. |
| 58 | """ |
| 59 | map = { |
| 60 | "avg_pool2d": "Pool", |
| 61 | "conv2d": "Conv", |
| 62 | "conv3d": "Conv", |
| 63 | "depthwise_conv2d": "Conv", |
| 64 | "fully_connected": "FullyConnected", |
Dhruv Chauhan | 35a3aa9 | 2023-11-28 15:00:34 +0000 | [diff] [blame] | 65 | "fft2d": "FFT", |
| 66 | "rfft2d": "RFFT", |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 67 | "matmul": "MatMul", |
| 68 | "max_pool2d": "Pool", |
Dmitrii Agibov | b0b9e33 | 2023-11-01 13:49:37 +0000 | [diff] [blame] | 69 | "transpose_conv2d": "TransposeConv", |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 70 | "clamp": "Clamp", |
| 71 | "arithmetic_right_shift": "ArithmeticRightShift", |
| 72 | "mul": "Mul", |
| 73 | "table": "Table", |
| 74 | "negate": "Negate", |
| 75 | "pad": "Pad", |
| 76 | "reshape": "Reshape", |
| 77 | "slice": "Slice", |
| 78 | "tile": "Tile", |
| 79 | "transpose": "Transpose", |
| 80 | "resize": "Resize", |
| 81 | "rescale": "Rescale", |
| 82 | "cond_if": "CondIf", |
| 83 | "while_loop": "WhileLoop", |
| 84 | } |
| 85 | if tosaOpName not in map.keys(): |
| 86 | return "None" |
| 87 | else: |
| 88 | return map[tosaOpName] |
| 89 | |
| 90 | |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 91 | def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs): |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 92 | """ |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 93 | Returns the attributes required by the Serialization library for the TOSA operator specified. |
| 94 | Generates code to initialize Serialization library attributes. If a matching TOSA argument exists, |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 95 | that value is used for initialization, otherwise a default value e.g. 0 is used. |
| 96 | """ |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 97 | serLibOpType = getSerializeOpType(tosaOpName) |
| 98 | if serLibOpType not in allSerialLibAtts.keys(): |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 99 | return {} |
| 100 | else: |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 101 | serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType]) |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 102 | tosaArgsDict = {arg["name"]: arg for arg in tosaArgs} |
| 103 | serTosaTypeMap = {"ResizeMode": "tosa_mode"} |
Dmitrii Agibov | b0b9e33 | 2023-11-01 13:49:37 +0000 | [diff] [blame] | 104 | serAttsToFix = { |
| 105 | "reshape": {"new_shape": "shape"}, |
| 106 | "transpose_conv2d": {"output_shape": "out_shape"}, |
| 107 | } |
| 108 | if tosaOpName in serAttsToFix: |
| 109 | # Fix attributes names to match with tosa.xml |
| 110 | for attDefName, tosaSpecName in serAttsToFix[tosaOpName].items(): |
| 111 | for opAtts in serLibOpAtts: |
| 112 | if opAtts["name"] == attDefName: |
| 113 | opAtts["name"] = tosaSpecName |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 114 | for att in serLibOpAtts: |
| 115 | attName = att["name"] |
| 116 | attType = att["dType"] |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 117 | init = "" |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 118 | # Translate TOSA data types to Serialization library data types for initialization |
| 119 | if attType in serTosaTypeMap.keys(): |
| 120 | init = f"const {attType} {attName} = translate_client_{serTosaTypeMap[att['dType']]}(client_{attName});" |
| 121 | # Initialize Serialization library attributes to their matching function parameter |
Dmitrii Agibov | c8fdccf | 2023-09-21 11:05:58 +0100 | [diff] [blame] | 122 | elif tosaOpName == "avg_pool2d" and attName == "accum_dtype": |
| 123 | init = f"const tosa::DType {attName} = translate_client_acc_size(client_acc_size);" |
| 124 | att["dType"] = "tosa::DType" |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 125 | elif attName in tosaArgsDict: |
| 126 | if att["SV"] == "V": |
| 127 | if tosaArgsDict[attName]["type"] == "tosa_tensor_t": |
| 128 | init = f"std::vector<{attType}> {attName};" |
| 129 | init = ( |
| 130 | init |
| 131 | + f"size_t {attName}_size = client_{attName}.size / sizeof({attType});" |
| 132 | ) |
| 133 | init = ( |
| 134 | init |
| 135 | + f"{attType}* {attName}_data = reinterpret_cast<{attType}*>(client_{attName}.data);" |
| 136 | ) |
| 137 | init = ( |
| 138 | init |
| 139 | + f"{attName}.assign({attName}_data, {attName}_data + {attName}_size);" |
| 140 | ) |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 141 | else: |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 142 | init = f"const std::vector<{attType}> {attName}" |
| 143 | shape = tosaArgsDict[attName]["shape"] |
| 144 | if shape == "[]": |
| 145 | init = ( |
| 146 | init |
| 147 | + f"(&client_{attName}[0], &client_{attName}[0] + client_{attName}_len);" |
| 148 | ) |
| 149 | else: |
| 150 | init = ( |
| 151 | init |
| 152 | + f"(&client_{attName}[0], &client_{attName}{shape});" |
| 153 | ) |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 154 | else: |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 155 | init = "" |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 156 | else: |
| 157 | # Initialize Serialization library attributes with no matching fuction parameter |
| 158 | if att["SV"] == "V": |
| 159 | init = f"std::vector<int32_t> {attName};" |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 160 | else: |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 161 | if att["dType"] == "DType": |
| 162 | att["dType"] = "tosa::DType" |
| 163 | init = f"const tosa::DType {attName} = tosa::DType::DType_FP32;" |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 164 | else: |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 165 | init = f"const {attType} {attName} = 0;" |
| 166 | att["init"] = init |
| 167 | return serLibOpAtts |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 168 | |
| 169 | |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 170 | def updateTosaArgs(tosaArgs, serialLibAtts, tosaXml): |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 171 | """ |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 172 | Replace TOSA argument data types with their matching Serialization attribute data types. |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 173 | Delete TOSA arguments where the type couldn't be determined. |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 174 | Add Serialization attributes that have no matching TOSA argument. |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 175 | """ |
| 176 | tosaArgTypes = getTosaArgTypes(tosaXml) |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 177 | serAttsDict = {att["name"]: att for att in serialLibAtts} |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 178 | tosaArgsNames = [arg["name"] for arg in tosaArgs] |
| 179 | delTosaArgs = [] |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 180 | # Replace TOSA argument data types with their matching Serialization attribute data types. |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 181 | for tosaArg in tosaArgs: |
| 182 | if tosaArg["type"] in tosaArgTypes: |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 183 | if tosaArg["name"] in serAttsDict: |
| 184 | tosaArg["type"] = serAttsDict[tosaArg["name"]]["dType"] |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 185 | else: |
| 186 | # Delete TOSA argument whose data type can't be determined |
| 187 | delTosaArgs.append(tosaArgsNames.index(tosaArg["name"])) |
| 188 | # Delete corresponding length argument if one exists |
| 189 | lenArgName = f"{tosaArg['name']}_len" |
| 190 | if lenArgName in tosaArgsNames: |
| 191 | delTosaArgs.append(tosaArgsNames.index(lenArgName)) |
| 192 | # Delete TOSA arguments where the type couldn't be determined |
| 193 | for index in sorted(delTosaArgs, key=int, reverse=True): |
| 194 | del tosaArgs[index] |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 195 | # Add Serialization attributes that have no matching TOSA argument |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 196 | tosaArgNames = [arg["name"] for arg in tosaArgs] |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 197 | for serAtt in serialLibAtts: |
| 198 | attName = serAtt["name"] |
| 199 | attType = serAtt["dType"] |
| 200 | if (attName not in tosaArgNames) and (not attType == "tosa::DType"): |
| 201 | serAttName = serAtt["name"] |
| 202 | if serAtt["SV"] == "V": |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 203 | # For vector data types, insert a matching length argument |
| 204 | tosaArgs.insert( |
| 205 | len(tosaArgs) - 1, |
| 206 | { |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 207 | "name": f"{serAttName}_len", |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 208 | "type": "int32_t", |
| 209 | "shape": "", |
| 210 | "category": "", |
| 211 | }, |
| 212 | ) |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 213 | init = f"const std::vector<{attType}> {attName}(&client_{serAttName}[0], &client_{serAttName}[0] + client_{serAttName}_len);" |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 214 | shape = "[]" |
| 215 | else: |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 216 | init = "" |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 217 | shape = "" |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 218 | serAtt["init"] = init |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 219 | # Insert new argument |
| 220 | tosaArgs.insert( |
| 221 | len(tosaArgs) - 1, |
| 222 | { |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 223 | "name": serAttName, |
| 224 | "type": serAtt["dType"], |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 225 | "shape": shape, |
| 226 | "category": "", |
| 227 | }, |
| 228 | ) |
| 229 | |
| 230 | |
| 231 | def getOperators(tosaXml): |
| 232 | """ |
| 233 | Return a list of TOSA operators as defined by tosa.xml. |
| 234 | """ |
| 235 | operators = [] |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 236 | ignoreOps = [ |
| 237 | "while_loop", |
| 238 | "cond_if", |
| 239 | "const", |
| 240 | "custom", |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 241 | "variable", |
| 242 | "variable_read", |
| 243 | "variable_write", |
| 244 | ] |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 245 | opsXml = tosaXml.getElementsByTagName("operator") |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 246 | allSerialLibAtts = getSerialLibAtts() |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 247 | for opXml in opsXml: |
| 248 | opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower() |
| 249 | if opName not in ignoreOps: |
| 250 | operator = {"name": opName} |
| 251 | operator["serializeAttType"] = getSerializeOpType(opName) |
| 252 | tosaArgs = getTosaArgs(opXml) |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 253 | serialLibAtts = getSerialLibAttsForOp(opName, allSerialLibAtts, tosaArgs) |
Grant Watson | 6168047 | 2023-05-31 14:56:13 +0100 | [diff] [blame] | 254 | # Handle "axis" arguments |
| 255 | axisList = [arg["name"] for arg in tosaArgs if arg["name"] == "axis"] |
| 256 | if operator["serializeAttType"] == "None" and len(axisList) > 0: |
| 257 | operator["serializeAttType"] = "Axis" |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 258 | serialLibAtts = [ |
Grant Watson | 6168047 | 2023-05-31 14:56:13 +0100 | [diff] [blame] | 259 | { |
| 260 | "name": "axis", |
| 261 | "dType": "int32_t", |
| 262 | "SV": "S", |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 263 | "init": "", |
Grant Watson | 6168047 | 2023-05-31 14:56:13 +0100 | [diff] [blame] | 264 | } |
| 265 | ] |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 266 | updateTosaArgs(tosaArgs, serialLibAtts, tosaXml) |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 267 | operator["arguments"] = tosaArgs |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 268 | operator["serialLibAtts"] = serialLibAtts |
| 269 | serializationAttNames = [att["name"] for att in serialLibAtts] |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 270 | operator["inputs"] = [ |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 271 | arg["name"] |
| 272 | for arg in tosaArgs |
| 273 | if arg["category"] == "input" |
| 274 | and arg["name"] not in serializationAttNames |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 275 | ] |
| 276 | operator["outputs"] = [ |
| 277 | arg["name"] for arg in tosaArgs if arg["category"] == "output" |
| 278 | ] |
| 279 | operators.append(operator) |
| 280 | return operators |
| 281 | |
| 282 | |
| 283 | def getTosaArgs(opXml): |
| 284 | """ |
| 285 | Return the arguments required for the TOSA operator specified. |
| 286 | """ |
| 287 | arguments = [] |
| 288 | argsXml = opXml.getElementsByTagName("argument") |
| 289 | tosaTensorTypes = getTosaArgTypes(tosaXml) |
| 290 | tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"} |
Dmitrii Agibov | c8fdccf | 2023-09-21 11:05:58 +0100 | [diff] [blame] | 291 | tensorElemTypeMap = { |
| 292 | "resize_mode_t": "tosa_mode_t", |
| 293 | "acc_size_t": "tosa_acc_size_t", |
| 294 | } |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 295 | for xmlArg in argsXml: |
| 296 | argName = xmlArg.getAttribute("name").lower() |
Dmitrii Agibov | c8fdccf | 2023-09-21 11:05:58 +0100 | [diff] [blame] | 297 | tensorElemType = xmlArg.getAttribute("tensor-element-type") |
| 298 | if tensorElemType in tensorElemTypeMap: |
| 299 | argType = tensorElemTypeMap[tensorElemType] |
Grant Watson | eb74106 | 2023-06-23 16:52:12 +0100 | [diff] [blame] | 300 | else: |
| 301 | argType = xmlArg.getAttribute("type") |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 302 | argShape = xmlArg.getAttribute("shape") |
| 303 | argCategory = xmlArg.getAttribute("category") |
Grant Watson | e70d931 | 2023-08-28 16:34:28 +0100 | [diff] [blame] | 304 | # FullyConnected workaround |
| 305 | if (argName == "weight" or argName == "bias") and (argCategory == "attribute"): |
| 306 | argCategory = "input" |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 307 | # Update argument type |
| 308 | if argType[-1:] == "*": |
| 309 | argType = argType[:-1] |
| 310 | if argCategory in ["input", "output"] and argType in tosaTensorTypes: |
| 311 | argType = "tosa_tensor_t" |
| 312 | argShape = "" |
| 313 | if argType in tosaTypeMap: |
| 314 | argType = tosaTypeMap[argType] |
| 315 | # Add a length argument for arrays with unknown compile-time size |
| 316 | if argShape != "" and argShape[0] == "[" and not argShape[1:-1].isnumeric(): |
| 317 | argShape = "[]" |
| 318 | arguments.append( |
| 319 | { |
| 320 | "name": f"{argName}_len", |
| 321 | "type": "int32_t", |
| 322 | "shape": "", |
| 323 | "category": "", |
| 324 | } |
| 325 | ) |
| 326 | elif argShape == "" or not argShape[0] == "[": |
| 327 | argShape = "" |
| 328 | # Append argument |
| 329 | arguments.append( |
| 330 | { |
| 331 | "name": argName, |
| 332 | "type": argType, |
| 333 | "shape": argShape, |
| 334 | "category": argCategory, |
| 335 | } |
| 336 | ) |
| 337 | return arguments |
| 338 | |
| 339 | |
| 340 | def clangFormat(filename): |
| 341 | cmd = ["clang-format", "-i", filename] |
| 342 | with open(os.devnull, "w") as devnull: |
| 343 | subprocess.check_call(cmd, stdout=devnull) |
| 344 | |
| 345 | |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 346 | def getSerialLibAtts(): |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 347 | """ |
| 348 | Parse attribute.def file and return a dictionary where the keys are Serialization library operator names. |
| 349 | The values are the arguments required by each Serialization library operator. |
| 350 | """ |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 351 | serialLibAtts = {} |
Eric Kunze | 99f8f9f | 2023-09-07 01:36:07 +0000 | [diff] [blame] | 352 | base_path = getBasePath() |
| 353 | attr_def = ( |
| 354 | base_path / "thirdparty" / "serialization_lib" / "include" / "attribute.def" |
| 355 | ) |
| 356 | with open(attr_def) as file: |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 357 | preamble = True |
| 358 | inAtt = False |
| 359 | opName = "" |
| 360 | args = [] |
| 361 | for line in file: |
| 362 | if preamble and not line[: len("DEF_ATTRIBUTE(")] == "DEF_ATTRIBUTE(": |
| 363 | continue |
| 364 | else: |
| 365 | preamble = False |
| 366 | line = line.lstrip().rstrip() |
| 367 | if not inAtt and "DEF_ATTRIBUTE(" in line: |
| 368 | opName = line[len("DEF_ATTRIBUTE(") : line.find(",")] |
| 369 | inAtt = True |
| 370 | elif inAtt: |
| 371 | vals = line.split(",") |
| 372 | argName = vals[2].lstrip().strip() |
| 373 | if ")" in argName: |
| 374 | argName = argName[:-1] |
| 375 | arg = { |
| 376 | "name": argName, |
| 377 | "dType": vals[0].lstrip().strip(), |
| 378 | "SV": vals[1].lstrip().strip(), |
| 379 | } |
| 380 | args.append(arg) |
| 381 | if ")" in line: |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 382 | serialLibAtts[opName] = args |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 383 | opName = "" |
| 384 | args = [] |
| 385 | inAtt = False |
Grant Watson | eff7038 | 2023-09-12 10:46:36 +0100 | [diff] [blame] | 386 | return serialLibAtts |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 387 | |
| 388 | |
| 389 | def renderTemplate(environment, dataTypes, operators, template, outfile): |
| 390 | content = template.render(dataTypes=dataTypes, operators=operators) |
| 391 | with open(outfile, mode="w", encoding="utf-8") as output: |
| 392 | output.write(content) |
| 393 | print(f"Created {outfile}") |
| 394 | |
| 395 | clangFormat(outfile) |
| 396 | |
| 397 | |
Eric Kunze | 99f8f9f | 2023-09-07 01:36:07 +0000 | [diff] [blame] | 398 | def generate(environment, dataTypes, operators, base_path): |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 399 | # Generate include/operators.h |
| 400 | template = environment.get_template("operators_h.j2") |
Eric Kunze | 99f8f9f | 2023-09-07 01:36:07 +0000 | [diff] [blame] | 401 | outfile = base_path / "reference_model/include/operators.h" |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 402 | renderTemplate(environment, dataTypes, operators, template, outfile) |
| 403 | |
| 404 | # Generate src/operators.cc |
| 405 | template = environment.get_template("operators_cc.j2") |
Eric Kunze | 99f8f9f | 2023-09-07 01:36:07 +0000 | [diff] [blame] | 406 | outfile = base_path / "reference_model/src/operators.cc" |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 407 | renderTemplate(environment, dataTypes, operators, template, outfile) |
| 408 | |
| 409 | |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 410 | if __name__ == "__main__": |
Eric Kunze | 99f8f9f | 2023-09-07 01:36:07 +0000 | [diff] [blame] | 411 | base_path = getBasePath() |
| 412 | environment = Environment( |
| 413 | loader=FileSystemLoader(Path(__file__).resolve().parent / "templates") |
| 414 | ) |
| 415 | tosaXml = minidom.parse(str(base_path / "thirdparty/specification/tosa.xml")) |
Grant Watson | 64285a1 | 2022-11-16 15:32:39 +0000 | [diff] [blame] | 416 | dataTypes = getTosaDataTypes(tosaXml) |
| 417 | operators = getOperators(tosaXml) |
Eric Kunze | 99f8f9f | 2023-09-07 01:36:07 +0000 | [diff] [blame] | 418 | generate(environment, dataTypes, operators, base_path) |