MLBEDSW-3148: Refactor Operation

- op.type is now an enum instead of a string
- Removed unused operator codes
- Refactored some attributes like npu_block_type, fused_activation_function
- Refactored operator index calculation
- Refactored a number of operator sets

Change-Id: I641f65ee375794b7aec42abc0664251ae37d78e8
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 1481887..a2b67df 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -15,10 +15,11 @@
 # limitations under the License.
 # Description:
 # Internal representation of a Neural Network Operation.
-import enum
+from collections import namedtuple
+from enum import Enum
 
 
-class NpuBlockType(enum.Enum):
+class NpuBlockType(Enum):
     Default = 0
     ConvolutionMxN = 1
     VectorProduct = 2
@@ -28,10 +29,266 @@
     ReduceSum = 6
 
 
+# Classifies operators of type Custom
+class CustomType(Enum):
+    ThirdPartyOp = 0  # Third party custom op
+    NpuOp = 1  # NPU op
+    ExistingNpuOp = 2  # NPU op that was part of the input network
+
+
+TensorIndices = namedtuple("TensorIndices", ["ifms", "weights", "biases"])
+
+NO_INDICES = TensorIndices([], [], [])
+IFM_INDICES = TensorIndices([0], [], [])
+IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
+IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
+IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
+CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
+TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
+CONCAT_INDICES = TensorIndices([1, 2], [], [])
+SPLIT_IFM_INDICES = TensorIndices([1], [], [])
+BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
+
+
+# Static information related to operation codes
+class OperatorInfo:
+    __slots__ = ("id", "block_type", "indices", "is_unary")
+    _id = 0
+
+    def __init__(self, block_type=NpuBlockType.Default, indices=NO_INDICES, is_unary=False):
+        OperatorInfo._id += 1
+        self.id = OperatorInfo._id
+        self.block_type = block_type
+        self.indices = indices  # Indices of the different tensor purposes
+        self.is_unary = is_unary  # Classifies elementwise operators
+
+
+# Internally used operation codes
+class Op(Enum):
+    Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
+    Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+    AddN = OperatorInfo()
+    Any = OperatorInfo()
+    ArgMax = OperatorInfo()
+    ArgMin = OperatorInfo()
+    AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+    BatchMatMul = OperatorInfo()
+    BatchToSpaceND = OperatorInfo()
+    BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=BLOCK_LSTM_INDICES)
+
+    CLZ = OperatorInfo(
+        block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True
+    )  # NPU specific operation
+    Call = OperatorInfo()
+    Cast = OperatorInfo()
+    Ceil = OperatorInfo()
+    Concat = OperatorInfo(indices=CONCAT_INDICES)
+    ConcatEmbeddings = OperatorInfo()
+    ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES)
+    ConcatTFLite = OperatorInfo()
+    Const = OperatorInfo()  # Constant tensor, only used in CPU subgraphs
+    Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
+    Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES)
+    Conv2DBackpropInputSwitchedBias = OperatorInfo(
+        block_type=NpuBlockType.ConvolutionMxN, indices=TRANSPOSE_CONV_INDICES
+    )
+    Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_BIAS_INDICES)
+    Cos = OperatorInfo()
+    Custom = OperatorInfo()  # Custom 3rd party operator, only used in CPU subgraphs
+    CustomNpuOp = OperatorInfo()  # NPU custom operator, only used in CPU subgraphs
+    DMA = OperatorInfo()
+    Delegate = OperatorInfo()
+    Densify = OperatorInfo()
+    DepthToSpace = OperatorInfo()
+    DepthwiseConv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionDepthWise, indices=IFM_WEIGHTS_BIAS_INDICES)
+    Dequantize = OperatorInfo()
+    Div = OperatorInfo()
+    Elu = OperatorInfo()
+    EmbeddingLookup = OperatorInfo()
+    EmbeddingLookupSparse = OperatorInfo()
+    Equal = OperatorInfo()
+    Exp = OperatorInfo()
+    ExpandDims = OperatorInfo(indices=IFM_INDICES)
+    FakeQuantWithMinMaxArgs = OperatorInfo()
+    Fill = OperatorInfo()
+    Floor = OperatorInfo()
+    FloorDiv = OperatorInfo()
+    FloorMod = OperatorInfo()
+    FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_BIAS_INDICES)
+    GatherNd = OperatorInfo()
+    GatherV2 = OperatorInfo()
+    Greater = OperatorInfo()
+    GreaterEqual = OperatorInfo()
+    HardSwish = OperatorInfo()
+    HashtableLookup = OperatorInfo()
+    Identity = OperatorInfo()
+    If = OperatorInfo()
+    L2Norm = OperatorInfo()
+    L2Pool2D = OperatorInfo()
+    LRN = OperatorInfo()
+    LSHProjection = OperatorInfo()
+    LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
+    Less = OperatorInfo()
+    LessEqual = OperatorInfo()
+    Log = OperatorInfo()
+    LogSoftmax = OperatorInfo()
+    LogicalAnd = OperatorInfo()
+    LogicalNot = OperatorInfo()
+    LogicalOr = OperatorInfo()
+    Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    LUT = OperatorInfo()  # NPU specific, operator has LUT, only used in fused activation functions
+    MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    MatrixDiag = OperatorInfo()
+    MatrixSetDiag = OperatorInfo()
+    Max = OperatorInfo()
+    MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+    Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+    Mean = OperatorInfo()
+    Min = OperatorInfo()
+    Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+    MirrorPad = OperatorInfo()
+    Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+    Neg = OperatorInfo()
+    NonMaxSuppressionV4 = OperatorInfo()
+    NonMaxSuppressionV5 = OperatorInfo()
+    NotEqual = OperatorInfo()
+    OneHot = OperatorInfo()
+    Pack = OperatorInfo()
+    PackReshaped = OperatorInfo(indices=IFM_INDICES)
+    Pad = OperatorInfo()
+    PadV2 = OperatorInfo()
+    Placeholder = OperatorInfo()  # Only used in CPU subgraphs
+    Pow = OperatorInfo()
+    Prelu = OperatorInfo()
+    Prod = OperatorInfo()
+    Quantize = OperatorInfo()
+    QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+    QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
+    QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+    QuantizedReshape = OperatorInfo(indices=IFM_INDICES)
+    Range = OperatorInfo()
+    Rank = OperatorInfo()
+    ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=IFM_INDICES)
+    Relu = OperatorInfo(indices=IFM_INDICES)
+    Relu6 = OperatorInfo(indices=IFM_INDICES)
+    ReluN1To1 = OperatorInfo(indices=IFM_INDICES)
+    Reshape = OperatorInfo(indices=IFM_INDICES)
+    ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+    ResizeNearestNeighbor = OperatorInfo()
+    ReverseSequence = OperatorInfo()
+    ReverseV2 = OperatorInfo()
+    Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    Round = OperatorInfo()
+    Rsqrt = OperatorInfo()
+    SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)  # NPU specific operation
+    SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)  # NPU specific operation
+    ScatterNd = OperatorInfo()
+    SegmentSum = OperatorInfo()
+    Select = OperatorInfo()
+    SelectV2 = OperatorInfo()
+    Shape = OperatorInfo()
+    Sigmoid = OperatorInfo(indices=IFM_INDICES)
+    SignBit = OperatorInfo()
+    Sin = OperatorInfo()
+    SkipGram = OperatorInfo()
+    Slice = OperatorInfo(indices=IFM_INDICES)
+    Softmax = OperatorInfo()
+    SpaceToBatchND = OperatorInfo()
+    SpaceToDepth = OperatorInfo()
+    SparseToDense = OperatorInfo()
+    Split = OperatorInfo(indices=SPLIT_IFM_INDICES)
+    SplitSliceRead = OperatorInfo(indices=IFM_INDICES)
+    SplitV = OperatorInfo(indices=IFM_INDICES)
+    Sqrt = OperatorInfo()
+    Square = OperatorInfo()
+    SquaredDifference = OperatorInfo()
+    Squeeze = OperatorInfo(indices=IFM_INDICES)
+    StridedSlice = OperatorInfo(indices=IFM_INDICES)
+    StridedSliceOptions = OperatorInfo()
+    Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+    SubgraphInput = OperatorInfo()  # Only used in CPU subgraphs
+    Sum = OperatorInfo()
+    Svdf = OperatorInfo()
+    Tanh = OperatorInfo(indices=IFM_INDICES)
+    Tile = OperatorInfo()
+    TopKV2 = OperatorInfo()
+    Transpose = OperatorInfo()
+    UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    Unique = OperatorInfo()
+    Unpack = OperatorInfo()
+    UnpackReshaped = OperatorInfo(indices=IFM_INDICES)
+    Where = OperatorInfo()
+    While = OperatorInfo()
+    ZerosLike = OperatorInfo()
+
+    @property
+    def info(self):
+        return self.value
+
+    @property
+    def npu_block_type(self):
+        return self.info.block_type
+
+    def is_conv2d_op(self):
+        return self.info.block_type == NpuBlockType.ConvolutionMxN
+
+    def is_depthwise_conv2d_op(self):
+        return self.info.block_type == NpuBlockType.ConvolutionDepthWise
+
+    def is_pool_op(self):
+        return self.info.block_type == NpuBlockType.Pooling
+
+    def is_maxpool_op(self):
+        return self in (Op.MaxPool, Op.QuantizedMaxPool)
+
+    def is_avgpool_op(self):
+        return self in (Op.QuantizedAvgPool, Op.AvgPool)
+
+    def is_elementwise_op(self):
+        return self.info.block_type == NpuBlockType.ElementWise
+
+    def is_unary_elementwise_op(self):
+        return self.info.block_type == NpuBlockType.ElementWise and self.info.is_unary
+
+    def is_binary_elementwise_op(self):
+        return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary
+
+    def is_relu_op(self):
+        return self in (Op.Relu, Op.Relu6, Op.ReluN1To1)
+
+    def is_activation_op(self):
+        return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT)
+
+    def is_split_op(self):
+        return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped)
+
+    def is_concat_op(self):
+        return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped)
+
+    def needs_bias(self):
+        return bool(self.info.indices.biases)
+
+    @classmethod
+    def op_set(cls, predicate):
+        # Returns the set of all operator codes that fulfill the given predicate
+        return {op_type for op_type in Op if predicate(op_type)}
+
+    def __str__(self):
+        return self.name
+
+    __repr__ = __str__
+
+    def __lt__(self, other):
+        return self.value.id < other.value.id
+
+
 def create_avgpool_nop(name):
