blob: 4b0a2fece6519861a47873acccdd88d1a1746128 [file] [log] [blame]
#!/usr/bin/env python3
# Copyright (c) 2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import re
import xml.etree.ElementTree as ET
# possible shapes: shape1, [2], [N,H,W,C]
# returns (checkable, rank)
# checkable is false if shape doesn't contain []
def get_rank_from_shape(shape):
if "[" not in shape or "[]" in shape:
return (False, -1)
# Check for fixed rank requirement [N]
m = re.match(r"\[(\d+)\]", shape)
if m:
return (True, 1)
# Check for comma separated rank descriptors, return count
m = re.match(r"\[(.*)\]", shape)
if m:
return (True, len(m.group(1).split(",")))
else:
raise RuntimeError(f"Unable to parse shape {shape}")
class TOSAOperatorArgumentCategory:
def __init__(self, name, profiles=None):
self.name = name
self.profiles = profiles
class TOSAProfile:
def __init__(self, profile, name, description, status):
self.profile = profile
self.name = name
self.description = description
self.status = status
self.ops = []
class TOSAProfileExtension:
def __init__(self, name, description, status, profiles):
self.name = name
self.description = description
self.status = status
self.profiles = profiles
self.ops = []
class TOSAEnum:
def __init__(self, name, description, values):
self.name = name
self.description = description
self.values = values
class TOSALevel:
def __init__(self, name, desc, maximums):
self.name = name
self.desc = desc
self.maximums = maximums
class TOSAOperatorArgument:
def __init__(
self,
name,
description,
categories,
ty,
elty,
shape,
levellimits,
rank,
optional=False,
):
assert isinstance(optional, bool)
self.name = name
self.description = description
self.categories = categories
self.type = ty
self.tensor_element_type = elty
self.shape = shape
self.levellimits = levellimits
self.rank = rank
self.optional = optional
class TOSAOperatorDataTypeSupport:
def __init__(self, mode, tymap, version_added, profiles):
self.mode = mode
self.tymap = tymap
self.profiles = profiles
self.version_added = version_added
class TOSAOperator:
def __init__(self, name, arguments, types, typesupports):
self.name = name
self.arguments = arguments
self.types = types
self.typesupports = typesupports
class TOSAOperatorGroup:
def __init__(self, name, operators):
self.name = name
self.operators = operators
class TOSASpec:
def __init__(self, xmlpath):
tree = ET.parse(xmlpath)
self.xmlroot = tree.getroot()
self.profiles = []
self.profile_extensions = []
self.levels = []
self.operatorgroups = []
self.enums = []
self.__load_spec()
def __load_spec(self):
self.__load_version()
for profile in self.xmlroot.findall("./profiles/profile"):
self.profiles.append(self.__load_profile(profile))
for profile_ext in self.xmlroot.findall(
"./profile_extensions/profile_extension"
):
self.profile_extensions.append(self.__load_profile_extension(profile_ext))
for level in self.xmlroot.findall("./levels/level"):
self.levels.append(self.__load_level(level))
for group in self.xmlroot.findall("./operators/operatorgroup"):
self.operatorgroups.append(self.__load_operator_group(group))
for enum in self.xmlroot.findall("./enum"):
self.enums.append(self.__load_enum(enum))
def __load_version(self):
version = self.xmlroot.find("./version")
self.version_major = int(version.get("major"))
self.version_minor = int(version.get("minor"))
self.version_patch = int(version.get("patch"))
if version.get("draft") == "true":
self.version_is_draft = True
else:
self.version_is_draft = False
def __load_profile(self, xml_profile):
profile = xml_profile.get("profile")
name = xml_profile.get("name")
description = xml_profile.get("description")
status = xml_profile.get("status")
return TOSAProfile(profile, name, description, status)
def __load_profile_extension(self, ext):
name = ext.get("name")
description = ext.get("description")
status = ext.get("status")
profiles = [x.text for x in ext]
return TOSAProfileExtension(name, description, status, profiles)
def __load_level(self, level):
name = level.get("name")
desc = level.text.strip()
maximums = {
"MAX_RANK": level.get("max_rank"),
"MAX_KERNEL": level.get("max_kernel"),
"MAX_STRIDE": level.get("max_stride"),
"MAX_SCALE": level.get("max_scale"),
"MAX_LOG2_SIZE": level.get("max_log2_size"),
"MAX_NESTING": level.get("max_nesting"),
"MAX_TENSOR_LIST_SIZE": level.get("max_tensor_list_size"),
}
return TOSALevel(name, desc, maximums)
def __load_operator_group(self, group):
name = group.get("name")
operators = []
for op in group.findall("operator"):
operators.append(self.__load_operator(op))
return TOSAOperatorGroup(name, operators)
def __load_operator(self, op):
name = op.find("name").text
args = []
types = []
typesupports = []
for arg in op.findall("arguments/argument"):
args.append(self.__load_operator_argument(arg, name))
# TODO add pseudo-code to operator object?
for ty in op.findall("types/type"):
types.append(ty.get("name"))
for tysup in op.findall("typesupport"):
tsmode = tysup.get("mode")
tsmap = {}
version_added = tysup.get("version_added")
profiles = tysup.findall("op_profile")
tsprofiles = []
for p in profiles:
tsp_name = p.get("name")
and_name = p.get("and_name")
if and_name is not None:
if and_name < tsp_name:
tsp_name = f"{and_name} and {tsp_name}"
else:
tsp_name = f"{tsp_name} and {and_name}"
tsprofiles.append(tsp_name)
for ty in types:
tsmap[ty] = tysup.get(ty)
typesupports.append(
TOSAOperatorDataTypeSupport(tsmode, tsmap, version_added, tsprofiles)
)
return TOSAOperator(name, args, types, typesupports)
def __load_operator_argument(self, arg, op_name):
name = arg.get("name")
desc = arg.find("description").text.strip()
argcats = []
argtype = arg.get("type")
argtelty = arg.get("tensor-element-type")
shape = arg.get("shape")
levellimits = []
rank = []
optional = arg.get("optional", "false") == "true"
r = arg.find("rank")
if r is not None:
rank = [r.get("min"), r.get("max")]
if shape == "-" and (rank[0] != "0" or rank[1] != "0"):
raise RuntimeError(
"rank is not empty or non-zero, but shape is '-'"
f" for {op_name}: {name}"
)
# validate rank against the shape argument
(shape_check, shape_rank) = get_rank_from_shape(shape)
if shape_check and (shape_rank < int(rank[0]) or shape_rank > int(rank[1])):
raise RuntimeError(
"Description of shape rank doesn't match XML rank"
f" min/max: {op_name} {name} shape: {shape} shape_rank: "
f"{shape_rank} min/max: {rank[0]} {rank[1]}"
)
else:
if shape != "-":
raise RuntimeError(
f"Rank not present for {op_name}: {name} when shape is {shape}"
)
for levellimit in arg.findall("levellimit"):
value = levellimit.get("value")
limit = levellimit.get("limit")
levellimits.append([value, limit])
cats = re.findall(
r"(input|output|attribute)\(?([A-Z,]+)?\)?", arg.get("category")
)
for cat in cats:
argcats.append(TOSAOperatorArgumentCategory(cat[0], cat[1].split(",")))
return TOSAOperatorArgument(
name, desc, argcats, argtype, argtelty, shape, levellimits, rank, optional
)
def __load_enum(self, arg):
name = arg.get("name")
desc = arg.get("description").strip()
values = []
for val in arg.findall("enumval"):
values.append((val.get("name"), val.get("value"), val.get("description")))
return TOSAEnum(name, desc, values)
def get_enum_by_name(self, name):
for e in self.enums:
if e.name == name:
return e