| # Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved. |
| # |
| # SPDX-License-Identifier: Apache-2.0 |
| # |
| # Licensed under the Apache License, Version 2.0 (the License); you may |
| # not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an AS IS BASIS, WITHOUT |
| # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # Description: |
| # TOSA mapping functions used by reader. |
| # Contains a mapping from the various TOSA enums and options structs, generated by the FlatBuffer code |
| # generator, to Vela's internal format. |
| import numpy as np |
| |
| from .data_type import DataType |
| from .operation import Op |
| from .operation import TensorIndices |
| from .tosa import ArithmeticRightShiftAttribute # noqa: F401 |
| from .tosa import AxisAttribute # noqa: F401 |
| from .tosa import ClampAttribute # noqa: F401 |
| from .tosa import CondIfAttribute # noqa: F401 |
| from .tosa import ConvAttribute # noqa: F401 |
| from .tosa import ConvQuantInfo # noqa: F401 |
| from .tosa import MatMulQuantInfo # noqa: F401 |
| from .tosa import MulAttribute # noqa: F401 |
| from .tosa import PadQuantInfo # noqa: F401 |
| from .tosa import PoolAttribute # noqa: F401 |
| from .tosa import ReluNAttribute # noqa: F401 |
| from .tosa import RescaleAttribute # noqa: F401 |
| from .tosa import ReshapeAttribute # noqa: F401 |
| from .tosa import ResizeAttribute # noqa: F401 |
| from .tosa import SliceAttribute # noqa: F401 |
| from .tosa import TileAttribute # noqa: F401 |
| from .tosa import TransposeConvAttribute # noqa: F401 |
| from .tosa import UnaryQuantInfo # noqa: F401 |
| from .tosa import WhileLoopAttribute # noqa: F401 |
| from .tosa.DType import DType |
| from .tosa.Op import Op as TosaOp |
| |
| |
| datatype_map = { |
| DType.BOOL: DataType.bool, |
| DType.UINT8: DataType.uint8, |
| DType.INT4: DataType.int4, |
| DType.INT8: DataType.int8, |
| DType.INT16: DataType.int16, |
| DType.INT32: DataType.int32, |
| DType.INT48: DataType.int48, |
| DType.FLOAT: DataType.float32, |
| } |
| |
| datatype_map_numpy = { |
| DType.BOOL: np.bool, |
| DType.UINT8: np.uint8, |
| DType.INT8: np.int8, |
| DType.INT16: np.int16, |
| DType.INT32: np.int32, |
| DType.FLOAT: np.float32, |
| } |
| |
| |
| # TODO duplicate of tflite_mapping |
| def underscore_to_camel_case(s): |
| return "".join(x.title() for x in s.split("_")) |
| |
| |
| # TODO duplicate of tflite_mapping |
| def identity(x): |
| return x |
| |
| |
| class AttrSerializer: |
| def __init__(self, name, members=None): |
| self.name = name |
| self.module = globals()[self.name] |
| self.cls = getattr(self.module, self.name) |
| self.members = [] |
| if members is not None: |
| for mem in members: |
| deserialize = identity |
| is_vector = False |
| if isinstance(mem, tuple): |
| if len(mem) == 2: |
| mem, is_vector = mem |
| deserialize = tuple |
| else: |
| assert 0 |
| underscore_mem = mem |
| camelcase_mem = underscore_to_camel_case(mem) |
| self.members.append((underscore_mem, camelcase_mem, deserialize, is_vector)) |
| |
| def deserialize(self, op_data): |
| attr_type = op_data.AttributeType() |
| attr = op_data.Attribute() |
| attrs = {} |
| if attr_type: |
| tosa_attrs = self.cls() |
| tosa_attrs.Init(attr.Bytes, attr.Pos) |
| for underscore_mem, camelcase_mem, deserialize, is_vector in self.members: |
| fun = camelcase_mem |
| if is_vector: |
| fun += "AsNumpy" |
| |
| attr = getattr(tosa_attrs, fun)() |
| try: |
| attrs[underscore_mem] = deserialize(attr) |
| except TypeError: |
| print("Warning: {0} could not read attribute '{1}'.".format(self.name, underscore_mem)) |
| |
| return attrs |
| |
| |
| class QuantSerializer: |
| def __init__(self, name, members=None): |
| self.name = name |
| self.module = globals()[self.name] |
| self.cls = getattr(self.module, self.name) |
| self.members = [] |
| if members is not None: |
| for mem in members: |
| deserialize = identity |
| underscore_mem = mem |
| camelcase_mem = underscore_to_camel_case(mem) |
| self.members.append((underscore_mem, camelcase_mem, deserialize)) |
| |
| def deserialize(self, op_data): |
| quant_info_type = op_data.QuantInfoType() |
| quant_info = op_data.QuantInfo() |
| quant = {} |
| if quant_info_type: |
| tosa_quant = self.cls() |
| tosa_quant.Init(quant_info.Bytes, quant_info.Pos) |
| for underscore_mem, camelcase_mem, deserialize in self.members: |
| attr = getattr(tosa_quant, camelcase_mem)() |
| try: |
| quant[underscore_mem] = deserialize(attr) |
| except TypeError: |
| print("Warning: {0} could not read quant info '{1}'.".format(self.name, underscore_mem)) |
| |
| return quant |
| |
| |
| is_vec = True |
| pool_attrs = AttrSerializer("PoolAttribute", (("padding", is_vec), ("kernel", is_vec), ("stride", is_vec))) |
| conv_attrs = AttrSerializer("ConvAttribute", (("padding", is_vec), ("stride", is_vec), ("dilation", is_vec))) |
| transpose_conv_attrs = AttrSerializer( |
| "TransposeConvAttribute", (("outpad", is_vec), ("stride", is_vec), ("dilation", is_vec), ("out_shape", is_vec)) |
| ) |
| relun_attrs = AttrSerializer("ReluNAttribute", ("max_int")) |
| axis_attrs = AttrSerializer("AxisAttribute", ("axis",)) |
| reshape_attrs = AttrSerializer("ReshapeAttribute", (("shape", is_vec),)) |
| slice_attrs = AttrSerializer("SliceAttribute", (("begin", is_vec), ("size", is_vec))) |
| tile_attrs = AttrSerializer("TileAttribute", (("multiplies", is_vec),)) |
| resize_attrs = AttrSerializer( |
| "ResizeAttribute", (("output_size", is_vec), ("stride", is_vec), ("offset", is_vec), ("shift")) |
| ) |
| clamp_attrs = AttrSerializer("ClampAttribute", (("min_int"), ("max_int"))) |
| rescale_attrs = AttrSerializer( |
| "RescaleAttribute", |
| ("input_zp", "output_zp", ("multiplier", is_vec), ("shift", is_vec), "scale32", "double_round", "per_channel"), |
| ) |
| mul_attrs = AttrSerializer("MulAttribute", ("shift",)) |
| ars_attrs = AttrSerializer("ArithmeticRightShiftAttribute", ("round",)) |
| condif_attrs = AttrSerializer("CondIfAttribute", (("then_branch"), ("else_branch"))) # TODO these are references |
| while_attrs = AttrSerializer("WhileLoopAttribute", (("cond_branch"), ("body_branch"))) # TODO these are references |
| |
| unary_quant_info = QuantSerializer("UnaryQuantInfo", ("input_zp", "output_zp")) |
| conv_quant_info = QuantSerializer("ConvQuantInfo", ("input_zp", "weight_zp")) |
| matmul_quant_info = QuantSerializer("MatMulQuantInfo", ("a_zp", "b_zp")) |
| pad_quant_info = QuantSerializer("PadQuantInfo", ("input_zp",)) |
| |
| unsupported_tosa_operators = { |
| TosaOp.UNKNOWN, |
| TosaOp.ARGMAX, |
| TosaOp.CONV3D, |
| TosaOp.MATMUL, |
| TosaOp.TRANSPOSE_CONV2D, |
| TosaOp.SIGMOID, |
| TosaOp.TANH, |
| TosaOp.BITWISE_AND, |
| TosaOp.BITWISE_OR, |
| TosaOp.BITWISE_XOR, |
| TosaOp.INTDIV, |
| TosaOp.LOGICAL_AND, |
| TosaOp.LOGICAL_LEFT_SHIFT, |
| TosaOp.LOGICAL_RIGHT_SHIFT, |
| TosaOp.LOGICAL_OR, |
| TosaOp.LOGICAL_XOR, |
| TosaOp.MAXIMUM, |
| TosaOp.MINIMUM, |
| TosaOp.POW, |
| TosaOp.ABS, |
| TosaOp.BITWISE_NOT, |
| TosaOp.CEIL, |
| TosaOp.CLZ, |
| TosaOp.EXP, |
| TosaOp.FLOOR, |
| TosaOp.LOG, |
| TosaOp.LOGICAL_NOT, |
| TosaOp.NEGATE, |
| TosaOp.RECIPROCAL, |
| TosaOp.RSQRT, |
| TosaOp.SELECT, |
| TosaOp.EQUAL, |
| TosaOp.GREATER, |
| TosaOp.GREATER_EQUAL, |
| TosaOp.REDUCE_ANY, |
| TosaOp.REDUCE_ALL, |
| TosaOp.REDUCE_MAX, |
| TosaOp.REDUCE_MIN, |
| TosaOp.REDUCE_PRODUCT, |
| TosaOp.REDUCE_SUM, |
| TosaOp.REVERSE, |
| TosaOp.TILE, |
| TosaOp.GATHER, |
| TosaOp.SCATTER, |
| TosaOp.RESIZE, |
| TosaOp.CAST, |
| TosaOp.CUSTOM, |
| TosaOp.COND_IF, |
| TosaOp.WHILE_LOOP, |
| } |
| |
| |
| TOSA_NO_INDICES = TensorIndices([], [], []) |
| TOSA_IFM_INDICES = TensorIndices([0], [], []) |
| # TOSA_IFM_WEIGHTS_INDICES = TensorIndices([0], [1], []) |
| TOSA_IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2]) |
| TOSA_IFM_IFM2_INDICES = TensorIndices([0, 1], [], []) |
| # TOSA_CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3]) |
| # TOSA_TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3]) |
| TOSA_CONCAT_INDICES = TensorIndices([1, 2], [], []) |
| # TOSA_SPLIT_IFM_INDICES = TensorIndices([1], [], []) |
| # TOSA_BLOCK_LSTM_INDICES = TensorIndices([3], [4], []) |
| |
| |
| tosa_operator_map = { |
| # TosaOp.UNKNOWN: (), |
| # TODO TosaOp.ARGMAX: (Op.ArgMax, axis_attrs, None), |
| TosaOp.AVG_POOL2D: (Op.AvgPool, pool_attrs, unary_quant_info, TOSA_IFM_INDICES), |
| TosaOp.CONV2D: (Op.Conv2DBias, conv_attrs, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES), |
| # TODO TosaOp.CONV3D: |
| TosaOp.DEPTHWISE_CONV2D: (Op.DepthwiseConv2DBias, conv_attrs, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES), |
| TosaOp.FULLY_CONNECTED: (Op.FullyConnected, None, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES), |
| # TODO TosaOp.MATMUL: |
| TosaOp.MAX_POOL2D: (Op.MaxPool, pool_attrs, None, TOSA_IFM_INDICES), |
| # TODO TosaOp.TRANSPOSE_CONV2D: (Op.Conv2DBackpropInput, transpose_conv_attrs, conv_quant_info) |
| TosaOp.CLAMP: (Op.Clamp, clamp_attrs, None, TOSA_IFM_INDICES), |
| TosaOp.RELUN: (Op.ReluN, relun_attrs, None, TOSA_IFM_INDICES), |
| # TODO TosaOp.SIGMOID |
| # TODO TosaOp.TANH |
| TosaOp.ADD: (Op.Add, None, None, TOSA_IFM_IFM2_INDICES), |
| TosaOp.ARITHMETIC_RIGHT_SHIFT: (Op.SHR, ars_attrs, None, TOSA_IFM_IFM2_INDICES), |
| # TODO TosaOp.BITWISE_AND |
| # TODO TosaOp.BITWISE_OR |
| # TODO TosaOp.BITWISE_XOR |
| # TODO TosaOp.INTDIV |
| # TODO TosaOp.LOGICAL_AND |
| # TODO TosaOp.LOGICAL_LEFT_SHIFT |
| # TODO TosaOp.LOGICAL_RIGHT_SHIFT |
| # TODO TosaOp.LOGICAL_OR |
| # TODO TosaOp.LOGICAL_XOR |
| # TODO TosaOp.MAXIMUM |
| # TODO TosaOp.MINIMUM |
| TosaOp.MUL: (Op.Mul, mul_attrs, None, TOSA_IFM_IFM2_INDICES), |
| # TODO TosaOp.POW |
| TosaOp.SUB: (Op.Sub, None, None, TOSA_IFM_IFM2_INDICES), |
| # TODO is table content in input[1] always constant? |
| TosaOp.TABLE: (Op.Table, None, None, TOSA_IFM_INDICES), |
| # TODO TosaOp.ABS |
| # TODO TosaOp.BITWISE_NOT |
| # TODO TosaOp.CEIL |
| # TODO TosaOp.CLZ |
| # TODO TosaOp.EXP |
| # TODO TosaOp.FLOOR |
| # TODO TosaOp.LOG |
| # TODO TosaOp.LOGICAL_NOT |
| # TODO TosaOp.NEGATE |
| # TODO TosaOp.RECIPROCAL |
| # TODO TosaOp.RSQRT |
| # TODO TosaOp.SELECT |
| # TODO TosaOp.EQUAL |
| # TODO TosaOp.GREATER |
| # TODO TosaOp.GREATER_EQUAL |
| # TODO TosaOp.REDUCE_ANY |
| # TODO TosaOp.REDUCE_ALL |
| # TODO TosaOp.REDUCE_MAX |
| # TODO TosaOp.REDUCE_MIN |
| # TODO TosaOp.REDUCE_PRODUCT |
| # TODO TosaOp.REDUCE_SUM |
| TosaOp.CONCAT: (Op.Concat, axis_attrs, None, TOSA_CONCAT_INDICES), |
| # TODO Is the padding intended to be dynamic input, TOSA spec state it as attribute |
| # Handled as for TFLite for now |
| TosaOp.PAD: (Op.Pad, None, pad_quant_info, TOSA_IFM_INDICES), |
| TosaOp.RESHAPE: (Op.Reshape, reshape_attrs, None, TOSA_IFM_INDICES), |
| # TODO TosaOp.REVERSE |
| TosaOp.SLICE: (Op.SplitSliceRead, slice_attrs, None, TOSA_IFM_INDICES), |
| # TODO TosaOp.TILE |
| TosaOp.TRANSPOSE: ( |
| Op.Transpose, |
| None, |
| None, |
| TOSA_IFM_IFM2_INDICES, |
| ), # TODO Is the perms intended to be dynamic input, TOSA spec state it as attribute |
| # TODO TosaOp.GATHER |
| # TODO TosaOp.SCATTER |
| # TODO TosaOp.RESIZE |
| # TODO TosaOp.CAST |
| TosaOp.RESCALE: (Op.Rescale, rescale_attrs, None, TOSA_IFM_INDICES), |
| TosaOp.CONST: (Op.Const, None, None, TOSA_NO_INDICES), |
| TosaOp.IDENTITY: (Op.Identity, None, None, TOSA_IFM_INDICES), |
| # TODO TosaOp.CUSTOM |
| # TODO TosaOp.COND_IF |
| # TODO TosaOp.WHILE_LOOP |
| } |
| |
| tosa_operator_inv_map = {v[0]: (k, v[1]) for k, v in tosa_operator_map.items()} |
| |
| |
| def tosa_type_name(builtin): |
| return next(k for k, v in vars(TosaOp).items() if v == builtin) |
| |
| |
| # TODO will return UNKNOWN for the once that have not yet been defined in tosa_operator_map |
| def optype_to_tosa_op_type(op_type): |
| if op_type in tosa_operator_inv_map: |
| return tosa_type_name(tosa_operator_inv_map[op_type][0]) |
| else: |
| return TosaOp.UNKNOWN |