-    op = Operation("AvgPool", name)
+    op = Operation(Op.AvgPool, name)
     op.attrs["padding"] = b"VALID"
-    op.attrs["npu_block_type"] = NpuBlockType.Pooling
     op.attrs["stride_w"] = 1
     op.attrs["stride_h"] = 1
     op.attrs["filter_width"] = 1
@@ -70,6 +327,9 @@
         "flops",
         "scheduled_pass",
         "run_on_npu",
+        "activation",
+        "memory_function",
+        "forced_output_quantization",
         "activation_lut",
     )
 
@@ -81,6 +341,13 @@
         self.outputs = []
         self.flops = 0
         self.run_on_npu = True
+        # Fused activation function. If not none: operator code.
+        self.activation = None
+        # Fused memory function, if not None: operator code
+        self.memory_function = None
+        # If not none: contains QuantizationParameters to be used as output quantization
+        # (which overrides the ofm tensor's quantization), used in LUT
+        self.forced_output_quantization = None
         self.scheduled_pass = None
         self.op_index = None  # input network operator index
         self.activation_lut = None
@@ -92,173 +359,95 @@
         res.inputs = list(self.inputs)
         res.outputs = list(self.outputs)
         res.flops = self.flops
+        res.run_on_npu = self.run_on_npu
+        res.activation = self.activation
+        res.memory_function = self.memory_function
+        res.forced_output_quantization = self.forced_output_quantization
         res.scheduled_pass = self.scheduled_pass
         res.op_index = None  # not relevant as not part of input network
 
         return res
 
     def __str__(self):
-        return "<nng.Operation '%s' type=%s>" % (self.name, self.type)
+        return "<nng.Operation '{}' type={}>".format(self.name, self.type)
 
     __repr__ = __str__
 
-    def get_ifm_ifm2_weight_bias_ofm_indices(self):
-        ifm_idx = -1
-        ifm2_idx = -1
-        weight_idx = -1
-        bias_idx = -1
-        ofm_idx = -1
-        npu_block_type = self.attrs.get("npu_block_type", NpuBlockType.Default)
-        if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise):
-            ifm_idx = 0
-            weight_idx = 1
-            ofm_idx = 0
-
-            if self.type in ("Conv2DBiasAct", "DepthwiseConv2dBiasAct", "TransposeConvAct"):
-                if len(self.inputs) >= 3:
-                    bias_idx = 2
-
-            elif self.type == "Conv2DBackpropInputSwitchedBias":
-                bias_idx = 3
-
-        elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
-            ifm_idx = 0
-            ofm_idx = 0
-        elif npu_block_type == NpuBlockType.VectorProduct:
-            ifm_idx = 0
-            weight_idx = 1
-            ofm_idx = 0
-
-            if self.type == "FullyConnectedAct":
-                if len(self.inputs) >= 3:
-                    bias_idx = 2
-
-            if self.type == "BlockLSTM":
-                ifm_idx = 3
-                weight_idx = 4
-                ofm_idx = 6
-
-        elif npu_block_type == NpuBlockType.ElementWise:
-            ifm_idx = 0
-            ifm2_idx = 1
-            ofm_idx = 0
-
-            # LeakyRelu, Abs and CLZ have a single IFM
-            if self.type in ("LeakyRelu", "Abs", "CLZ"):
-                ifm2_idx = -1
-
-        elif self.type == "Conv2DBackpropInput":
-            ifm_idx = 2
-            weight_idx = 1
-            ofm_idx = 0
-
-        elif self.type in ("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims", "Sigmoid", "Tanh"):
-            ifm_idx = 0
-            ofm_idx = 0
-
-        elif self.is_split_op():
-            ifm_idx = 0
-            ofm_idx = 0
-            if self.type == "Split":
-                ifm_idx = 1
-
-        elif self.is_concat_op():
-            ifms, _ = self.get_concat_inputs_axis()
-            ifm_idx = self.inputs.index(ifms[0])
-            if len(ifms) > 1:
-                ifm2_idx = self.inputs.index(ifms[1])
-            ofm_idx = 0
-
-        return ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx
-
     def get_ifm_ifm2_weights_ofm(self):
-        ifm_tensor = None
-        ifm2_tensor = None
-        weight_tensor = None
-        ofm_tensor = None
-
-        ifm_idx, ifm2_idx, weight_idx, _, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
-        if ifm_idx != -1:
-            ifm_tensor = self.inputs[ifm_idx]
-        if ifm2_idx != -1:
-            ifm2_tensor = self.inputs[ifm2_idx]
-        if weight_idx != -1:
-            weight_tensor = self.inputs[weight_idx]
-        if ofm_idx != -1:
-            ofm_tensor = self.outputs[ofm_idx]
-
-        return ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor
+        return self.ifm, self.ifm2, self.weights, self.ofm
 
     def get_ifm_weights_biases_ofm(self):
-        ifm_tensor = None
-        weight_tensor = None
-        bias_tensor = None
-        ofm_tensor = None
-
-        ifm_idx, _, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
-        if ifm_idx != -1:
-            ifm_tensor = self.inputs[ifm_idx]
-        if weight_idx != -1:
-            weight_tensor = self.inputs[weight_idx]
-        if bias_idx != -1:
-            bias_tensor = self.inputs[bias_idx]
-        if ofm_idx != -1:
-            ofm_tensor = self.outputs[ofm_idx]
-
-        return ifm_tensor, weight_tensor, bias_tensor, ofm_tensor
+        return self.ifm, self.weights, self.bias, self.ofm
 
     def get_ifm_ifm2_weights_biases_ofm(self):
-        ifm_tensor = None
-        ifm2_tensor = None
-        weight_tensor = None
-        bias_tensor = None
-        ofm_tensor = None
+        return self.ifm, self.ifm2, self.weights, self.bias, self.ofm
 
-        ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
-        if ifm_idx != -1:
-            ifm_tensor = self.inputs[ifm_idx]
-        if ifm2_idx != -1:
-            ifm2_tensor = self.inputs[ifm2_idx]
-        if weight_idx != -1:
-            weight_tensor = self.inputs[weight_idx]
-        if bias_idx != -1:
-            bias_tensor = self.inputs[bias_idx]
-        if ofm_idx != -1:
-            ofm_tensor = self.outputs[ofm_idx]
+    def get_ifm_ofm(self):
+        return self.ifm, self.ofm
 
-        return ifm_tensor, ifm2_tensor, weight_tensor, bias_tensor, ofm_tensor
+    @property
+    def ifm(self):
+        # Gets the IFM tensor, or None if not applicable
+        return self.get_input(self.type.info.indices.ifms, 0)
 
-    def get_ofm(self):
-        _, _, _, ofm = self.get_ifm_ifm2_weights_ofm()
-        return ofm
+    @property
+    def ifm2(self):
+        # Gets the IFM2 tensor, or None if not applicable
+        return self.get_input(self.type.info.indices.ifms, 1)
 
-    def is_concat_op(self):
-        return self.type in ("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped")
+    @property
+    def bias(self):
+        # Gets the bias tensor, or None if not applicable
+        return self.get_input(self.type.info.indices.biases, 0)
+
+    @property
+    def weights(self):
+        # Gets the weight tensor, or None if not applicable
+        return self.get_input(self.type.info.indices.weights, 0)
+
+    def get_ifm_tensors(self):
+        # Gets the IFM tensors, or empty list if not applicable
+        return self._index_list_to_tensors(self.type.info.indices.ifms)
+
+    def get_weight_tensors(self):
+        # Gets the weight tensors, or empty list if not applicable
+        return self._index_list_to_tensors(self.type.info.indices.weights)
+
+    def get_bias_tensors(self):
+        # Gets the bias tensors, or empty list if not applicable
+        return self._index_list_to_tensors(self.type.info.indices.biases)
+
+    def _index_list_to_tensors(self, index_list):
+        return [self.inputs[ix] for ix in index_list if ix < len(self.inputs)]
+
+    def get_input(self, index_list, ix):
+        if ix >= len(index_list):
+            return None
+        if index_list[ix] >= len(self.inputs):
+            return None
+        return self.inputs[index_list[ix]]
+
+    @property
+    def ofm(self):
+        # Gets the OFM tensor, or None if not applicable
+        return self.outputs[0] if self.outputs else None
 
     def get_concat_inputs_axis(self):
-        assert self.is_concat_op()
+        assert self.type.is_concat_op()
 
-        if self.type == "ConcatV2":
-            axis_tensor = self.inputs[-1]
-            inputs = self.inputs[:-1]
-        elif self.type == "Concat":
+        if self.type == Op.Concat:
             axis_tensor = self.inputs[0]
             inputs = self.inputs[1:]
-        elif self.type == "QuantizedConcat":
-            axis_tensor = self.inputs[0]
-            inputs = self.inputs[1:]
-            inputs = inputs[: len(inputs) // 3]  # Skip min/max
-
-        if self.type == "ConcatTFLite":
+        elif self.type == Op.ConcatTFLite:
             inputs = self.inputs
             axis = self.attrs["axis"]
-        elif self.type == "PackReshaped":
+        elif self.type == Op.PackReshaped:
             # Requires fixup_pack_input to be called before this point
             inputs = self.inputs
             axis = self.attrs["axis"]
             assert len(self.inputs) == self.attrs["values_count"]
         else:
-            assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == "Const"
+            assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == Op.Const
             axis = int(axis_tensor.values)
 
         return inputs, axis
@@ -267,33 +456,30 @@
         _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
         return dilation_h, dilation_w
 
-    def is_split_op(self):
-        return self.type in ("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped")
-
     def get_split_inputs_axis(self):
-        assert self.is_split_op()
+        assert self.type.is_split_op()
 
         offset_start = None
         offset_end = None
         axis = None
-        if self.type == "Split":
+        if self.type == Op.Split:
             num_splits = self.attrs.get("num_splits")
             axis_tens = self.inputs[0]
-            assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
+            assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
             axis = int(axis_tens.values)
             input_tens = self.inputs[1]
             outputs = self.outputs
             assert num_splits == len(outputs)
 
-        elif self.type == "SplitV":
+        elif self.type == Op.SplitV:
             num_splits = self.attrs.get("num_splits")
             input_tens = self.inputs[0]
             size_tens = self.inputs[1]
-            assert len(size_tens.ops) == 1 and size_tens.ops[0].type == "Const"
+            assert len(size_tens.ops) == 1 and size_tens.ops[0].type == Op.Const
             sizes = size_tens.values
 
             axis_tens = self.inputs[2]
-            assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
+            assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
             axis = int(axis_tens.values)
 
             for idx, size in enumerate(sizes):
@@ -306,7 +492,7 @@
             assert num_splits == len(outputs)
             assert sum(sizes) == input_tens.shape[axis]
 
-        elif self.type == "Slice":
+        elif self.type == Op.Slice:
             input_tens, begin_tens, size_tens = self.inputs
             outputs = self.outputs
             offset_start = [0] * len(input_tens.shape)
@@ -318,7 +504,7 @@
                     offset_start[idx] = begin_tens.values[idx]
                     offset_end[idx] = size_tens.values[idx] + offset_start[idx]
 
-        elif self.type == "StridedSlice":
+        elif self.type == Op.StridedSlice:
             input_tens, begin_tens, end_tens, strides_tens = self.inputs
             outputs = self.outputs
             out_tens = outputs[0]
@@ -336,7 +522,7 @@
             assert len(input_tens.shape) == len(out_tens.shape)
             offset_start = get_slice_offsets(input_tens.shape, begin_tens, begin_mask, is_begin=True)
             offset_end = get_slice_offsets(input_tens.shape, end_tens, end_mask, is_begin=False)
-        elif self.type == "UnpackReshaped":
+        elif self.type == Op.UnpackReshaped:
             # Requires fixup_unpack_output to be called before this point
             input_tens = self.inputs[0]
             outputs = self.outputs
@@ -350,7 +536,7 @@
         return input_tens, outputs, axis, offset_start, offset_end
 
     def set_activation_lut(self, lut_tensor):
-        self.attrs["fused_activation_function"] = "LUT"
+        self.activation = Op.LUT
         self.activation_lut = lut_tensor
         self.add_input_tensor(lut_tensor)
 
@@ -372,13 +558,7 @@
         tens.ops = [self]
         self.outputs = [tens]
 
-    def needs_bias(self):
-        return self.type in (
-            "Conv2DBiasAct",
-            "DepthwiseConv2dBiasAct",
-            "Conv2DBackpropInputSwitchedBias",
-            "FullyConnectedAct",
-        )
-
     def get_output_quantization(self):
-        return self.attrs.get("forced_output_quantization", self.get_ofm().quantization)
+        if self.forced_output_quantization is not None:
+            return self.forced_output_quantization
+        return self.ofm.quantization