MLBEDSW-3148: Refactor Operation

- op.type is now an enum instead of a string
- Removed unused operator codes
- Refactored some attributes like npu_block_type, fused_activation_function
- Refactored operator index calculation
- Refactored a number of operator sets

Change-Id: I641f65ee375794b7aec42abc0664251ae37d78e8
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py
index 7391d71..3ef4d1b 100644
--- a/ethosu/vela/architecture_features.py
+++ b/ethosu/vela/architecture_features.py
@@ -574,7 +574,7 @@
         Cortex-Mx.intercept=<some float value>
         Cortex-Mx.slope=<some float value>
         """
-        section = "CpuPerformance." + op.type
+        section = "CpuPerformance." + op.type.name
         if self.vela_config is not None and section in self.vela_config:
             op_config = self.vela_config[section]
             try:
diff --git a/ethosu/vela/extract_npu_subgraphs.py b/ethosu/vela/extract_npu_subgraphs.py
index c0430b5..e08392d 100644
--- a/ethosu/vela/extract_npu_subgraphs.py
+++ b/ethosu/vela/extract_npu_subgraphs.py
@@ -25,17 +25,19 @@
 from .nn_graph import Pass
 from .nn_graph import PassPlacement
 from .nn_graph import Subgraph
+from .operation import CustomType
 from .operation import NpuBlockType
+from .operation import Op
 from .operation import Operation
 
 
 def make_npu_call_op_pass(npu_subgraph):
-    op = Operation("NpuOp", "call_" + npu_subgraph.name)
+    op = Operation(Op.CustomNpuOp, "call_" + npu_subgraph.name)
     op.attrs["subgraph"] = npu_subgraph
+    op.attrs["custom_type"] = CustomType.NpuOp
     ps = Pass(op.name, PassPlacement.MemoryOnly, False, NpuBlockType.Default)
     ps.ops = [op]
     ps.primary_op = op
-    op.attrs["npu_block_type"] = ps.npu_block_type
     op.scheduled_pass = ps
 
     # Inputs and outputs filled in later as we cut the graphs
@@ -69,14 +71,13 @@
 def rewrite_tensor_cpu_producer_npu_consumers(
     orig_tens, call_ps, startup_init_ps, npu_subgraph, cpu_subgraph, subgraph_for_pass
 ):
-    is_const = orig_tens.ops[0].type == "Const"
+    is_const = orig_tens.ops[0].type == Op.Const
     new_tens = orig_tens.clone("_npu")
 
-    op_type = "SubgraphInput"
+    op_type = Op.SubgraphInput
     if is_const:
-        op_type = "Const"
+        op_type = Op.Const
     op = Operation(op_type, orig_tens.name + "_input")
-    op.attrs["npu_block_type"] = NpuBlockType.Default
     op.scheduled_pass = startup_init_ps
     op.set_output_tensor(new_tens)
     startup_init_ps.ops.append(op)
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 4f435dc..1966a82 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -32,6 +32,7 @@
 from .numeric_util import round_away_zero
 from .operation import create_avgpool_nop
 from .operation import NpuBlockType
+from .operation import Op
 from .operation import Operation
 from .softmax import SoftMax
 from .tensor import create_const_tensor
@@ -39,33 +40,9 @@
 from .tensor import QuantizationParameters
 from .tensor import Tensor
 
-passthrough_nodes = set(("Identity",))
+passthrough_nodes = set((Op.Identity,))
 
-conv_op = set(("Conv2D", "QuantizedConv2D", "Conv2DBackpropInputSwitchedBias", "Conv2DBiasAct"))
-fc_op = set(
-    (
-        "MatMul",
-        "QuantizedMatMul",
-        "BlockLSTM",
-        "RnnAct",
-        "UnidirectionalSequenceRnnAct",
-        "BidirectionalSequenceRnnAct",
-        "LstmAct",
-        "UnidirectionalSequenceLstmAct",
-        "BidirectionalSequenceLstmAct",
-        "FullyConnectedAct",
-    )
-)
-depthwise_op = set(("DepthwiseConv2dNative", "DepthwiseConv2dBiasAct",))
-pool_op = set(
-    ("AvgPool", "MaxPool", "QuantizedAvgPool", "QuantizedMaxPool", "AvgPoolAct", "MaxPoolAct", "ResizeBilinear")
-)
-reduce_sum_ops = set(("ReduceSum",))
-binary_elementwise_op = set(("AddAct", "MulAct", "SubAct", "Maximum", "Minimum"))
-elementwise_op = set(("LeakyRelu", "Abs", "CLZ", "SHL", "SHR")) | binary_elementwise_op
-relu_ops = set(("Relu", "Relu6", "ReluN1To1"))
-activation_ops = set(("Sigmoid", "Tanh")) | relu_ops
-memory_only_ops = set(("Reshape",))
+memory_only_ops = set((Op.Reshape,))
 
 
 def remove_passthrough_tensor(tens, arch, nng):
@@ -76,7 +53,7 @@
 
 
 def rewrite_concat(tens, arch, nng):
-    if len(tens.ops) == 1 and tens.ops[0].is_concat_op():
+    if len(tens.ops) == 1 and tens.ops[0].type.is_concat_op():
         concat_op = tens.ops[0]
         if tens != concat_op.outputs[0]:
             return tens  # don't attempt to rewrite the min/max outputs of QuantizedConcat
@@ -90,7 +67,7 @@
         tens.ops = []
         offset = 0
         for idx, inp in enumerate(inputs):
-            new_op = Operation("ConcatSliceWrite", concat_op.name + str(idx))
+            new_op = Operation(Op.ConcatSliceWrite, concat_op.name + str(idx))
             new_op.inputs = [inp]
             new_op.outputs = [tens]
             new_op.attrs["concat_axis"] = axis
@@ -116,7 +93,7 @@
 
 def rewrite_split(tens, arch, nng):
 
-    if len(tens.ops) == 1 and tens.ops[0].is_split_op():
+    if len(tens.ops) == 1 and tens.ops[0].type.is_split_op():
         split_op = tens.ops[0]
 
         # Not supported so leave it and run on CPU
@@ -126,7 +103,7 @@
         inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
 
         tens.ops = []
-        new_op = Operation("SplitSliceRead", split_op.name)
+        new_op = Operation(Op.SplitSliceRead, split_op.name)
         new_op.inputs = [inp]
 
         # For Split the offset cannot be extracted from the tensor so it has to
@@ -206,10 +183,10 @@
 
 
 def fixup_conv2d_backprop(op, arch, nng):
-    if op.type == "Conv2DBackpropInput":
+    if op.type == Op.Conv2DBackpropInput:
         # flip the inputs
         op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
-        op.type = "Conv2DBackpropInputSwitchedBias"
+        op.type = Op.Conv2DBackpropInputSwitchedBias
 
         # Update strides
         op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
@@ -219,9 +196,8 @@
 
 # Convert the op to an elementwise add
 def convert_resizebilinear_1x1_to_add(op):
-    op.type = "AddAct"
+    op.type = Op.Add
     op.name = op.name + "_add"
-    op.attrs.update({"npu_block_type": NpuBlockType.ElementWise})
     op.attrs["resizebilinear"] = True
     # Create an input tensor filled with zeros
     shape = op.outputs[0].shape
@@ -296,11 +272,11 @@
 
 
 def fixup_resizebilinear(op, arch, nng):
-    if op.type == "ResizeBilinear" and op.run_on_npu:
+    if op.type == Op.ResizeBilinear and op.run_on_npu:
         if op.inputs[0].shape == op.outputs[0].shape:
             # Bypass nop resizebilinear
             op.inputs = op.inputs[:1]
-            op.type = "Identity"
+            op.type = Op.Identity
         elif op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1:
             convert_resizebilinear_1x1_to_add(op)
         else:
@@ -310,16 +286,16 @@
 
 
 def convert_nop_split_to_identity(op, arch, nng):
-    if op.type == "Split" and op.attrs.get("num_splits") == 1:
+    if op.type == Op.Split and op.attrs.get("num_splits") == 1:
         # the list comprehension should return a list with a single tensor
         # if it shouldn't, remove_passthrough_tensor will fail appropriately
         op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
-        op.type = "Identity"
+        op.type = Op.Identity
     return op
 
 
 def fixup_fully_connected_input(op, arch, nng):
-    if op.type == "FullyConnectedAct":
+    if op.type == Op.FullyConnected:
         inp = op.inputs[0]
         weights = op.inputs[1]
 
@@ -337,7 +313,7 @@
 
 
 def convert_batched_fc_to_conv(op, arch, nng):
-    if op.type == "FullyConnectedAct":
+    if op.type == Op.FullyConnected:
         ifm = op.inputs[0]
         ofm = op.outputs[0]
         # Check if the FC is 2D and first dimension indicates batching
@@ -348,14 +324,11 @@
 
             # Convert to convolution
             op.name += "_conv"
-            op.type = "Conv2DBiasAct"
-            faf = op.attrs.get("fused_activation_function", None)
+            op.type = Op.Conv2DBias
             op.attrs = {
                 "dilation": (1, 1, 1, 1),
                 "dilation_h_factor": 1,
                 "dilation_w_factor": 1,
-                "fused_activation_function": faf,
-                "npu_block_type": NpuBlockType.ConvolutionMxN,
                 "padding": b"SAME",
                 "stride_h": 1,
                 "stride_w": 1,
@@ -364,7 +337,7 @@
 
             prev_op = ifm.ops[0]
             desired_shape = [1, h, w, ifm.shape[-1]]
-            if len(ifm.consumer_list) == 1 and prev_op is not None and prev_op.type == "Reshape":
+            if len(ifm.consumer_list) == 1 and prev_op is not None and prev_op.type == Op.Reshape:
                 # There is a preceding Reshape
                 # Compare input of prev_op and input of op, to see if prev_op can be removed
                 ifm_prev_op = prev_op.inputs[0]
@@ -391,7 +364,7 @@
             if (
                 len(ofm.consumer_list) == 1
                 and ofm.consumer_list[0] is not None
-                and ofm.consumer_list[0].type == "Reshape"
+                and ofm.consumer_list[0].type == Op.Reshape
             ):
                 # There is a subsequent Reshape
                 # Compare desired shape and output of consumer op, to see if consumer op can be removed
@@ -408,7 +381,7 @@
 
 
 def fixup_pack_input(op, arch, nng):
-    if op.type == "Pack":
+    if op.type == Op.Pack:
         # Pack is also referred to as Stack
         # Requires the rewrite_concat function to be called on the op afterwards
         axis = int(op.attrs["axis"])
@@ -421,24 +394,22 @@
             reshape_out = inp.clone("_reshaped")
             reshape_out.set_all_shapes(desired_shape)
 
-            reshape_op = Operation("Reshape", "{}{}_reshape".format(op.name, idx))
+            reshape_op = Operation(Op.Reshape, "{}{}_reshape".format(op.name, idx))
             reshape_op.attrs["new_shape"] = desired_shape
             reshape_op.inputs = [inp, new_shape_tens]
             reshape_op.set_output_tensor(reshape_out)
 
             op.inputs[idx] = reshape_out
 
-        op.type = "PackReshaped"
+        op.type = Op.PackReshaped
 
     return op
 
 
 def unfuse_activation_function(op, arch, nng):
-    unfuse_ops = ("ConcatTFLite",)
-    if op.type in unfuse_ops and op.run_on_npu and op.attrs.get("fused_activation_function", None) is not None:
-        act = op.attrs["fused_activation_function"]
-        del op.attrs["fused_activation_function"]
-        act_op = Operation(act, op.name + act)
+    if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
+        act_op = Operation(op.activation, op.name + op.activation.name)
+        op.activation = None
         out_tens = op.outputs[0]
         intermediate_tens = out_tens.clone("_act_intermediate")
         act_op.set_output_tensor(out_tens)
@@ -450,12 +421,12 @@
 
 def fixup_unpack_output(tens, arch, nng):
     op = tens.ops[0]
-    if op.type in set(("Unpack", "StridedSlice")):
+    if op.type in set((Op.Unpack, Op.StridedSlice)):
         # Unpack is also referred to as Unstack
         # Requires the rewrite_split function to be called on the op afterwards
 
         reshape_input_shape = tens.shape
-        if op.type == "StridedSlice":
+        if op.type == Op.StridedSlice:
             new_axis_mask = op.attrs["new_axis_mask"]
             shrink_axis_mask = op.attrs["shrink_axis_mask"]
             ellipsis_mask = op.attrs["ellipsis_mask"]
@@ -494,7 +465,7 @@
                 op.attrs["new_axis_mask"] = 0
         else:
             axis = int(op.attrs["axis"])
-            op.type = "UnpackReshaped"
+            op.type = Op.UnpackReshaped
             reshape_input_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
 
         # Construct 1 shape tensor to be used by all inserted reshape ops
@@ -505,7 +476,7 @@
             reshape_in.set_all_shapes(reshape_input_shape)
             reshape_in.ops = [op]
 
-            reshape_op = Operation("Reshape", "{}{}_reshape".format(op.name, idx))
+            reshape_op = Operation(Op.Reshape, "{}{}_reshape".format(op.name, idx))
             reshape_op.attrs["new_shape"] = reshape_input_shape
             reshape_op.inputs = [reshape_in, new_shape_tens]
             reshape_op.set_output_tensor(out_tens)
@@ -518,19 +489,16 @@
 def add_padding_fields(op, arch, nng):
     if op.run_on_npu:
         if "padding" in op.attrs:
-            if op.type in conv_op | depthwise_op:
+            if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
                 kernel_size = op.inputs[1].shape[:2]
                 input_shape = op.inputs[0].shape
-            elif op.type in pool_op | reduce_sum_ops:
+            elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
                 kernel_size = op.attrs["ksize"][1:3]
                 input_shape = op.inputs[0].shape
-            elif op.type == "ExtractImagePatches":
-                kernel_size = op.attrs["ksizes"][1:3]
-                input_shape = op.inputs[0].shape
             else:
                 raise UnsupportedFeatureError("Unknown operation that uses padding: {}".format(op.type))
 
-            if op.type == "Conv2DBackpropInputSwitchedBias":
+            if op.type == Op.Conv2DBackpropInputSwitchedBias:
                 upscaling_factor = op.outputs[0].shape[1] // input_shape[1]
                 padding, skirt = calc_upscaled_padding_and_skirt(
                     op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
@@ -564,38 +532,19 @@
     return None
 
 
-def mark_npu_block_type(op, arch, nng):
-    npu_block_type = NpuBlockType.Default
-    if op.type in conv_op:
-        npu_block_type = NpuBlockType.ConvolutionMxN
-    elif op.type in fc_op:
-        npu_block_type = NpuBlockType.VectorProduct
-    elif op.type in depthwise_op:
-        npu_block_type = NpuBlockType.ConvolutionDepthWise
-    elif op.type in pool_op:
-        npu_block_type = NpuBlockType.Pooling
-    elif op.type in elementwise_op:
-        npu_block_type = NpuBlockType.ElementWise
-    elif op.type in reduce_sum_ops:
-        npu_block_type = NpuBlockType.ReduceSum
-
-    op.attrs["npu_block_type"] = npu_block_type
-    return op
-
-
 def convert_depthwise_to_conv(op, arch, nng):
     # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and
     # the ofm depth equals the depth multipler.
     # If those conditions are true, then we can perform a simple
     # switch of the operator type (and weight order)
 
-    if (op.type in depthwise_op) and (op.attrs["depth_multiplier"] != 1):
+    if op.type == Op.DepthwiseConv2DBias and (op.attrs["depth_multiplier"] != 1):
         ifm_tensor = op.inputs[0]
         weight_tensor = op.inputs[1]
         ofm_tensor = op.outputs[0]
         if (ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"]):
             # Change op type to Conv2d
-            op.type = op.type.replace("DepthwiseConv2d", "Conv2D")
+            op.type = Op.Conv2DBias
             del op.attrs["channel_multiplier"]
             del op.attrs["depth_multiplier"]
 
@@ -611,7 +560,7 @@
 
 
 def reorder_depthwise_weights(op, arch, nng):
-    if op.type in depthwise_op:
+    if op.type.is_depthwise_conv2d_op():
         weight_tensor = op.inputs[1]
         weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2))
         weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
@@ -625,18 +574,15 @@
     # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
     # caching/double buffering for the weights.
     # (Weights dont need to be reloaded for convs when IFM H and W are 1)
-    if op.type == "Conv2DBiasAct":
+    if op.type == Op.Conv2DBias:
         _, h, w, _ = op.inputs[0].shape
         kh, kw, _, _ = op.inputs[1].shape
         if h == 1 and w == 1 and kh == 1 and kw == 1:
             # Overwrite this op as a Fully Connected Op
             op.name += "_fc"
-            op.type = "FullyConnectedAct"
-            faf = op.attrs.get("fused_activation_function", None)
+            op.type = Op.FullyConnected
             op.attrs = {
-                "fused_activation_function": faf,
                 "weights_format": 0,
-                "npu_block_type": NpuBlockType.VectorProduct,
             }
             # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
             weight_tensor = op.inputs[1]
@@ -652,7 +598,7 @@
             # Add a reshape after the new OFM to convert it back to the original 4D shape
             reshape_name = op.name + "_reshape"
             new_shape_tens = create_const_tensor(reshape_name + "_shape", [1], DataType.int32, orig_ofm_tensor.shape)
-            reshape_op = Operation("Reshape", reshape_name)
+            reshape_op = Operation(Op.Reshape, reshape_name)
             reshape_op.attrs["new_shape"] = orig_ofm_tensor.shape
             reshape_op.inputs = [fc_ofm_tensor, new_shape_tens]
             reshape_op.set_output_tensor(orig_ofm_tensor)
@@ -662,7 +608,7 @@
 
 
 def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
-    if op.run_on_npu and op.type in relu_ops:
+    if op.run_on_npu and op.type.is_relu_op():
         ifm = op.inputs[0]
         ofm = op.outputs[0]
         # Relu with differing IFM and OFM scaling cannot be fused with another primary op
@@ -671,7 +617,7 @@
             # Override this op with its own primary op (avgpool)
             relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
             # And fuse the original activation function to it
-            relu_fused_op.attrs["fused_activation_function"] = op.type
+            relu_fused_op.activation = op.type
             # Tidy up and assign the ifm and ofm to the new op
             ifm.consumer_list.remove(op)
 
@@ -691,7 +637,7 @@
 
 # Reorder activation op if it's after the memory only operations
 def fixup_act_reorder(op, arch, nng):
-    if op.type in activation_ops:
+    if op.type.is_relu_op() or op in set((Op.Sigmoid, Op.Tanh)):
         prep_op = get_prepend_op(op)
         if prep_op is not None:
             act_op = op.clone("_reordered")
@@ -711,12 +657,12 @@
             prep_op.outputs[0].quantization = act_op_out.quantization.clone()
 
             # Mark the op so that it will be removed as passthrough later on
-            op.type = "Identity"
+            op.type = Op.Identity
     return op
 
 
 def fixup_elementwise_with_scalars(op, arch, nng):
-    if op.type in binary_elementwise_op:
+    if op.type.is_binary_elementwise_op():
         ifm_tensor, ifm2_tensor, _, _ = op.get_ifm_ifm2_weights_ofm()
         if ifm2_tensor.shape != [] and ifm_tensor.shape != []:
             diff = len(ifm_tensor.shape) - len(ifm2_tensor.shape)
@@ -745,7 +691,7 @@
 
 
 def convert_softmax(op, arch, nng):
-    if op.type == "Softmax" and op.run_on_npu:
+    if op.type == Op.Softmax and op.run_on_npu:
         softmax = SoftMax(op)
         op = softmax.get_graph()
     return op
@@ -761,9 +707,9 @@
        Max
     """
 
-    if op.type == "Maximum":
+    if op.type == Op.Maximum:
         # finds the Mul input(s) to the Max
-        muls = [i for i in op.inputs if i.ops[0].type == "MulAct"]
+        muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
         if len(muls) == 1:
             mul = muls[0].ops[0]
         elif len(muls) == 2:
@@ -777,10 +723,10 @@
         mul_ofm = mul.outputs[0]
         if len(mul_ofm.consumers()) != 1:
             return op
-        # make sure the Mul doesn't have a faf
-        if mul.attrs["fused_activation_function"]:
+        # make sure the Mul doesn't have a fused activation function
+        if mul.activation:
             return op
-        ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
+        ifm, ofm = op.get_ifm_ofm()
         if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
             return op
         if not ifm.is_scaling_equal(ofm) or not ifm.is_scaling_equal(mul_ofm):
@@ -798,7 +744,7 @@
                 return op
             const = const_tens.ops[0]
             # check that it is a constant
-            if const.type != "Const":
+            if const.type != Op.Const:
                 return op
             # Remove the Mul from the shared input's consumers
             shared_in.consumer_list.remove(mul)
@@ -807,7 +753,7 @@
 
         val = const.outputs[0].values
         if val >= 0:
-            new_op = "LeakyRelu"
+            new_op = Op.LeakyRelu
             op.attrs["alpha"] = val
             # to produce bit exact results, the alpha is not enough;
             # save additional scaling info in attr "alpha_scale", to be used as input
@@ -819,13 +765,13 @@
             alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
             op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
         elif val == -1:
-            new_op = "Abs"
+            new_op = Op.Abs
         else:
             return op
 
-        op.type = op.type.replace("Maximum", new_op)
-        op.name = op.name.replace("Maximum", new_op)
-        op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op)
+        op.type = new_op
+        op.name = op.name.replace("Maximum", new_op.name)
+        op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
         op.inputs = [shared_in]
     return op
 
@@ -833,10 +779,10 @@
 def convert_lrelu_to_mul_max(op, arch):
     # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
     # (the opposite of convert_mul_max_to_abs_or_lrelu)
-    ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
+    ifm, ofm = op.get_ifm_ofm()
 
     # Add multiplication with alpha
-    mul_alpha = Operation("MulAct", op.name + "_mul_alpha")
+    mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
     mul_alpha.add_input_tensor(ifm)
     # Create const tensor containing alpha as scalar
     alpha = op.attrs["alpha"]
@@ -855,7 +801,7 @@
         fm_id = ifm
     else:
         # Add multiplication with identity
-        mul_identity = Operation("MulAct", op.name + "_mul_identity")
+        mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
         mul_identity.add_input_tensor(ifm)
         # Create const tensor containing identity as scalar
         quantization = ifm.quantization.clone()
@@ -871,7 +817,7 @@
         mul_identity.set_output_tensor(fm_id)
 
     # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
-    op.type = "Maximum"
+    op.type = Op.Maximum
     op.name = op.name.replace("LeakyRelu", "Maximum")
     op.inputs = []
     ifm.consumer_list.remove(op)
@@ -884,9 +830,8 @@
     # Rewrite the operation by Add with scalar 0 + LUT activation
     ifm = op.inputs[0]
     assert ifm.dtype.size_in_bytes() == 1
-    op.type = "AddAct"
+    op.type = Op.Add
     op.name = op.name + "_add"
-    op.attrs.update({"npu_block_type": NpuBlockType.ElementWise})
     # Mark as no-op to enable potential fusing optimizations
     op.attrs["is_nop"] = True
     # Create an input tensor containing scalar zero
@@ -898,7 +843,7 @@
     # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
     # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
     # should be the same as the IFM
-    op.attrs["forced_output_quantization"] = ifm.quantization
+    op.forced_output_quantization = ifm.quantization
     lut_tensor = lut.create_lut_tensor(op.name + "_lut", lut_values, DataType.int8)
     op.set_activation_lut(lut_tensor)
     return op
@@ -907,7 +852,7 @@
 def convert_to_lut8(op, fn):
     # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
     # fn is a function(real) -> real
-    ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
+    ifm, ofm = op.get_ifm_ofm()
     if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
         return op
     # Generate the LUT
@@ -929,7 +874,7 @@
 
 
 def convert_lrelu_to_lut(op, arch):
-    ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
+    ifm, ofm = op.get_ifm_ofm()
     # Generate the LUT
     alpha = op.attrs["alpha"]
     ifm_scale = np.double(ifm.quantization.scale_f32)
@@ -960,9 +905,9 @@
 
 def convert_lrelu(op, arch, nng):
     # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
-    if op.type != "LeakyRelu":
+    if op.type != Op.LeakyRelu:
         return op
-    ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
+    ifm, ofm = op.get_ifm_ofm()
     if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
         # use LUT for int8/uint8
         return convert_lrelu_to_lut(op, arch)
@@ -974,20 +919,20 @@
 
 def convert_tanh_sigmoid_to_lut(op, arch, nng):
     # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
-    if op.type == "Sigmoid":
+    if op.type == Op.Sigmoid:
         return convert_to_lut8(op, clamp_sigmoid)
-    elif op.type == "Tanh":
+    elif op.type == Op.Tanh:
         return convert_to_lut8(op, math.tanh)
     return op
 
 
 def remove_unwanted_reshapes(op, arch, nng):
     # Try to remove reshapes enclosing ElementWise operator with only one non-constant input
-    if not op.run_on_npu or op.attrs["npu_block_type"] != NpuBlockType.ElementWise:
+    if not op.run_on_npu or not op.type.is_elementwise_op():
         return op
 
     # Check if the ElementWise operator only have one non-constant input
-    non_const_tens = [x for x in op.inputs if x.ops[0].type != "Const"]
+    non_const_tens = [x for x in op.inputs if x.ops[0].type != Op.Const]
     if len(non_const_tens) != 1:
         return op
     ifm = non_const_tens[0]
@@ -997,12 +942,12 @@
     prev_op = ifm.ops[0]
     if (
         len(ifm.consumer_list) == 1
-        and prev_op.type == "Reshape"
+        and prev_op.type == Op.Reshape
         and len(ofm.consumer_list) == 1
-        and ofm.consumer_list[0].type == "Reshape"
+        and ofm.consumer_list[0].type == Op.Reshape
     ):
         # Operation is enclosed by reshapes, check if they can be removed
-        prev_op_ifm, _, _, prev_op_ofm = prev_op.get_ifm_weights_biases_ofm()
+        prev_op_ifm, prev_op_ofm = prev_op.get_ifm_ofm()
         cons_op = ofm.consumer_list[0]
         cons_op_ifm = ofm
         cons_op_ofm = cons_op.outputs[0]
@@ -1018,19 +963,18 @@
 
 def fuse_activation_function_with_prev(op, arch, nng):
     # if op is a no-op: attempts to move the activation function to the preceding op
-    if not op.attrs.get("is_nop", False) or op.attrs.get("fused_activation_function", None) is None:
+    if not op.attrs.get("is_nop", False) or op.activation is None:
         return op
-    ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
+    ifm, ofm = op.get_ifm_ofm()
     # finds the input(s) to the operation
     prev_op = ifm.ops[0]
     # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
     fuse = (
         prev_op.run_on_npu
-        and "npu_block_type" in prev_op.attrs
-        and prev_op.attrs["npu_block_type"] != NpuBlockType.Default
+        and prev_op.type.npu_block_type != NpuBlockType.Default
         and len(ifm.ops) == 1
         and len(prev_op.outputs[0].consumers()) == 1
-        and prev_op.attrs.get("fused_activation_function", None) is None
+        and prev_op.activation is None
     )
     if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
         # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
@@ -1039,9 +983,8 @@
     if not fuse:
         return op
     # Move the fused activation function + corresponding info to prev_op
-    for attr in ("fused_activation_function", "forced_output_quantization"):
-        if attr in op.attrs:
-            prev_op.attrs[attr] = op.attrs[attr]
+    prev_op.activation = op.activation
+    prev_op.forced_output_quantization = op.forced_output_quantization
     if op.activation_lut is not None:
         prev_op.set_activation_lut(op.activation_lut)
     # Bypass op
@@ -1050,7 +993,7 @@
 
 
 def add_attrs_to_resizebilinear(op, arch, nng):
-    if op.type == "ResizeBilinear" and op.run_on_npu:
+    if op.type == Op.ResizeBilinear and op.run_on_npu:
         input_tensor = op.inputs[0]
         upscaled_shape = [input_tensor.shape[1] * 2, input_tensor.shape[2] * 2]
         out_shape = op.outputs[0].shape[1:3]
@@ -1070,7 +1013,7 @@
 
 
 def fixup_bias_tensors(op, arch, nng):
-    if op.needs_bias() and not op.inputs[-1]:
+    if op.type.needs_bias() and op.bias is None:
         # Op has no bias, add bias tensor filled with zeros
         nr_biases = op.inputs[1].shape[-1]
         bias_values = [0] * nr_biases
@@ -1091,8 +1034,6 @@
         nng.print_graph()
 
     op_rewrite_list = [
-        # mark block type and check if the operations are supported
-        mark_npu_block_type,
         set_tensor_equivalence,
         supported_operator_check,
         # then do any rewrites of supported operators
@@ -1106,7 +1047,6 @@
         fixup_conv2d_backprop,
         fixup_relus_with_differing_ifm_ofm_scaling,
         fixup_act_reorder,
-        mark_npu_block_type,
         fixup_elementwise_with_scalars,
         reorder_depthwise_weights,
         fixup_resizebilinear,
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index 8486dad..dc52ae5 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -25,6 +25,7 @@
 from .nn_graph import SchedulingStrategy
 from .numeric_util import round_up_divide
 from .operation import NpuBlockType
+from .operation import Op
 from .tensor import TensorPurpose
 
 
@@ -39,7 +40,7 @@
     if source == derived:
         return True
     ops = derived.ops
-    return ops != [] and len(ops) == 1 and ops[0].type == "SplitSliceRead" and source == ops[0].inputs[0]
+    return ops != [] and len(ops) == 1 and ops[0].type == Op.SplitSliceRead and source == ops[0].inputs[0]
 
 
 def generate_high_level_command_stream_for_pass(strat, passes, block_configs, idx):
@@ -56,8 +57,8 @@
             ps.ifm_tensor, ps.ifm2_tensor = ps.ifm2_tensor, ps.ifm_tensor
 
         for op in ps.ops:
-            if op.type == "SplitSliceRead":
-                ps.primary_op.attrs["fused_memory_function"] = op.type
+            if op.type == Op.SplitSliceRead:
+                ps.primary_op.memory_function = op.type
                 assert len(op.inputs) == 1
                 if match_tensor(ps.ifm_tensor, op.inputs[0]):
                     split_offsets[0] = op.attrs["split_start"]
@@ -68,10 +69,10 @@
     else:
         ifm_idx = 0
         for op in ps.ops:
-            if op.type == "SplitSliceRead":
+            if op.type == Op.SplitSliceRead:
                 assert ifm_idx < 2
                 split_offsets[ifm_idx] = op.attrs["split_start"]
-                ps.primary_op.attrs["fused_memory_function"] = op.type
+                ps.primary_op.memory_function = op.type
                 ifm_idx += 1
 
     ifm_tensor = ps.ifm_tensor
@@ -89,19 +90,16 @@
     if ps.primary_op is not None:
         strides = ps.primary_op.attrs.get("strides", None)
         skirt = ps.primary_op.attrs.get("skirt", None)
-        if ps.primary_op.type == "Conv2DBackpropInputSwitchedBias":
+        if ps.primary_op.type == Op.Conv2DBackpropInputSwitchedBias:
             upscaling = ofm_tensor.shape[-3] // ifm_tensor.shape[-3]
-        elif ps.primary_op.type == "ResizeBilinear":
+        elif ps.primary_op.type == Op.ResizeBilinear:
             upscaling = round_up_divide(ofm_tensor.shape[-3], ifm_tensor.shape[-3])
 
     concat_axis = 0
     concat_offset = 0
 
-    # Fusable activation functions
-    activation_ops = set(("Sigmoid", "Tanh", "Relu", "Relu6", "ReluN1To1"))
-
     for op in ps.ops:
-        if op.type == "ConcatSliceWrite":
+        if op.type == Op.ConcatSliceWrite:
             concat_axis = op.attrs["concat_axis"]
             concat_start = op.attrs["concat_start"]
             concat_end = op.attrs["concat_end"]
@@ -109,9 +107,9 @@
             ofm_start[concat_axis] = concat_start
             ofm_end[concat_axis] = concat_end
             concat_offset = concat_start
-            ps.primary_op.attrs["fused_memory_function"] = op.type
-        elif op.type in activation_ops:
-            ps.primary_op.attrs["fused_activation_function"] = op.type
+            ps.primary_op.memory_function = op.type
+        elif op.type.is_relu_op() or op.type in (Op.Tanh, Op.Sigmoid):
+            ps.primary_op.activation = op.type
 
     if strat == SchedulingStrategy.WeightStream:
         ofm_step = block_config[-1]
diff --git a/ethosu/vela/insert_dma.py b/ethosu/vela/insert_dma.py
index 99b46c0..56d68d1 100644
--- a/ethosu/vela/insert_dma.py
+++ b/ethosu/vela/insert_dma.py
@@ -17,6 +17,7 @@
 # Insert DMA operations into the graph for transfering weights.
 from . import rewrite_graph
 from .operation import NpuBlockType
+from .operation import Op
 from .operation import Operation
 from .tensor import MemArea
 from .tensor import MemType
@@ -24,9 +25,6 @@
 from .weight_compressor import compress_weights
 
 
-binary_elementwise_op = set(("AddAct", "MulAct", "SubAct", "Maximum", "Minimum"))
-
-
 def weights_fit_sram(arch, op, tens, nng):
     if tens.purpose != TensorPurpose.Weights:
         return True
@@ -57,7 +55,7 @@
 
 
 def insert_dma_cmd(op, arch, nng):
-    if op.type == "DMA" or not op.run_on_npu:
+    if op.type == Op.DMA or not op.run_on_npu:
         return op
 
     is_lut_used = any(inp.purpose == TensorPurpose.LUT for inp in op.inputs)
@@ -76,14 +74,14 @@
             ) or tens.purpose == TensorPurpose.LUT:
                 if tens.purpose in (TensorPurpose.Weights, TensorPurpose.LUT) or (
                     tens.purpose == TensorPurpose.FeatureMap
-                    and op.type in binary_elementwise_op
+                    and op.type.is_binary_elementwise_op()
                     and tens.shape != []
                     and tens.shape != op.outputs[0].shape
                     and tens.storage_size() > max_ifm_shram_avail
                 ):
                     only_vector_product_consumers = True
                     for oper in tens.consumers():
-                        if oper is None or oper.attrs.get("npu_block_type") != NpuBlockType.VectorProduct:
+                        if oper is None or oper.type.npu_block_type != NpuBlockType.VectorProduct:
                             only_vector_product_consumers = False
                             break
 
@@ -95,7 +93,7 @@
                     ) or tens.purpose == TensorPurpose.LUT:
                         # Insert a DMA command here, as well as a new tensor situated in SRAM of the same size.
                         new_tens = tens.clone_into_fast_storage(arch)
-                        dma_cmd = Operation("DMA", tens.ops[0].name + "_dma")
+                        dma_cmd = Operation(Op.DMA, tens.ops[0].name + "_dma")
                         dma_cmd.inputs = [tens]
                         dma_cmd.set_output_tensor(new_tens)
                         dma_cmd.attrs["source"] = tens.mem_area
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 9a8ee58..23026c7 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -18,6 +18,7 @@
 # Can work with either a pass packed subgraph or a scheduled subgraph.
 from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_cascaded_pass
 from .nn_graph import PassPlacement
+from .operation import Op
 from .tensor import MemType
 from .tensor import Tensor
 
@@ -262,7 +263,11 @@
 
         cps_primary_op = cps.passes[0].primary_op
 
-        if cps_primary_op and cps_primary_op.type == "NpuOp" and MemType.Permanent_CPU not in target_mem_type_set:
+        if (
+            cps_primary_op
+            and cps_primary_op.type == Op.CustomNpuOp
+            and MemType.Permanent_CPU not in target_mem_type_set
+        ):
             # If the primary-op is an NpuOp that means this is where an Npu subgraph
             # is called. Go into said subgraph and extract live ranges before continuing.
             # Use default allocation alignment of 16 for Npu tensors
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py
index c4496cd..206d836 100644
--- a/ethosu/vela/mark_tensors.py
+++ b/ethosu/vela/mark_tensors.py
@@ -18,10 +18,11 @@
 from . import rewrite_graph
 from . import weight_compressor
 from .errors import OperatorError
+from .operation import CustomType
+from .operation import Op
 from .tensor import MemType
 from .tensor import TensorFormat
 from .tensor import TensorPurpose
-from .tflite_mapping import custom_prefix
 
 
 def purpose_from_list(lst):
@@ -59,72 +60,53 @@
     (
         set(
             (
-                "Relu",
-                "Relu6",
-                "Mul",
-                "Add",
-                "Sub",
-                "Rsqrt",
-                "Abs",
-                "Cast",
-                "Exp",
-                "Floor",
-                "FloorDiv",
-                "FloorMod",
-                "SquaredDifference",
-                "AddN",
-                "BiasAdd",
-                "RealDiv",
-                "Maximum",
-                "Minimum",
-                "Sigmoid",
-                "Tanh",
-                "FusedBatchNorm",
-                "AvgPool",
-                "MaxPool",
-                "Squeeze",
-                "Softmax",
-                "LRN",
-                "Assign",
-                "BatchMatMul",
-                "ZerosLike",
-                "ExtractImagePatches",
-                "MulAct",
-                "AddAct",
-                "SubAct",
-                "DivAct",
-                "AvgPoolAct",
-                "MaxPoolAct",
-                "LeakyRelu",
-                "CLZ",
-                "SHL",
-                "SHR",
-                "ReduceSum",
+                Op.Relu,
+                Op.Relu6,
+                Op.Rsqrt,
+                Op.Abs,
+                Op.Cast,
+                Op.Exp,
+                Op.Floor,
+                Op.FloorDiv,
+                Op.FloorMod,
+                Op.SquaredDifference,
+                Op.AddN,
+                Op.Maximum,
+                Op.Minimum,
+                Op.Sigmoid,
+                Op.Tanh,
+                Op.AvgPool,
+                Op.MaxPool,
+                Op.Squeeze,
+                Op.Softmax,
+                Op.LRN,
+                Op.BatchMatMul,
+                Op.ZerosLike,
+                Op.Mul,
+                Op.Add,
+                Op.Sub,
+                Op.Div,
+                Op.LeakyRelu,
+                Op.CLZ,
+                Op.SHL,
+                Op.SHR,
+                Op.ReduceSum,
             )
         ),
         all_fm,
     ),
     (
-        set(
-            (
-                "Conv2D",
-                "DepthwiseConv2dNative",
-                "MatMul",
-                "Conv2DBiasAct",
-                "DepthwiseConv2dBiasAct",
-                "FullyConnectedAct",
-            )
-        ),
+        set((Op.Conv2D, Op.MatMul, Op.Conv2DBias, Op.DepthwiseConv2DBias, Op.FullyConnected,)),
         purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.Weights, TensorPurpose.FeatureMap]),
     ),
     (
-        set(("Conv2DBackpropInputSwitchedBias",)),
+        set((Op.Conv2DBackpropInputSwitchedBias,)),
         purpose_from_list(
             [TensorPurpose.FeatureMap, TensorPurpose.Weights, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
         ),
     ),
     (
-        set(("QuantizedConv2D", "QuantizedMatMul")),
+        set((Op.QuantizedConv2D, Op.QuantizedMatMul)),
         purpose_from_list(
             [
                 TensorPurpose.FeatureMap,
@@ -139,66 +121,39 @@
     (
         set(
             (
-                "Reshape",
-                "Min",
-                "Max",
-                "Mean",
-                "Pad",
-                "MirrorPad",
-                "ArgMax",
-                "ArgMin",
-                "ExpandDims",
-                "ResizeNearestNeighbor",
-                "ResizeBilinear",
-                "Tile",
-                "Transpose",
-                "Mfcc",
+                Op.Reshape,
+                Op.Min,
+                Op.Max,
+                Op.Mean,
+                Op.Pad,
+                Op.MirrorPad,
+                Op.ArgMax,
+                Op.ArgMin,
+                Op.ExpandDims,
+                Op.ResizeNearestNeighbor,
+                Op.ResizeBilinear,
+                Op.Tile,
+                Op.Transpose,
             )
         ),
         purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
     ),
     (
-        set(("QuantizedReshape", "QuantizedResizeBilinear")),
+        set((Op.QuantizedReshape,)),
         purpose_from_list(
             [TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
         ),
     ),
     (
-        set(("QuantizedBiasAdd", "QuantizedAdd", "QuantizedMul")),
-        purpose_from_list(
-            [
-                TensorPurpose.FeatureMap,
-                TensorPurpose.FeatureMap,
-                TensorPurpose.FeatureMap,
-                TensorPurpose.FeatureMap,
-                TensorPurpose.FeatureMap,
-                TensorPurpose.FeatureMap,
-            ]
-        ),
-    ),
-    (
-        set(
-            (
-                "Dequantize",
-                "Quantize",
-                "QuantizeV2",
-                "QuantizedRelu",
-                "QuantizedRelu1",
-                "QuantizedRelu6",
-                "QuantizedAvgPool",
-                "QuantizedMaxPool",
-                "Slice",
-                "SplitV",
-            )
-        ),
+        set((Op.Dequantize, Op.Quantize, Op.QuantizedAvgPool, Op.QuantizedMaxPool, Op.Slice, Op.SplitV,)),
         purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
     ),
     (
-        set(("BatchToSpaceND", "SpaceToBatchND", "DepthToSpaceND", "SpaceToDepthND")),
+        set((Op.BatchToSpaceND, Op.SpaceToBatchND, Op.DepthToSpace, Op.SpaceToDepth)),
         purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
     ),
     (
-        set(("BlockLSTM",)),
+        set((Op.BlockLSTM,)),
         purpose_from_list(
             [
                 TensorPurpose.FeatureMap,
@@ -213,33 +168,18 @@
             ]
         ),
     ),
-    (set(("SplitSliceRead",)), purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap])),
-    (set(("Shape", "ConcatSliceWrite", "AudioSpectrogram")), purpose_from_list([TensorPurpose.FeatureMap])),
+    (set((Op.SplitSliceRead,)), purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap])),
+    (set((Op.Shape, Op.ConcatSliceWrite)), purpose_from_list([TensorPurpose.FeatureMap])),
     (
-        set(("StridedSlice",)),
+        set((Op.StridedSlice,)),
         purpose_from_list(
             [TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
         ),
     ),
-    (set(("Fill", "Pack", "Range")), all_parameter),
-    (
-        set(("Requantize",)),
-        purpose_from_list(
-            [
-                TensorPurpose.FeatureMap,
-                TensorPurpose.FeatureMap,
-                TensorPurpose.FeatureMap,
-                TensorPurpose.FeatureMap,
-                TensorPurpose.FeatureMap,
-            ]
-        ),
-    ),
-    (set(("Placeholder", "SubgraphInput", "Const", "VariableV2")), purpose_from_list([])),
-    (set(("FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxVars")), input0_from_output_rest_parameter),
-    (
-        set(("Square", "Sqrt", "Log", "Less", "Enter", "Exit", "Identity", "StopGradient", "Merge", "Switch")),
-        inputs_from_output,
-    ),
+    (set((Op.Fill, Op.Pack, Op.Range)), all_parameter),
+    (set((Op.Placeholder, Op.SubgraphInput, Op.Const,)), purpose_from_list([])),
+    (set((Op.FakeQuantWithMinMaxArgs,)), input0_from_output_rest_parameter),
+    (set((Op.Square, Op.Sqrt, Op.Log, Op.Less, Op.Identity,)), inputs_from_output,),
     (None, all_fm),
 ]
 
@@ -247,8 +187,6 @@
 for ops, input_purpose in tensor_purposes:
     if ops is None:
         continue
-    for op in ops:
-        assert len(op) > 1, "string literal has been decomposed"
 
 
 def mark_tensor_purpose(nng, arch, verbose_tensor_purpose=False):
@@ -260,7 +198,7 @@
         tens.mem_area = arch.tensor_storage_mem_area[tens.purpose]
         tens.mem_type = arch.tensor_storage_mem_type[tens.purpose]
 
-        if len(tens.ops) == 1 and tens.ops[0].type == "Const":
+        if len(tens.ops) == 1 and tens.ops[0].type == Op.Const:
             tens.mem_area = (
                 arch.permanent_storage_mem_area
             )  # special case constants, as they must be in permanent storage
@@ -288,11 +226,11 @@
                     purpose = input_purpose(op, idx) if tens.purpose == TensorPurpose.Unknown else tens.purpose
                     mark_tensor_helper(tens, purpose)
 
-                if op.type == "Reshape":
+                if op.type == Op.Reshape:
                     # Reshape's input and output point to same data
                     op.outputs[0].mem_area = op.inputs[0].mem_area
 
-                if op.type.startswith(custom_prefix) and op.attrs.get("custom_type", "") == "ExistingNpuOp":
+                if op.type == Op.Custom and op.attrs.get("custom_type") == CustomType.ExistingNpuOp:
                     scratch_tensor = None
 
                     if len(op.inputs) >= 3:
@@ -301,7 +239,7 @@
                             scratch_tensor.purpose = TensorPurpose.Scratch
 
                     if scratch_tensor is None:
-                        raise OperatorError(op, "Scratch tensor not found.")
+                        OperatorError(op, "Scratch tensor not found.")
 
                 break
 
@@ -318,21 +256,6 @@
     return nng
 
 
-reshape_operations = set(
-    (
-        "Reshape",
-        "QuantizedReshape",
-        "ExpandDims",
-        "Squeeze",
-        "BatchToSpaceND",
-        "SpaceToBatchND",
-        "DepthToSpaceND",
-        "SpaceToDepthND",
-        "Placeholder",
-    )
-)
-
-
 def mark_tensor_format(nng, arch, verbose_tensor_format=False):
     formats_for_tensor = {}
 
@@ -375,8 +298,9 @@
             if src_tens is not None:
                 op = tens.find_npu_op()
                 if op is not None:
-                    npu_block_type = op.attrs["npu_block_type"]
-                    weight_compressor.compress_weights(arch, nng, tens, npu_block_type, 16, 16, op.get_dilation_h_w())
+                    weight_compressor.compress_weights(
+                        arch, nng, tens, op.type.npu_block_type, 16, 16, op.get_dilation_h_w()
+                    )
                     # Alias compressed weights back into source tensor
                     src_tens.copy_compressed_weight_info(tens)
 
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index 58aab61..12edf5e 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -22,6 +22,8 @@
 # Graph - A full neural network graph with one or more Subgraphs.
 import enum
 
+from .operation import Op
+
 
 class PassPlacement(enum.Enum):
     Unknown = 0
@@ -176,7 +178,7 @@
                 visit_tensor(inp)
                 inp.consumer_list.append(op)
 
-            if op.type in set(("Placeholder", "SubgraphInput")):
+            if op.type in set((Op.Placeholder, Op.SubgraphInput)):
                 assert len(op.outputs) == 1
                 self.input_tensors.append(op.outputs[0])
 
@@ -319,19 +321,14 @@
         all_ops = self.get_all_ops()
         unique_ops = []
         for op in all_ops:
-            if op.type in set(("Const", "Identity", "Placeholder")):
+            if op.type in set((Op.Const, Op.Identity, Op.Placeholder)):
                 continue
 
-            attrs = op.attrs
-            if (
-                op.type == "Conv2D"
-                or op.type == "DepthwiseConv2dNative"
-                or op.type == "Conv2DBiasAct"
-                or op.type == "DepthwiseConv2dBiasAct"
-            ):
+            attrs = op.attrs.copy()
+            if op.type in (Op.Conv2D, Op.Conv2DBias, Op.DepthwiseConv2DBias):
                 kshape = op.inputs[1].shape
                 attrs["kshape"] = [kshape[0], kshape[1]]
-            attrs["type"] = op.type
+            attrs["type"] = op.type.name
             attrs.pop("use_cudnn_on_gpu", None)
             if attrs not in unique_ops:
                 unique_ops.append(attrs)
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index e09fc9e..fc148f3 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -41,10 +41,7 @@
 
     if ps2.npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct)):
         op = ps2.primary_op
-        ifm_idx, _, _, _, _ = op.get_ifm_ifm2_weight_bias_ofm_indices()
-        ifm_block_depth = arch.calc_ifm_block_depth(
-            op.inputs[ifm_idx].shape[-1], op.inputs[ifm_idx].dtype.size_in_bits()
-        )
+        ifm_block_depth = arch.calc_ifm_block_depth(op.ifm.shape[-1], op.ifm.dtype.size_in_bits())
     else:
         ifm_block_depth = block_config_ps2[-1]
 
@@ -237,8 +234,8 @@
     elif primary_op:
         skirt = primary_op.attrs.get("skirt", skirt)
         explicit_padding = primary_op.attrs.get("explicit_padding", explicit_padding)
-        assert primary_op.attrs["npu_block_type"] == ps.npu_block_type
-        npu_block_type = primary_op.attrs["npu_block_type"]
+        assert primary_op.type.npu_block_type == ps.npu_block_type
+        npu_block_type = primary_op.type.npu_block_type
 
         ifm_tensor, _, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm()
 
diff --git a/ethosu/vela/npu_serialisation.py b/ethosu/vela/npu_serialisation.py
index 430db58..7989fa9 100644
--- a/ethosu/vela/npu_serialisation.py
+++ b/ethosu/vela/npu_serialisation.py
@@ -22,6 +22,7 @@
 from . import driver_actions
 from .data_type import DataType
 from .nn_graph import PassPlacement
+from .operation import Op
 from .operation import Operation
 from .tensor import MemArea
 from .tensor import MemType
@@ -125,7 +126,7 @@
                     # For DMA ops, ps.weight_tensor is referring to the SRAM weight tensor and therefore the address
                     # is pointing at the destination address of where the weights should be placed in SRAM.
                     # This ensures that the Flash weight tensor is used instead and thus gets the correct address.
-                    if ps.weight_tensor.ops[0].type == "DMA":
+                    if ps.weight_tensor.ops[0].type == Op.DMA:
                         copy_compressed_values_to_memory_tensor(sg.flash_tensor, ps.weight_tensor.ops[0].inputs[0])
                     else:
                         copy_compressed_values_to_memory_tensor(sg.flash_tensor, ps.weight_tensor)
@@ -150,7 +151,7 @@
 
 
 def add_const_tens_to_startup_cascaded_pass(startup_cps, tens):
-    op = Operation("Const", tens.name + "_const")
+    op = Operation(Op.Const, tens.name + "_const")
     op.set_output_tensor(tens)
     startup_cps.passes[0].ops.insert(0, op)
     startup_cps.passes[0].outputs.insert(0, tens)
@@ -166,9 +167,8 @@
     for idx, cps in enumerate(sg.cascaded_passes):
         for ps in cps.passes:
             for op in ps.ops:
-                if op.type == "NpuOp":
+                if op.type == Op.CustomNpuOp:
                     callee = op.attrs["subgraph"]
-                    op.attrs["custom_type"] = op.type
 
                     sz = 0
                     for tens in [
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 1481887..a2b67df 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -15,10 +15,11 @@
 # limitations under the License.
 # Description:
 # Internal representation of a Neural Network Operation.
-import enum
+from collections import namedtuple
+from enum import Enum
 
 
-class NpuBlockType(enum.Enum):
+class NpuBlockType(Enum):
     Default = 0
     ConvolutionMxN = 1
     VectorProduct = 2
@@ -28,10 +29,266 @@
     ReduceSum = 6
 
 
+# Classifies operators of type Custom
+class CustomType(Enum):
+    ThirdPartyOp = 0  # Third party custom op
+    NpuOp = 1  # NPU op
+    ExistingNpuOp = 2  # NPU op that was part of the input network
+
+
+TensorIndices = namedtuple("TensorIndices", ["ifms", "weights", "biases"])
+
+NO_INDICES = TensorIndices([], [], [])
+IFM_INDICES = TensorIndices([0], [], [])
+IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
+IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
+IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
+CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
+TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
+CONCAT_INDICES = TensorIndices([1, 2], [], [])
+SPLIT_IFM_INDICES = TensorIndices([1], [], [])
+BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
+
+
+# Static information related to operation codes
+class OperatorInfo:
+    __slots__ = ("id", "block_type", "indices", "is_unary")
+    _id = 0
+
+    def __init__(self, block_type=NpuBlockType.Default, indices=NO_INDICES, is_unary=False):
+        OperatorInfo._id += 1
+        self.id = OperatorInfo._id
+        self.block_type = block_type
+        self.indices = indices  # Indices of the different tensor purposes
+        self.is_unary = is_unary  # Classifies elementwise operators
+
+
+# Internally used operation codes
+class Op(Enum):
+    Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
+    Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+    AddN = OperatorInfo()
+    Any = OperatorInfo()
+    ArgMax = OperatorInfo()
+    ArgMin = OperatorInfo()
+    AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+    BatchMatMul = OperatorInfo()
+    BatchToSpaceND = OperatorInfo()
+    BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=BLOCK_LSTM_INDICES)
+
+    CLZ = OperatorInfo(
+        block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True
+    )  # NPU specific operation
+    Call = OperatorInfo()
+    Cast = OperatorInfo()
+    Ceil = OperatorInfo()
+    Concat = OperatorInfo(indices=CONCAT_INDICES)
+    ConcatEmbeddings = OperatorInfo()
+    ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES)
+    ConcatTFLite = OperatorInfo()
+    Const = OperatorInfo()  # Constant tensor, only used in CPU subgraphs
+    Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
+    Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES)
+    Conv2DBackpropInputSwitchedBias = OperatorInfo(
+        block_type=NpuBlockType.ConvolutionMxN, indices=TRANSPOSE_CONV_INDICES
+    )
+    Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_BIAS_INDICES)
+    Cos = OperatorInfo()
+    Custom = OperatorInfo()  # Custom 3rd party operator, only used in CPU subgraphs
+    CustomNpuOp = OperatorInfo()  # NPU custom operator, only used in CPU subgraphs
+    DMA = OperatorInfo()
+    Delegate = OperatorInfo()
+    Densify = OperatorInfo()
+    DepthToSpace = OperatorInfo()
+    DepthwiseConv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionDepthWise, indices=IFM_WEIGHTS_BIAS_INDICES)
+    Dequantize = OperatorInfo()
+    Div = OperatorInfo()
+    Elu = OperatorInfo()
+    EmbeddingLookup = OperatorInfo()
+    EmbeddingLookupSparse = OperatorInfo()
+    Equal = OperatorInfo()
+    Exp = OperatorInfo()
+    ExpandDims = OperatorInfo(indices=IFM_INDICES)
+    FakeQuantWithMinMaxArgs = OperatorInfo()
+    Fill = OperatorInfo()
+    Floor = OperatorInfo()
+    FloorDiv = OperatorInfo()
+    FloorMod = OperatorInfo()
+    FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_BIAS_INDICES)
+    GatherNd = OperatorInfo()
+    GatherV2 = OperatorInfo()
+    Greater = OperatorInfo()
+    GreaterEqual = OperatorInfo()
+    HardSwish = OperatorInfo()
+    HashtableLookup = OperatorInfo()
+    Identity = OperatorInfo()
+    If = OperatorInfo()
+    L2Norm = OperatorInfo()
+    L2Pool2D = OperatorInfo()
+    LRN = OperatorInfo()
+    LSHProjection = OperatorInfo()
+    LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
+    Less = OperatorInfo()
+    LessEqual = OperatorInfo()
+    Log = OperatorInfo()
+    LogSoftmax = OperatorInfo()
+    LogicalAnd = OperatorInfo()
+    LogicalNot = OperatorInfo()
+    LogicalOr = OperatorInfo()
+    Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    LUT = OperatorInfo()  # NPU specific, operator has LUT, only used in fused activation functions
+    MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    MatrixDiag = OperatorInfo()
+    MatrixSetDiag = OperatorInfo()
+    Max = OperatorInfo()
+    MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+    Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+    Mean = OperatorInfo()
+    Min = OperatorInfo()
+    Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+    MirrorPad = OperatorInfo()
+    Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+    Neg = OperatorInfo()
+    NonMaxSuppressionV4 = OperatorInfo()
+    NonMaxSuppressionV5 = OperatorInfo()
+    NotEqual = OperatorInfo()
+    OneHot = OperatorInfo()
+    Pack = OperatorInfo()
+    PackReshaped = OperatorInfo(indices=IFM_INDICES)
+    Pad = OperatorInfo()
+    PadV2 = OperatorInfo()
+    Placeholder = OperatorInfo()  # Only used in CPU subgraphs
+    Pow = OperatorInfo()
+    Prelu = OperatorInfo()
+    Prod = OperatorInfo()
+    Quantize = OperatorInfo()
+    QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+    QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
+    QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+    QuantizedReshape = OperatorInfo(indices=IFM_INDICES)
+    Range = OperatorInfo()
+    Rank = OperatorInfo()
+    ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=IFM_INDICES)
+    Relu = OperatorInfo(indices=IFM_INDICES)
+    Relu6 = OperatorInfo(indices=IFM_INDICES)
+    ReluN1To1 = OperatorInfo(indices=IFM_INDICES)
+    Reshape = OperatorInfo(indices=IFM_INDICES)
+    ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+    ResizeNearestNeighbor = OperatorInfo()
+    ReverseSequence = OperatorInfo()
+    ReverseV2 = OperatorInfo()
+    Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    Round = OperatorInfo()
+    Rsqrt = OperatorInfo()
+    SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)  # NPU specific operation
+    SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)  # NPU specific operation
+    ScatterNd = OperatorInfo()
+    SegmentSum = OperatorInfo()
+    Select = OperatorInfo()
+    SelectV2 = OperatorInfo()
+    Shape = OperatorInfo()
+    Sigmoid = OperatorInfo(indices=IFM_INDICES)
+    SignBit = OperatorInfo()
+    Sin = OperatorInfo()
+    SkipGram = OperatorInfo()
+    Slice = OperatorInfo(indices=IFM_INDICES)
+    Softmax = OperatorInfo()
+    SpaceToBatchND = OperatorInfo()
+    SpaceToDepth = OperatorInfo()
+    SparseToDense = OperatorInfo()
+    Split = OperatorInfo(indices=SPLIT_IFM_INDICES)
+    SplitSliceRead = OperatorInfo(indices=IFM_INDICES)
+    SplitV = OperatorInfo(indices=IFM_INDICES)
+    Sqrt = OperatorInfo()
+    Square = OperatorInfo()
+    SquaredDifference = OperatorInfo()
+    Squeeze = OperatorInfo(indices=IFM_INDICES)
+    StridedSlice = OperatorInfo(indices=IFM_INDICES)
+    StridedSliceOptions = OperatorInfo()
+    Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+    SubgraphInput = OperatorInfo()  # Only used in CPU subgraphs
+    Sum = OperatorInfo()
+    Svdf = OperatorInfo()
+    Tanh = OperatorInfo(indices=IFM_INDICES)
+    Tile = OperatorInfo()
+    TopKV2 = OperatorInfo()
+    Transpose = OperatorInfo()
+    UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+    Unique = OperatorInfo()
+    Unpack = OperatorInfo()
+    UnpackReshaped = OperatorInfo(indices=IFM_INDICES)
+    Where = OperatorInfo()
+    While = OperatorInfo()
+    ZerosLike = OperatorInfo()
+
+    @property
+    def info(self):
+        return self.value
+
+    @property
+    def npu_block_type(self):
+        return self.info.block_type
+
+    def is_conv2d_op(self):
+        return self.info.block_type == NpuBlockType.ConvolutionMxN
+
+    def is_depthwise_conv2d_op(self):
+        return self.info.block_type == NpuBlockType.ConvolutionDepthWise
+
+    def is_pool_op(self):
+        return self.info.block_type == NpuBlockType.Pooling
+
+    def is_maxpool_op(self):
+        return self in (Op.MaxPool, Op.QuantizedMaxPool)
+
+    def is_avgpool_op(self):
+        return self in (Op.QuantizedAvgPool, Op.AvgPool)
+
+    def is_elementwise_op(self):
+        return self.info.block_type == NpuBlockType.ElementWise
+
+    def is_unary_elementwise_op(self):
+        return self.info.block_type == NpuBlockType.ElementWise and self.info.is_unary
+
+    def is_binary_elementwise_op(self):
+        return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary
+
+    def is_relu_op(self):
+        return self in (Op.Relu, Op.Relu6, Op.ReluN1To1)
+
+    def is_activation_op(self):
+        return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT)
+
+    def is_split_op(self):
+        return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped)
+
+    def is_concat_op(self):
+        return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped)
+
+    def needs_bias(self):
+        return bool(self.info.indices.biases)
+
+    @classmethod
+    def op_set(cls, predicate):
+        # Returns the set of all operator codes that fulfill the given predicate
+        return {op_type for op_type in Op if predicate(op_type)}
+
+    def __str__(self):
+        return self.name
+
+    __repr__ = __str__
+
+    def __lt__(self, other):
+        return self.value.id < other.value.id
+
+
 def create_avgpool_nop(name):
-    op = Operation("AvgPool", name)
+    op = Operation(Op.AvgPool, name)
     op.attrs["padding"] = b"VALID"
-    op.attrs["npu_block_type"] = NpuBlockType.Pooling
     op.attrs["stride_w"] = 1
     op.attrs["stride_h"] = 1
     op.attrs["filter_width"] = 1
@@ -70,6 +327,9 @@
         "flops",
         "scheduled_pass",
         "run_on_npu",
+        "activation",
+        "memory_function",
+        "forced_output_quantization",
         "activation_lut",
     )
 
@@ -81,6 +341,13 @@
         self.outputs = []
         self.flops = 0
         self.run_on_npu = True
+        # Fused activation function. If not none: operator code.
+        self.activation = None
+        # Fused memory function, if not None: operator code
+        self.memory_function = None
+        # If not none: contains QuantizationParameters to be used as output quantization
+        # (which overrides the ofm tensor's quantization), used in LUT
+        self.forced_output_quantization = None
         self.scheduled_pass = None
         self.op_index = None  # input network operator index
         self.activation_lut = None
@@ -92,173 +359,95 @@
         res.inputs = list(self.inputs)
         res.outputs = list(self.outputs)
         res.flops = self.flops
+        res.run_on_npu = self.run_on_npu
+        res.activation = self.activation
+        res.memory_function = self.memory_function
+        res.forced_output_quantization = self.forced_output_quantization
         res.scheduled_pass = self.scheduled_pass
         res.op_index = None  # not relevant as not part of input network
 
         return res
 
     def __str__(self):
-        return "<nng.Operation '%s' type=%s>" % (self.name, self.type)
+        return "<nng.Operation '{}' type={}>".format(self.name, self.type)
 
     __repr__ = __str__
 
-    def get_ifm_ifm2_weight_bias_ofm_indices(self):
-        ifm_idx = -1
-        ifm2_idx = -1
-        weight_idx = -1
-        bias_idx = -1
-        ofm_idx = -1
-        npu_block_type = self.attrs.get("npu_block_type", NpuBlockType.Default)
-        if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise):
-            ifm_idx = 0
-            weight_idx = 1
-            ofm_idx = 0
-
-            if self.type in ("Conv2DBiasAct", "DepthwiseConv2dBiasAct", "TransposeConvAct"):
-                if len(self.inputs) >= 3:
-                    bias_idx = 2
-
-            elif self.type == "Conv2DBackpropInputSwitchedBias":
-                bias_idx = 3
-
-        elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
-            ifm_idx = 0
-            ofm_idx = 0
-        elif npu_block_type == NpuBlockType.VectorProduct:
-            ifm_idx = 0
-            weight_idx = 1
-            ofm_idx = 0
-
-            if self.type == "FullyConnectedAct":
-                if len(self.inputs) >= 3:
-                    bias_idx = 2
-
-            if self.type == "BlockLSTM":
-                ifm_idx = 3
-                weight_idx = 4
-                ofm_idx = 6
-
-        elif npu_block_type == NpuBlockType.ElementWise:
-            ifm_idx = 0
-            ifm2_idx = 1
-            ofm_idx = 0
-
-            # LeakyRelu, Abs and CLZ have a single IFM
-            if self.type in ("LeakyRelu", "Abs", "CLZ"):
-                ifm2_idx = -1
-
-        elif self.type == "Conv2DBackpropInput":
-            ifm_idx = 2
-            weight_idx = 1
-            ofm_idx = 0
-
-        elif self.type in ("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims", "Sigmoid", "Tanh"):
-            ifm_idx = 0
-            ofm_idx = 0
-
-        elif self.is_split_op():
-            ifm_idx = 0
-            ofm_idx = 0
-            if self.type == "Split":
-                ifm_idx = 1
-
-        elif self.is_concat_op():
-            ifms, _ = self.get_concat_inputs_axis()
-            ifm_idx = self.inputs.index(ifms[0])
-            if len(ifms) > 1:
-                ifm2_idx = self.inputs.index(ifms[1])
-            ofm_idx = 0
-
-        return ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx
-
     def get_ifm_ifm2_weights_ofm(self):
-        ifm_tensor = None
-        ifm2_tensor = None
-        weight_tensor = None
-        ofm_tensor = None
-
-        ifm_idx, ifm2_idx, weight_idx, _, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
-        if ifm_idx != -1:
-            ifm_tensor = self.inputs[ifm_idx]
-        if ifm2_idx != -1:
-            ifm2_tensor = self.inputs[ifm2_idx]
-        if weight_idx != -1:
-            weight_tensor = self.inputs[weight_idx]
-        if ofm_idx != -1:
-            ofm_tensor = self.outputs[ofm_idx]
-
-        return ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor
+        return self.ifm, self.ifm2, self.weights, self.ofm
 
     def get_ifm_weights_biases_ofm(self):
-        ifm_tensor = None
-        weight_tensor = None
-        bias_tensor = None
-        ofm_tensor = None
-
-        ifm_idx, _, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
-        if ifm_idx != -1:
-            ifm_tensor = self.inputs[ifm_idx]
-        if weight_idx != -1:
-            weight_tensor = self.inputs[weight_idx]
-        if bias_idx != -1:
-            bias_tensor = self.inputs[bias_idx]
-        if ofm_idx != -1:
-            ofm_tensor = self.outputs[ofm_idx]
-
-        return ifm_tensor, weight_tensor, bias_tensor, ofm_tensor
+        return self.ifm, self.weights, self.bias, self.ofm
 
     def get_ifm_ifm2_weights_biases_ofm(self):
-        ifm_tensor = None
-        ifm2_tensor = None
-        weight_tensor = None
-        bias_tensor = None
-        ofm_tensor = None
+        return self.ifm, self.ifm2, self.weights, self.bias, self.ofm
 
-        ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
-        if ifm_idx != -1:
-            ifm_tensor = self.inputs[ifm_idx]
-        if ifm2_idx != -1:
-            ifm2_tensor = self.inputs[ifm2_idx]
-        if weight_idx != -1:
-            weight_tensor = self.inputs[weight_idx]
-        if bias_idx != -1:
-            bias_tensor = self.inputs[bias_idx]
-        if ofm_idx != -1:
-            ofm_tensor = self.outputs[ofm_idx]
+    def get_ifm_ofm(self):
+        return self.ifm, self.ofm
 
-        return ifm_tensor, ifm2_tensor, weight_tensor, bias_tensor, ofm_tensor
+    @property
+    def ifm(self):
+        # Gets the IFM tensor, or None if not applicable
+        return self.get_input(self.type.info.indices.ifms, 0)
 
-    def get_ofm(self):
-        _, _, _, ofm = self.get_ifm_ifm2_weights_ofm()
-        return ofm
+    @property
+    def ifm2(self):
+        # Gets the IFM2 tensor, or None if not applicable
+        return self.get_input(self.type.info.indices.ifms, 1)
 
-    def is_concat_op(self):
-        return self.type in ("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped")
+    @property
+    def bias(self):
+        # Gets the bias tensor, or None if not applicable
+        return self.get_input(self.type.info.indices.biases, 0)
+
+    @property
+    def weights(self):
+        # Gets the weight tensor, or None if not applicable
+        return self.get_input(self.type.info.indices.weights, 0)
+
+    def get_ifm_tensors(self):
+        # Gets the IFM tensors, or empty list if not applicable
+        return self._index_list_to_tensors(self.type.info.indices.ifms)
+
+    def get_weight_tensors(self):
+        # Gets the weight tensors, or empty list if not applicable
+        return self._index_list_to_tensors(self.type.info.indices.weights)
+
+    def get_bias_tensors(self):
+        # Gets the bias tensors, or empty list if not applicable
+        return self._index_list_to_tensors(self.type.info.indices.biases)
+
+    def _index_list_to_tensors(self, index_list):
+        return [self.inputs[ix] for ix in index_list if ix < len(self.inputs)]
+
+    def get_input(self, index_list, ix):
+        if ix >= len(index_list):
+            return None
+        if index_list[ix] >= len(self.inputs):
+            return None
+        return self.inputs[index_list[ix]]
+
+    @property
+    def ofm(self):
+        # Gets the OFM tensor, or None if not applicable
+        return self.outputs[0] if self.outputs else None
 
     def get_concat_inputs_axis(self):
-        assert self.is_concat_op()
+        assert self.type.is_concat_op()
 
-        if self.type == "ConcatV2":
-            axis_tensor = self.inputs[-1]
-            inputs = self.inputs[:-1]
-        elif self.type == "Concat":
+        if self.type == Op.Concat:
             axis_tensor = self.inputs[0]
             inputs = self.inputs[1:]
-        elif self.type == "QuantizedConcat":
-            axis_tensor = self.inputs[0]
-            inputs = self.inputs[1:]
-            inputs = inputs[: len(inputs) // 3]  # Skip min/max
-
-        if self.type == "ConcatTFLite":
+        elif self.type == Op.ConcatTFLite:
             inputs = self.inputs
             axis = self.attrs["axis"]
-        elif self.type == "PackReshaped":
+        elif self.type == Op.PackReshaped:
             # Requires fixup_pack_input to be called before this point
             inputs = self.inputs
             axis = self.attrs["axis"]
             assert len(self.inputs) == self.attrs["values_count"]
         else:
-            assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == "Const"
+            assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == Op.Const
             axis = int(axis_tensor.values)
 
         return inputs, axis
@@ -267,33 +456,30 @@
         _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
         return dilation_h, dilation_w
 
-    def is_split_op(self):
-        return self.type in ("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped")
-
     def get_split_inputs_axis(self):
-        assert self.is_split_op()
+        assert self.type.is_split_op()
 
         offset_start = None
         offset_end = None
         axis = None
-        if self.type == "Split":
+        if self.type == Op.Split:
             num_splits = self.attrs.get("num_splits")
             axis_tens = self.inputs[0]
-            assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
+            assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
             axis = int(axis_tens.values)
             input_tens = self.inputs[1]
             outputs = self.outputs
             assert num_splits == len(outputs)
 
-        elif self.type == "SplitV":
+        elif self.type == Op.SplitV:
             num_splits = self.attrs.get("num_splits")
             input_tens = self.inputs[0]
             size_tens = self.inputs[1]
-            assert len(size_tens.ops) == 1 and size_tens.ops[0].type == "Const"
+            assert len(size_tens.ops) == 1 and size_tens.ops[0].type == Op.Const
             sizes = size_tens.values
 
             axis_tens = self.inputs[2]
-            assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
+            assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
             axis = int(axis_tens.values)
 
             for idx, size in enumerate(sizes):
@@ -306,7 +492,7 @@
             assert num_splits == len(outputs)
             assert sum(sizes) == input_tens.shape[axis]
 
-        elif self.type == "Slice":
+        elif self.type == Op.Slice:
             input_tens, begin_tens, size_tens = self.inputs
             outputs = self.outputs
             offset_start = [0] * len(input_tens.shape)
@@ -318,7 +504,7 @@
                     offset_start[idx] = begin_tens.values[idx]
                     offset_end[idx] = size_tens.values[idx] + offset_start[idx]
 
-        elif self.type == "StridedSlice":
+        elif self.type == Op.StridedSlice:
             input_tens, begin_tens, end_tens, strides_tens = self.inputs
             outputs = self.outputs
             out_tens = outputs[0]
@@ -336,7 +522,7 @@
             assert len(input_tens.shape) == len(out_tens.shape)
             offset_start = get_slice_offsets(input_tens.shape, begin_tens, begin_mask, is_begin=True)
             offset_end = get_slice_offsets(input_tens.shape, end_tens, end_mask, is_begin=False)
-        elif self.type == "UnpackReshaped":
+        elif self.type == Op.UnpackReshaped:
             # Requires fixup_unpack_output to be called before this point
             input_tens = self.inputs[0]
             outputs = self.outputs
@@ -350,7 +536,7 @@
         return input_tens, outputs, axis, offset_start, offset_end
 
     def set_activation_lut(self, lut_tensor):
-        self.attrs["fused_activation_function"] = "LUT"
+        self.activation = Op.LUT
         self.activation_lut = lut_tensor
         self.add_input_tensor(lut_tensor)
 
@@ -372,13 +558,7 @@
         tens.ops = [self]
         self.outputs = [tens]
 
-    def needs_bias(self):
-        return self.type in (
-            "Conv2DBiasAct",
-            "DepthwiseConv2dBiasAct",
-            "Conv2DBackpropInputSwitchedBias",
-            "FullyConnectedAct",
-        )
-
     def get_output_quantization(self):
-        return self.attrs.get("forced_output_quantization", self.get_ofm().quantization)
+        if self.forced_output_quantization is not None:
+            return self.forced_output_quantization
+        return self.ofm.quantization
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index f49f981..35e1b14 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -22,6 +22,7 @@
 from .nn_graph import PassPlacement
 from .operation import create_avgpool_nop
 from .operation import NpuBlockType
+from .operation import Op
 from .tensor import TensorPurpose
 
 
@@ -40,81 +41,57 @@
     PostFusingLimited = 8192
 
 
-npu_pre_ops = set(("QuantizedResizeBilinear", "SplitSliceRead",))
+npu_pre_ops = set((Op.SplitSliceRead,))
 
 mac_main_ops = set(
     (
         # convolutions
-        "Conv2DBiasAct",
-        "Conv2D",
-        "QuantizedConv2D",
-        "Conv2DBackpropInputSwitchedBias",
+        Op.Conv2DBias,
+        Op.Conv2D,
+        Op.QuantizedConv2D,
+        Op.Conv2DBackpropInputSwitchedBias,
         # depth-wise convolutions
-        "DepthwiseConv2dBiasAct",
-        "DepthwiseConv2dNative",
-        "QuantizedDepthwiseConv2D",
+        Op.DepthwiseConv2DBias,
         # FC layers
-        "QuantizedMatMul",
-        "MatMul",
-        "FullyConnectedAct",
+        Op.QuantizedMatMul,
+        Op.MatMul,
+        Op.FullyConnected,
         # RNN/LSTM/GRU
-        "BlockLSTM",
+        Op.BlockLSTM,
         # pooling
-        "QuantizedMaxPool",
-        "QuantizedAvgPool",
-        "AvgPool",
-        "MaxPool",
-        "AvgPoolAct",
-        "MaxPoolAct",
-        "ReduceSum",
+        Op.QuantizedMaxPool,
+        Op.QuantizedAvgPool,
+        Op.AvgPool,
+        Op.MaxPool,
+        Op.ReduceSum,
         # deconvolution
-        "ResizeBilinear",
+        Op.ResizeBilinear,
     )
 )
 
-binary_elem_wise_main_ops = set(
-    (
-        # binary element-wise
-        "AddAct",
-        "MulAct",
-        "SubAct",
-        "QuantizedAdd",
-        "QuantizedSub",
-        "QuantizedMul",
-        "Mul",
-        "Add",
-        "Sub",
-        "Minimum",
-        "Maximum",
-        "SHL",
-        "SHR",
-    )
-)
+binary_elem_wise_main_ops = Op.op_set(Op.is_binary_elementwise_op)
 
-unary_elem_wise_main_ops = set(("LeakyRelu", "Abs", "CLZ",))  # Unary element-wise operations
+unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)  # Unary element-wise operations
 
 elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
 
-activation_ops = set(("QuantizedRelu", "QuantizedRelu1", "QuantizedRelu6", "Relu", "Relu6", "ReluN1To1"))
-npu_post_ops = activation_ops | set(
-    # Bias-add operations: Get rid of these once we have rewrites from Conv2D + BiasAdd + Activation to Conv2DBiasAct.
-    ("Mul", "Add", "QuantizedBiasAdd", "Requantize", "QuantizedBatchNorm", "BiasAdd", "FusedBatchNorm")
-)
+activation_ops = Op.op_set(Op.is_relu_op)
+npu_post_ops = activation_ops
 
 npu_post_fuse_limited_ops = set(
     # Set of post operators that should not be fused with main/elementwise ops
-    ("ConcatSliceWrite", "Sigmoid", "Tanh", "Quantize")
+    (Op.ConcatSliceWrite, Op.Sigmoid, Op.Tanh, Op.Quantize)
 )
 
-elem_wise_ops = elem_wise_main_ops | activation_ops | set(("Sigmoid", "Tanh"))
+elem_wise_ops = elem_wise_main_ops | activation_ops | set((Op.Sigmoid, Op.Tanh))
 
 
-quantization_ops = set(("Dequantize", "QuantizeV2", "Max", "Min"))
-cpu_ops = set(("Softmax", "QuantizedSoftmax", "LRN", "Shape", "QuantizedPad", "Pad", "AddN")) | quantization_ops
+quantization_ops = set((Op.Dequantize, Op.Max, Op.Min))
+cpu_ops = set((Op.Softmax, Op.LRN, Op.Shape, Op.Pad, Op.AddN)) | quantization_ops
 
-npu_dma_ops = set(("DMA",))
-startup_init_ops = set(("Const", "VariableV2", "Placeholder", "SubgraphInput"))
-memory_only_ops = set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims",))
+npu_dma_ops = set((Op.DMA,))
+startup_init_ops = set((Op.Const, Op.Placeholder, Op.SubgraphInput))
+memory_only_ops = set((Op.Squeeze, Op.Reshape, Op.QuantizedReshape, Op.ExpandDims,))
 
 
 test_sequence = [
@@ -234,10 +211,6 @@
 for (operation_set, incompatible_pack_flags, flags_to_set, flags_to_clear) in test_sequence:
     assert not flags_to_clear & flags_to_set
 
-    if operation_set is not None:
-        for op in operation_set:
-            assert len(op) > 1  # This is to avoid string literals being decomposed
-
 
 def pack_into_passes(nng, arch, verbose_packing=False):
     def visit_op(op, ignored):
@@ -254,7 +227,7 @@
             if op.type in startup_init_ops:
                 startup_list.append(op)
             else:
-                _, _, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
+                ofm_tensor = op.ofm
                 if ofm_tensor is None:
                     ofm_tensor = op.outputs[0]
                 build_pass((op,), ofm_tensor)
@@ -287,7 +260,7 @@
                                 continue
 
                         reverse_ops_list.append(curr_op)
-                        new_block_type = curr_op.attrs.get("npu_block_type", NpuBlockType.Default)
+                        new_block_type = curr_op.type.npu_block_type
                         if new_block_type != NpuBlockType.Default:
                             assert npu_block_type == NpuBlockType.Default
                             npu_block_type = new_block_type  # Only one major block type per pass
@@ -302,10 +275,8 @@
                                 PassFlags.Mac | PassFlags.ElementWise | PassFlags.Post | PassFlags.PostFusingLimited
                             ):
                                 assert len(curr_op.inputs) >= 1
-                                if curr_op.type == "BlockLSTM":
-                                    ifm_tensor = curr_op.inputs[3]
-                                else:
-                                    ifm_tensor = curr_op.inputs[0]
+                                ifm_tensor = curr_op.ifm
+                                assert ifm_tensor is not None
                                 assert ifm_tensor.purpose == TensorPurpose.FeatureMap
 
                         if flags_to_set & PassFlags.Dma:
@@ -377,7 +348,7 @@
             primary_op = create_primary_op(ops_list)
             if primary_op is not None:
                 visit_tensor_refcount[primary_op.inputs[0]] += 1
-                npu_block_type = primary_op.attrs["npu_block_type"]
+                npu_block_type = primary_op.type.npu_block_type
                 for input_tens in primary_op.inputs:
                     if input_tens not in input_set:
                         input_set.add(input_tens)
@@ -394,7 +365,7 @@
             for inp in primary_op.inputs:
                 if inp is None:
                     continue
-                if len(inp.ops) == 1 and inp.ops[0].type == "DMA" and inp.purpose == TensorPurpose.FeatureMap:
+                if len(inp.ops) == 1 and inp.ops[0].type == Op.DMA and inp.purpose == TensorPurpose.FeatureMap:
                     src_op = inp.ops[0]
                     if src_op in input_ops_list:
                         inp = src_op.inputs[0]
@@ -408,7 +379,7 @@
                 add_input_list(inp, input_set, input_refcounts, lut_list, ordered_input_list)
 
         name = ops_list[0].name
-        non_dma_ops = [op for op in ops_list if op.type != "DMA"]
+        non_dma_ops = [op for op in ops_list if op.type != Op.DMA]
         if non_dma_ops:
             name = non_dma_ops[0].name
         ps = Pass(name, placement, is_element_wise, npu_block_type)
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index da9be66..073b50f 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -50,6 +50,7 @@
 from .numeric_util import round_away_zero
 from .numeric_util import round_up_to_int
 from .operation import NpuBlockType
+from .operation import Op
 from .tensor import MemType
 from .tensor import TensorBlockTraversal
 from .tensor import TensorFormat
@@ -357,16 +358,16 @@
 
     # Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
     elementwise_mode_map = {
-        "MulAct": elementwise_mode.MUL.value,
-        "AddAct": elementwise_mode.ADD.value,
-        "SubAct": elementwise_mode.SUB.value,
-        "Minimum": elementwise_mode.MIN.value,
-        "Maximum": elementwise_mode.MAX.value,
-        "LeakyRelu": elementwise_mode.LRELU.value,
-        "Abs": elementwise_mode.ABS.value,
-        "CLZ": elementwise_mode.CLZ.value,
-        "SHR": elementwise_mode.SHR.value,
-        "SHL": elementwise_mode.SHL.value,
+        Op.Mul: elementwise_mode.MUL.value,
+        Op.Add: elementwise_mode.ADD.value,
+        Op.Sub: elementwise_mode.SUB.value,
+        Op.Minimum: elementwise_mode.MIN.value,
+        Op.Maximum: elementwise_mode.MAX.value,
+        Op.LeakyRelu: elementwise_mode.LRELU.value,
+        Op.Abs: elementwise_mode.ABS.value,
+        Op.CLZ: elementwise_mode.CLZ.value,
+        Op.SHR: elementwise_mode.SHR.value,
+        Op.SHL: elementwise_mode.SHL.value,
     }
 
     cmd_stream = []
@@ -439,15 +440,15 @@
             rounding_mode = (
                 rounding.NATURAL if primary_op.attrs.get("rounding_mode", "") == b"NATURAL" else rounding.TFL
             )
-            if primary_op.type == "ResizeBilinear":
+            if primary_op.type == Op.ResizeBilinear:
                 rounding_mode = rounding.TRUNCATE
-            fmf = primary_op.attrs.get("fused_memory_function", None)
-            faf = primary_op.attrs.get("fused_activation_function", None)
-            fused_quantize = any(op.type == "Quantize" for op in ps.ops)
+            fmf = primary_op.memory_function
+            faf = primary_op.activation
+            fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
             # Force output scale, used in operations with fused LUT
             # Note: with current LUT support, forced_ofm_quantization is always equal to cmd.ofm_tensor.quantization
             # except when primary_op is AddAct + 0 (no-op) + LUT
-            forced_ofm_quantization = primary_op.attrs.get("forced_output_quantization", None)
+            forced_ofm_quantization = primary_op.forced_output_quantization
             ofm_quant = cmd.ofm_tensor.quantization
             if forced_ofm_quantization is not None:
                 ofm_quant = forced_ofm_quantization
@@ -482,16 +483,16 @@
                     ifm2_broadcast |= IFM2Broadcast.ReverseOperandOrder
 
                 # Calculate scales needed for arithmetic elementwise operators
-                if primary_op.type in set(("AddAct", "MulAct", "SubAct",)):
+                if primary_op.type in set((Op.Add, Op.Mul, Op.Sub,)):
                     input_scale = cmd.ifm_tensor.quantization.scale_f32 if cmd.ifm_tensor.quantization else None
                     input2_scale = cmd.ifm2_tensor.quantization.scale_f32 if cmd.ifm2_tensor.quantization else None
                     output_scale = ofm_quant.scale_f32 if ofm_quant else None
                     use_global_scale = True
 
-                    if output_scale is not None and faf in ("Sigmoid", "Tanh"):
+                    if output_scale is not None and faf in (Op.Sigmoid, Op.Tanh):
                         output_scale = 1 / 0x3000
 
-                    if primary_op.type == "MulAct":
+                    if primary_op.type == Op.Mul:
                         if None in (input_scale, input2_scale, output_scale):
                             ofm_scale = 1
                             shift = 0
@@ -537,11 +538,11 @@
                         emit.cmd1_with_offset(cmd1.NPU_SET_OPB_SCALE, opb_scale)
                         emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
 
-                elif primary_op.type in set(("LeakyRelu", "Abs",)):
+                elif primary_op.type in set((Op.LeakyRelu, Op.Abs,)):
                     output_scale = ofm_quant.scale_f32
                     use_global_scale = True
 
-                    if primary_op.type == "LeakyRelu":
+                    if primary_op.type == Op.LeakyRelu:
                         output_scale = primary_op.attrs["alpha"]
 
                     ofm_scale, shift = scaling.quantise_scale(output_scale)
@@ -599,10 +600,10 @@
 
             emit.cmd0_with_param(cmd0.NPU_SET_ACC_FORMAT, acc_format_map[shared_buffer.use_accumulator_element])
 
-            if primary_op.type == "ResizeBilinear":
+            if primary_op.type == Op.ResizeBilinear:
                 # perform nearest neighbor upscale
                 emit.cmd0_with_param(cmd0.NPU_SET_IFM_UPSCALE, resampling_mode.NEAREST)
-            elif primary_op.type == "Conv2DBackpropInputSwitchedBias":
+            elif primary_op.type == Op.Conv2DBackpropInputSwitchedBias:
                 # perform insert zero upscale
                 emit.cmd0_with_param(cmd0.NPU_SET_IFM_UPSCALE, resampling_mode.TRANSPOSE)
             else:
@@ -651,12 +652,9 @@
 
                     valid_padding = sum(explicit_padding) == 0
 
-                    if (
-                        primary_op.type in set(("AvgPool", "AvgPoolAct", "ResizeBilinear", "ReduceSum"))
-                        and valid_padding
-                    ):
+                    if primary_op.type in set((Op.AvgPool, Op.ResizeBilinear, Op.ReduceSum)) and valid_padding:
                         # For valid padding vela has to output scaling values
-                        if faf == "Sigmoid" or faf == "Tanh":
+                        if faf == Op.Sigmoid or faf == Op.Tanh:
                             rescale = 0x3000 * cmd.ifm_tensor.quantization.scale_f32
                             if cmd.ifm_tensor.dtype == DataType.int16:
                                 # Calculate scale and shift for the output scale of 1/(3*4096)
@@ -675,7 +673,7 @@
                             ifm_scale_f64 = np.double(cmd.ifm_tensor.quantization.scale_f32)
                             ofm_scale_f64 = np.double(ofm_quant.scale_f32)
                             scale, shift = scaling.quantise_scale(ifm_scale_f64 / ofm_scale_f64)
-                        elif primary_op.type == "ResizeBilinear" and "rescale" in primary_op.attrs:
+                        elif primary_op.type == Op.ResizeBilinear and "rescale" in primary_op.attrs:
                             rescale = primary_op.attrs["rescale"]
                             rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
                             scale, shift = scaling.quantise_pooling_scale(k_height * k_width, rescale_bits)
@@ -689,7 +687,7 @@
                                 rescale = cmd.ifm_tensor.quantization.scale_f32 / ofm_quant.scale_f32
                                 rescale_bits = 0
                                 if k_height == k_width == 1:
-                                    if fmf == "ConcatSliceWrite":
+                                    if fmf == Op.ConcatSliceWrite:
                                         rounding_mode = rounding.NATURAL
                                     if rescale > 1:
                                         rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
@@ -814,35 +812,35 @@
                 # Even if no activation function, values need to be set to override previous values
                 faf_min = ofm_quant_qmin
                 faf_max = ofm_quant_qmax
-            elif faf == "Relu":
+            elif faf == Op.Relu:
                 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.NONE)
                 faf_min = quantise_float32(0.0, ofm_quant.scale_f32, ofm_quant.zero_point)
                 faf_max = ofm_quant_qmax
-            elif faf == "Relu6":
+            elif faf == Op.Relu6:
                 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.NONE)
                 faf_min = quantise_float32(0.0, ofm_quant.scale_f32, ofm_quant.zero_point)
                 faf_max = quantise_float32(6.0, ofm_quant.scale_f32, ofm_quant.zero_point)
-            elif faf == "ReluN1To1":
+            elif faf == Op.ReluN1To1:
                 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.NONE)
                 faf_min = quantise_float32(-1.0, ofm_quant.scale_f32, ofm_quant.zero_point)
                 faf_max = quantise_float32(1.0, ofm_quant.scale_f32, ofm_quant.zero_point)
-            elif faf == "Tanh":
+            elif faf == Op.Tanh:
                 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.TANH)
-                if primary_op.type in set(("AvgPool", "AvgPoolAct", "ResizeBilinear")):
+                if primary_op.type in set((Op.AvgPool, Op.ResizeBilinear)):
                     faf_min = quantise_float32(-1.0, ofm_quant.scale_f32, ofm_quant.zero_point)
                     faf_max = quantise_float32(1.0, ofm_quant.scale_f32, ofm_quant.zero_point)
                 else:
                     faf_min = quantise_float32(clamp_tanh(ifm_min), ofm_quant.scale_f32, ofm_quant.zero_point)
                     faf_max = quantise_float32(clamp_tanh(ifm_max), ofm_quant.scale_f32, ofm_quant.zero_point)
-            elif faf == "Sigmoid":
+            elif faf == Op.Sigmoid:
                 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.SIGMOID)
-                if primary_op.type in set(("AvgPool", "AvgPoolAct", "ResizeBilinear")):
+                if primary_op.type in set((Op.AvgPool, Op.ResizeBilinear)):
                     faf_min = quantise_float32(0, ofm_quant.scale_f32, ofm_quant.zero_point)
                     faf_max = quantise_float32(1.0, ofm_quant.scale_f32, ofm_quant.zero_point)
                 else:
                     faf_min = quantise_float32(clamp_sigmoid(ifm_min), ofm_quant.scale_f32, ofm_quant.zero_point)
                     faf_max = quantise_float32(clamp_sigmoid(ifm_max), ofm_quant.scale_f32, ofm_quant.zero_point)
-            elif faf == "LUT":
+            elif faf == Op.LUT:
                 lut_index = int(activation.LUT_START.value) + primary_op.attrs.get("lut_index", -1)
                 assert activation.LUT_START.value <= lut_index <= activation.LUT_END.value, "LUT index out of range."
                 if cmd.ofm_tensor.dtype == DataType.int32:
@@ -851,7 +849,7 @@
                 faf_min = ofm_quant_qmin
                 faf_max = ofm_quant_qmax
             else:
-                raise Exception("Unsupported fused_activation_function = " + faf)
+                raise Exception("Unsupported fused_activation_function = " + faf.name)
 
             # Activation range needs to be set based upon the quantisation range and the fused activation range
             emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION_MIN, max(ofm_quant_qmin, faf_min))
@@ -911,14 +909,11 @@
 
                 need_zero_point = (
                     (faf is not None and forced_ofm_quantization is None)
-                    or (fmf == "ConcatSliceWrite")
+                    or (fmf == Op.ConcatSliceWrite)
                     or fused_quantize
                 )
                 if (
-                    (
-                        primary_op.type in set(("AvgPool", "AvgPoolAct", "ResizeBilinear", "CLZ", "SHL"))
-                        and not need_zero_point
-                    )
+                    (primary_op.type in set((Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL)) and not need_zero_point)
                     or (
                         tens.dtype == DataType.int32
                         and zero_point_op in (cmd0.NPU_SET_IFM_ZERO_POINT, cmd0.NPU_SET_IFM2_ZERO_POINT)
@@ -933,7 +928,7 @@
                         zero_point = forced_ofm_quantization.zero_point
                     elif (
                         "resizebilinear" in primary_op.attrs
-                        and primary_op.type == "AddAct"
+                        and primary_op.type == Op.Add
                         and cmd0.NPU_SET_OFM_ZERO_POINT == zero_point_op
                     ):
                         # Force output zero point same as the input zero point
@@ -1108,7 +1103,7 @@
                 # Vector product is implemented using a 1x1 convolution
                 emit.cmd_do_operation(cmd0.NPU_OP_CONV)
             elif npu_block_type == NpuBlockType.Pooling:
-                param = pooling_mode.MAX.value if "Max" in primary_op.type else pooling_mode.AVERAGE.value
+                param = pooling_mode.MAX.value if primary_op.type.is_maxpool_op() else pooling_mode.AVERAGE.value
                 emit.cmd_do_operation(cmd0.NPU_OP_POOL, param=param)
             elif npu_block_type == NpuBlockType.ReduceSum:
                 emit.cmd_do_operation(cmd0.NPU_OP_POOL, param=pooling_mode.REDUCE_SUM.value)
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 5c2ddab..41e1529 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -37,6 +37,7 @@
 from .npu_performance import PassCycles
 from .numeric_util import full_shape
 from .operation import NpuBlockType
+from .operation import Op
 from .shared_buffer_allocation import find_block_configs_suitable_for_pass_and_shared_buffer
 from .shared_buffer_allocation import shared_buffer_allocation_for_pass_and_block_config
 from .tensor import MemArea
@@ -254,11 +255,7 @@
         self.pareto_max_candidates = 16
 
         self.ifm_stream_npu_blocks = set(
-            (
-                NpuBlockType.ConvolutionMxN,
-                NpuBlockType.ConvolutionDepthWise,
-                NpuBlockType.Pooling,
-            )
+            (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,)
         )
 
     num_pareto_metrics = 4
@@ -652,7 +649,7 @@
     def avoid_for_cascading(self, pred_candidate):
         for op in pred_candidate.ops:
             if (
-                op.type == "ConcatSliceWrite"
+                op.type == Op.ConcatSliceWrite
                 and self.arch.feature_map_storage_mem_area != self.arch.fast_storage_mem_area
             ):
                 # For SRAM spilling, concat op is avoided as predecessor
@@ -981,9 +978,9 @@
                             use_NHCWB16 = False
                             use_fast_storage = False
                             continue
-                        if op.type == "ReduceSum" and output.dtype == DataType.int32:
+                        if op.type == Op.ReduceSum and output.dtype == DataType.int32:
                             use_NHCWB16 = False
-                        elif op.type == "Reshape":
+                        elif op.type == Op.Reshape:
                             # Detect no-op reshapes by comparing their full input and output tensor shapes.
                             inshape = full_shape(4, op.inputs[0].shape, 1)
                             outshape = full_shape(4, op.outputs[0].shape, 1)
@@ -995,7 +992,7 @@
                             incompatible_consumers = [
                                 (
                                     not consumer.run_on_npu
-                                    or consumer.type == "Reshape"
+                                    or consumer.type == Op.Reshape
                                     or (consumer is last_op_in_subgraph)
                                 )
                                 for consumer in op.outputs[0].consumer_list
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py
index 58856a3..aa5f4c8 100644
--- a/ethosu/vela/shared_buffer_allocation.py
+++ b/ethosu/vela/shared_buffer_allocation.py
@@ -25,6 +25,7 @@
 from .errors import VelaError
 from .ethos_u55_regs.ethos_u55_regs import resampling_mode
 from .operation import NpuBlockType
+from .operation import Op
 from .range_set import MemoryRangeSet
 from .tensor import MemArea
 
@@ -39,7 +40,7 @@
         ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm()
         tensors = [t for t in (ifm_tensor, ifm2_tensor, ofm_tensor) if t is not None]
         scales = [t.quantization.scale_f32 for t in tensors if t.quantization is not None]
-        has_scale = len(tensors) == len(scales) and not None in scales
+        has_scale = len(tensors) == len(scales) and None not in scales
 
         strides = (1, 1, 1, 1)
         dilation = (1, 1, 1, 1)
@@ -53,7 +54,7 @@
             k_h = 1
             k_w = 1
             if weight_tensor:
-                if ps.primary_op.type != "FullyConnectedAct":
+                if ps.primary_op.type != Op.FullyConnected:
                     k_h = weight_tensor.shape[0]
                     k_w = weight_tensor.shape[1]
             else:
@@ -94,7 +95,9 @@
                     self.use_ifm_element == SHRAMElements.IFM16_Elementwise
                 )
             elif self.ifm_bits == 32:
-                assert self.is_elementwise or ps.npu_block_type == NpuBlockType.ReduceSum, "Unsupported 32-bit IFM operation"
+                assert (
+                    self.is_elementwise or ps.npu_block_type == NpuBlockType.ReduceSum
+                ), "Unsupported 32-bit IFM operation"
                 self.use_ifm_element = SHRAMElements.IFM32
             else:
                 assert self.ifm_bits == 8, "Unexpected IFM bitdepth"
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index 5a5396f..12c2016 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -25,6 +25,7 @@
 from . import fp_math
 from . import scaling
 from .data_type import DataType
+from .operation import Op
 from .operation import Operation
 from .tensor import create_const_tensor
 from .tensor import create_reshape_tensor
@@ -229,7 +230,7 @@
 
         # PASS 0 - Depthwise Maxpool
         maxpool_op = self.op.clone("_maxpool0")
-        maxpool_op.type = "MaxPool"
+        maxpool_op.type = Op.MaxPool
         maxpool_h = ifm.shape[1] * ifm.shape[2]
         maxpool_w = ifm.shape[3]
         maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1]
@@ -246,7 +247,7 @@
         maxpool_op.set_output_tensor(ifm_max)
 
         # PASS 1 - Sub+LUT(exp)
-        sub_op = Operation("SubAct", self.op.name + "_sub1")
+        sub_op = Operation(Op.Sub, self.op.name + "_sub1")
         sub_op.add_input_tensor(ifm)
         sub_op.add_input_tensor(create_reshape_tensor(ifm_max, [1, ifm.shape[1], ifm.shape[2], 1]))
         sub_op.set_activation_lut(
@@ -262,7 +263,7 @@
         sub_op.set_output_tensor(ifm_exp)
 
         # PASS 2 - SHR
-        shr2_op = Operation("SHR", self.op.name + "_shr2")
+        shr2_op = Operation(Op.SHR, self.op.name + "_shr2")
         shr2_op.attrs["rounding_mode"] = b"NATURAL"
         shr2_op.add_input_tensor(ifm_exp)
         shr2_op.add_input_tensor(
@@ -275,7 +276,7 @@
         shr2_op.set_output_tensor(rescaled_exp)
 
         # PASS 3 - Reduce sum
-        reduce_sum_op = Operation("ReduceSum", self.op.name + "_reduce_sum3")
+        reduce_sum_op = Operation(Op.ReduceSum, self.op.name + "_reduce_sum3")
         reduce_sum_op.attrs["padding"] = b"VALID"
         reduce_sum_op.attrs["stride_w"] = 1
         reduce_sum_op.attrs["stride_h"] = 1
@@ -291,14 +292,14 @@
         reduce_sum_op.set_output_tensor(sum_of_exp)
 
         # PASS 4 - CLZ
-        clz_op = Operation("CLZ", self.op.name + "_clz4")
+        clz_op = Operation(Op.CLZ, self.op.name + "_clz4")
         clz_op.add_input_tensor(sum_of_exp)
         headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
         headroom_plus_one.quantization = no_scale_quant
         clz_op.set_output_tensor(headroom_plus_one)
 
         # PASS 5 - Sub
-        sub5_op = Operation("SubAct", self.op.name + "_sub5")
+        sub5_op = Operation(Op.Sub, self.op.name + "_sub5")
         sub5_op.add_input_tensor(
             create_const_tensor(
                 "headroom_offset_const",
@@ -316,7 +317,7 @@
 
         # PASS 6 - Sub
         one = create_const_tensor("one_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant)
-        sub6_op = Operation("SubAct", self.op.name + "_sub6")
+        sub6_op = Operation(Op.Sub, self.op.name + "_sub6")
         sub6_op.add_input_tensor(headroom_plus_one)
         sub6_op.add_input_tensor(one)
         headroom = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0")
@@ -324,7 +325,7 @@
         sub6_op.set_output_tensor(headroom)
 
         # PASS 7 - SHL
-        shl7_op = Operation("SHL", self.op.name + "_shl7")
+        shl7_op = Operation(Op.SHL, self.op.name + "_shl7")
         shl7_op.add_input_tensor(sum_of_exp)
         shl7_op.add_input_tensor(headroom)
         shifted_sum = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0")
@@ -332,7 +333,7 @@
         shl7_op.set_output_tensor(shifted_sum)
 
         # PASS 8 - Sub
-        sub8_op = Operation("SubAct", self.op.name + "_sub8")
+        sub8_op = Operation(Op.Sub, self.op.name + "_sub8")
         sub8_op.add_input_tensor(shifted_sum)
         sub8_op.add_input_tensor(
             create_const_tensor(
@@ -344,7 +345,7 @@
         sub8_op.set_output_tensor(shifted_sum_minus_one)
 
         # PASS 9 - SHL
-        shl9_op = Operation("SHL", self.op.name + "_shl9")
+        shl9_op = Operation(Op.SHL, self.op.name + "_shl9")
         shl9_op.add_input_tensor(shifted_sum_minus_one)
         shl9_op.add_input_tensor(one)
         shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
@@ -352,7 +353,7 @@
         shl9_op.set_output_tensor(shifted_sum_minus_one)
 
         # PASS 10 - Add
-        add10_op = Operation("AddAct", self.op.name + "_add10")
+        add10_op = Operation(Op.Add, self.op.name + "_add10")
         add10_op.add_input_tensor(
             create_const_tensor(
                 "F0_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 31) - 1], np.int32, quantization=no_scale_quant
@@ -365,7 +366,7 @@
         add10_op.set_output_tensor(half_denominator)
 
         # PASS 11 - Multiply
-        mul11_op = Operation("MulAct", self.op.name + "_mul11")
+        mul11_op = Operation(Op.Mul, self.op.name + "_mul11")
         mul11_op.add_input_tensor(half_denominator)
         mul11_op.add_input_tensor(
             create_const_tensor(
@@ -383,7 +384,7 @@
         mul11_op.set_output_tensor(rescaled)
 
         # PASS 12 - Add
-        add12_op = Operation("AddAct", self.op.name + "_add12")
+        add12_op = Operation(Op.Add, self.op.name + "_add12")
         add12_op.add_input_tensor(rescaled)
         add12_op.add_input_tensor(
             create_const_tensor(
@@ -403,7 +404,7 @@
         )
         for i in range(3):
             # PASS 13, 18, 23 - MUL
-            mul_op = Operation("MulAct", self.op.name + "_mul%d" % (13 + i * 5))
+            mul_op = Operation(Op.Mul, self.op.name + "_mul%d" % (13 + i * 5))
             mul_op.add_input_tensor(nr_x)
             mul_op.add_input_tensor(half_denominator)
             half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, mul_op.name + "_0")
@@ -411,14 +412,14 @@
             half_denominator_times_x.quantization.scale_f32 = 2.0
             mul_op.set_output_tensor(half_denominator_times_x)
             # PASS 14, 19, 24 - SUB
-            sub_op = Operation("SubAct", self.op.name + "_sub%d" % (14 + i * 5))
+            sub_op = Operation(Op.Sub, self.op.name + "_sub%d" % (14 + i * 5))
             sub_op.add_input_tensor(F2_one)
             sub_op.add_input_tensor(half_denominator_times_x)
             one_minus_half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, sub_op.name + "_0")
             one_minus_half_denominator_times_x.quantization = one_scale_quant
             sub_op.set_output_tensor(one_minus_half_denominator_times_x)
             # PASS 15, 20, 25 - MUL
-            mul_op = Operation("MulAct", self.op.name + "_mul%d" % (15 + i * 5))
+            mul_op = Operation(Op.Mul, self.op.name + "_mul%d" % (15 + i * 5))
             mul_op.add_input_tensor(nr_x)
             mul_op.add_input_tensor(one_minus_half_denominator_times_x)
             to_rescale = Tensor(reduce_sum_shape, DataType.int32, mul_op.name + "_0")
@@ -426,14 +427,14 @@
             to_rescale.quantization.scale_f32 = 2.0
             mul_op.set_output_tensor(to_rescale)
             # PASS 16, 21, 26 - MUL
-            shl_op = Operation("MulAct", self.op.name + "_mul%d" % (16 + i * 5))
+            shl_op = Operation(Op.Mul, self.op.name + "_mul%d" % (16 + i * 5))
             shl_op.add_input_tensor(to_rescale)
             shl_op.add_input_tensor(four)
             to_add = Tensor(reduce_sum_shape, DataType.int32, shl_op.name + "_0")
             to_add.quantization = no_scale_quant
             shl_op.set_output_tensor(to_add)
             # PASS 17, 22, 27 - ADD
-            add_op = Operation("AddAct", self.op.name + "_add%d" % (17 + i * 5))
+            add_op = Operation(Op.Add, self.op.name + "_add%d" % (17 + i * 5))
             add_op.add_input_tensor(nr_x)
             add_op.add_input_tensor(to_add)
             nr_x = Tensor(reduce_sum_shape, DataType.int32, add_op.name + "_0")
@@ -441,7 +442,7 @@
             add_op.set_output_tensor(nr_x)
 
         # PASS 28 - Multiply
-        mul28_op = Operation("MulAct", self.op.name + "_mul28")
+        mul28_op = Operation(Op.Mul, self.op.name + "_mul28")
         mul28_op.add_input_tensor(nr_x)
         mul28_op.add_input_tensor(
             create_const_tensor("two_const", [1, 1, 1, 1], DataType.int32, [2], np.int32, quantization=no_scale_quant)
@@ -451,7 +452,7 @@
         mul28_op.set_output_tensor(scale_factor)
 
         # PASS 29 - Multiply
-        mul_op = Operation("MulAct", self.op.name + "_mul29")
+        mul_op = Operation(Op.Mul, self.op.name + "_mul29")
         mul_op.add_input_tensor(ifm_exp)
         mul_op.add_input_tensor(scale_factor)
         scaled_exp = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
@@ -460,7 +461,7 @@
         mul_op.set_output_tensor(scaled_exp)
 
         # PASS 30 - SHR
-        shr30_op = Operation("SHR", self.op.name + "_shr30")
+        shr30_op = Operation(Op.SHR, self.op.name + "_shr30")
         shr30_op.attrs["rounding_mode"] = b"NATURAL"
         shr30_op.add_input_tensor(scaled_exp)
         shr30_op.add_input_tensor(right_shift)
@@ -474,7 +475,7 @@
 
         # PASS 0 - Depthwise Maxpool
         maxpool_op = self.op.clone("_maxpool0")
-        maxpool_op.type = "MaxPool"
+        maxpool_op.type = Op.MaxPool
         maxpool_h = ifm.shape[1] * ifm.shape[2]
         maxpool_w = ifm.shape[3]
         maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1]
@@ -491,7 +492,7 @@
         maxpool_op.set_output_tensor(maxpool_ofm)
 
         # PASS 1 - Sub
-        sub1_op = Operation("SubAct", self.op.name + "_sub1")
+        sub1_op = Operation(Op.Sub, self.op.name + "_sub1")
         sub1_op.add_input_tensor(ifm)
         sub1_op.add_input_tensor(create_reshape_tensor(maxpool_ofm, [1, ifm.shape[1], ifm.shape[2], 1]))
         sub1_ofm = Tensor(ifm.shape, DataType.int32, sub1_op.name + "_0")
@@ -504,7 +505,7 @@
         mul2_scale, _ = scaling.elementwise_mul_scale(sub1_ofm.quantization.scale_f32, beta, mul2_out_range)
         mul2_quant = ifm.quantization.clone()
         mul2_quant.scale_f32 = beta
-        mul2_op = Operation("MulAct", self.op.name + "_mul2")
+        mul2_op = Operation(Op.Mul, self.op.name + "_mul2")
         mul2_op.add_input_tensor(sub1_ofm)
         mul2_op.add_input_tensor(
             create_const_tensor(
@@ -517,7 +518,7 @@
         mul2_op.set_output_tensor(mul2_ofm)
 
         # PASS 3 - Add+LUT(exp)
-        add_op = Operation("AddAct", self.op.name + "_add3")
+        add_op = Operation(Op.Add, self.op.name + "_add3")
         add_op.add_input_tensor(mul2_ofm)
         add_op.add_input_tensor(
             create_const_tensor(
@@ -534,7 +535,7 @@
         add_op.set_output_tensor(exp_ofm)
 
         # PASS 4 - Reduce sum
-        reduce_sum_op = Operation("ReduceSum", self.op.name + "_reduce_sum4")
+        reduce_sum_op = Operation(Op.ReduceSum, self.op.name + "_reduce_sum4")
         reduce_sum_op.attrs["padding"] = b"VALID"
         reduce_sum_op.attrs["stride_w"] = 1
         reduce_sum_op.attrs["stride_h"] = 1
@@ -550,14 +551,14 @@
         reduce_sum_op.set_output_tensor(sum_of_exp)
 
         # PASS 5 - CLZ
-        clz_op = Operation("CLZ", self.op.name + "_clz5")
+        clz_op = Operation(Op.CLZ, self.op.name + "_clz5")
         clz_op.add_input_tensor(sum_of_exp)
         headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0")
         headroom_plus_one.quantization = no_scale_quant
         clz_op.set_output_tensor(headroom_plus_one)
 
         # PASS 6 - Sub
-        sub6_op = Operation("SubAct", self.op.name + "_sub6")
+        sub6_op = Operation(Op.Sub, self.op.name + "_sub6")
         sub6_op.add_input_tensor(
             create_const_tensor(
                 sub6_op.name + "_const", [1, 1, 1, 1], DataType.int32, [31], np.int32, quantization=no_scale_quant
@@ -569,7 +570,7 @@
         sub6_op.set_output_tensor(reciprocal_right_shift)
 
         # PASS 7 - SHL
-        shl7_op = Operation("SHL", self.op.name + "_shl7")
+        shl7_op = Operation(Op.SHL, self.op.name + "_shl7")
         shl7_op.add_input_tensor(
             create_const_tensor(
                 shl7_op.name + "_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant
@@ -581,7 +582,7 @@
         shl7_op.set_output_tensor(constant_one)
 
         # PASS 8 - Sub
-        sub8_op = Operation("SubAct", self.op.name + "_sub8")
+        sub8_op = Operation(Op.Sub, self.op.name + "_sub8")
         sub8_op.add_input_tensor(sum_of_exp)
         sub8_op.add_input_tensor(constant_one)
         sum_of_exps_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0")
@@ -589,7 +590,7 @@
         sub8_op.set_output_tensor(sum_of_exps_minus_one)
 
         # PASS 9 - SHL
-        shl9_op = Operation("SHL", self.op.name + "_shl9")
+        shl9_op = Operation(Op.SHL, self.op.name + "_shl9")
         shl9_op.add_input_tensor(sum_of_exps_minus_one)
         shl9_op.add_input_tensor(headroom_plus_one)
         shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0")
@@ -597,7 +598,7 @@
         shl9_op.set_output_tensor(shifted_sum_minus_one)
 
         # PASS 10 - SHR
-        shr10_op = Operation("SHR", self.op.name + "_shr10")
+        shr10_op = Operation(Op.SHR, self.op.name + "_shr10")
         shr10_op.add_input_tensor(shifted_sum_minus_one)
         shr10_op.add_input_tensor(
             create_const_tensor(
@@ -609,7 +610,7 @@
         shr10_op.set_output_tensor(shifted_sum_minus_one_16)
 
         # PASS 11 - Sub+LUT(one over one plus x)
-        sub11_op = Operation("SubAct", self.op.name + "_sub11")
+        sub11_op = Operation(Op.Sub, self.op.name + "_sub11")
         sub11_op.add_input_tensor(shifted_sum_minus_one_16)
         sub11_op.add_input_tensor(
             create_const_tensor(
@@ -631,7 +632,7 @@
         sub11_op.set_output_tensor(reciprocal_scale)
 
         # PASS 12 - Multiply
-        mul_op = Operation("MulAct", self.op.name + "_mul12")
+        mul_op = Operation(Op.Mul, self.op.name + "_mul12")
         mul_op.add_input_tensor(exp_ofm)
         mul_op.add_input_tensor(reciprocal_scale)
         mul_ofm = Tensor(exp_ofm.shape, DataType.int32, mul_op.name + "_0")
@@ -639,7 +640,7 @@
         mul_op.set_output_tensor(mul_ofm)
 
         # PASS 13 - SHR
-        shr13_op = Operation("SHR", self.op.name + "_shr13")
+        shr13_op = Operation(Op.SHR, self.op.name + "_shr13")
         shr13_op.add_input_tensor(mul_ofm)
         shr13_op.add_input_tensor(reciprocal_right_shift)
         shr13_op.set_output_tensor(ofm)
diff --git a/ethosu/vela/stats_writer.py b/ethosu/vela/stats_writer.py
index 2ea14f2..6fd68f8 100644
--- a/ethosu/vela/stats_writer.py
+++ b/ethosu/vela/stats_writer.py
@@ -25,6 +25,7 @@
 from .npu_performance import MacCount
 from .npu_performance import PassCycles
 from .numeric_util import round_up_to_int
+from .operation import Op
 from .tensor import MemArea
 from .tensor import TensorPurpose
 
@@ -192,11 +193,11 @@
                     continue  # skip the dummy init pass
 
                 for ps in cps.passes:
-                    if len(ps.ops) == 1 and ps.ops[0].type == "NpuOp":
+                    if len(ps.ops) == 1 and ps.ops[0].type == Op.CustomNpuOp:
                         # just treat this as a call, unroll it
                         write_subgraph(ps.ops[0].attrs["subgraph"])
                         continue
-                    stats = [ps.name, " ".join(op.type for op in ps.ops)]
+                    stats = [ps.name, " ".join(op.type.name for op in ps.ops)]
                     stats += [ps.placement.name]
                     stats += [cps.strategy.name]
                     stats += list(ps.block_config)
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 6ae072f..3d4a09f 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -20,6 +20,7 @@
 from .data_type import BaseType
 from .data_type import DataType
 from .operation import get_slice_offsets
+from .operation import Op
 
 
 # Custom decorator function to allow formatting docstrings containing "{}"
@@ -37,18 +38,18 @@
 
 class SupportedOperators:
     # Categorised lists of supported operators
-    npu_pre_ops = set(("QuantizedResizeBilinear", "SplitSliceRead",))
-    convolution_ops = set(("Conv2DBiasAct", "Conv2D", "QuantizedConv2D",))
-    depthwise_convolution_ops = set(("DepthwiseConv2dBiasAct", "DepthwiseConv2dNative", "QuantizedDepthwiseConv2D",))
-    transpose_convolution_ops = set(("Conv2DBackpropInput",))
-    max_pooling_ops = set(("QuantizedMaxPool", "MaxPool", "MaxPoolAct",))
-    avg_pooling_ops = set(("QuantizedAvgPool", "AvgPool", "AvgPoolAct",))
-    pooling_ops = set(("ReduceSum",)) | max_pooling_ops | avg_pooling_ops
-    resizing_ops = set(("ResizeBilinear",))
-    fc_vector_products = set(("QuantizedMatMul", "MatMul", "FullyConnectedAct",))
+    npu_pre_ops = set((Op.SplitSliceRead,))
+    convolution_ops = set((Op.Conv2DBias, Op.Conv2D, Op.QuantizedConv2D,))
+    depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
+    transpose_convolution_ops = set((Op.Conv2DBackpropInput,))
+    max_pooling_ops = Op.op_set(Op.is_maxpool_op)
+    avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
+    pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
+    resizing_ops = set((Op.ResizeBilinear,))
+    fc_vector_products = set((Op.QuantizedMatMul, Op.MatMul, Op.FullyConnected,))
     mac_main_ops = (
         # RNN/LSTM/GRU
-        set(("BlockLSTM",))
+        set((Op.BlockLSTM,))
         # convolutions
         | convolution_ops
         # depth-wise convolutions
@@ -62,45 +63,29 @@
         # FC layers
         | fc_vector_products
     )
-    unary_elem_wise_main_ops = set(("LeakyRelu", "Abs", "CLZ",))
-    binary_elem_wise_min_max_ops = set(("Minimum", "Maximum",))
-    binary_elem_wise_shift_ops = set(("SHL", "SHR",))
-    binary_elem_wise_add_mul_sub = set(
-        ("AddAct", "MulAct", "SubAct", "QuantizedAdd", "QuantizedSub", "QuantizedMul", "Mul", "Add", "Sub",)
-    )
+    unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
+    binary_elem_wise_min_max_ops = set((Op.Minimum, Op.Maximum,))
+    binary_elem_wise_shift_ops = set((Op.SHL, Op.SHR,))
+    binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.Sub,))
     binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
     elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
     supported_int32_tensor_ops = (
-        set(("Requantize", "ReduceSum", "CLZ",)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
+        set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
     )
-    activation_ops = set(
-        (
-            "QuantizedRelu",
-            "QuantizedRelu1",
-            "QuantizedRelu6",
-            "Relu",
-            "Relu6",
-            "ReluN1To1",
-            "Sigmoid",
-            "Tanh",
-            "Softmax",
-        )
-    )
+    activation_ops = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Sigmoid, Op.Tanh, Op.Softmax,))
     npu_post_ops = (
-        # concatenation write direction
-        set(("ConcatSliceWrite",))
-        # bias add and batch norm
-        | set(("QuantizedBiasAdd", "Requantize", "QuantizedBatchNorm", "BiasAdd", "FusedBatchNorm",))
-        # Quantization
-        | set(("Quantize",))
         # activation functions
-        | activation_ops
+        activation_ops
+        # concatenation write direction
+        | set((Op.ConcatSliceWrite,))
+        # Quantization
+        | set((Op.Quantize,))
     )
-    split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped", "Unpack",))
-    concat_ops = set(("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped", "Pack",))
-    memory_only_ops = set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims",)) | concat_ops | split_ops
-    shapeless_input_ops = set(("Split", "SplitV",)) | binary_elem_wise_main_ops
-    supported_fused_activations = set(("Relu", "Relu6", "ReluN1To1", "Tanh", "Sigmoid", "LUT",))
+    split_ops = set((Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack,))
+    concat_ops = set((Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack,))
+    memory_only_ops = set((Op.Squeeze, Op.Reshape, Op.QuantizedReshape, Op.ExpandDims,)) | concat_ops | split_ops
+    shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV,))
+    supported_fused_activations = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Tanh, Op.Sigmoid, Op.LUT,))
     supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | npu_post_ops | memory_only_ops
     supported_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
     # Defined ranges for allowed values:
@@ -233,7 +218,7 @@
     @docstring_format_args([supported_fused_activations])
     def constraint_faf(cls, op):
         "The fused activation function (if present) must be one of type: {}"
-        faf = op.attrs.get("fused_activation_function")
+        faf = op.activation
         valid = (faf is None) or (faf in cls.supported_fused_activations)
         extra = "fused_activation_function={}".format(faf)
         return valid, extra
@@ -300,7 +285,7 @@
     @classmethod
     def check_depthwise_convolution_restrictions(cls, op):
         # check depth
-        ifm_tensor, _, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
+        ifm_tensor, ofm_tensor = op.get_ifm_ofm()
         if op.attrs["depth_multiplier"] > 1 and not (
             (ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"])
         ):
@@ -337,9 +322,9 @@
             return False
 
         # check data type
-        ifm_tensor, _, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
+        ifm_tensor, ofm_tensor = op.get_ifm_ofm()
         if ifm_tensor.dtype != ofm_tensor.dtype:
-            if op.type != "ReduceSum":
+            if op.type != Op.ReduceSum:
                 return False
             # TODO: else check ReduceSum restrictions.
 
@@ -365,7 +350,7 @@
     @classmethod
     def check_resize_restrictions(cls, op):
         # check unsupported upscaling factor
-        if op.type == "ResizeBilinear":
+        if op.type == Op.ResizeBilinear:
             if op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1:
                 return True
             if op.inputs[0].shape == op.outputs[0].shape:
@@ -424,10 +409,10 @@
                 ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32
             ):
                 return False
-        elif op.type in cls.binary_elem_wise_shift_ops | set(("CLZ")):
+        elif op.type in cls.binary_elem_wise_shift_ops:
             if ifm_tensor.dtype != DataType.int32 or ifm2_tensor.dtype != DataType.int32:
                 return False
-            if op.type in ("CLZ", "SHL") and ofm_tensor.dtype != DataType.int32:
+            if op.type in (Op.CLZ, Op.SHL) and ofm_tensor.dtype != DataType.int32:
                 return False
 
         # check batch size
@@ -438,7 +423,7 @@
                 return False
 
         # negative alpha values are not supported
-        if op.type == "LeakyRelu" and op.attrs["alpha"] < 0:
+        if op.type == Op.LeakyRelu and op.attrs["alpha"] < 0:
             return False
 
         # check if ifm or ifm2 has ofm shape
@@ -452,7 +437,7 @@
 
     @classmethod
     def check_memory_only_restrictions(cls, op):
-        if op.type == "StridedSlice":
+        if op.type == Op.StridedSlice:
             if len(op.inputs) != 4:
                 warn_cpu(op, "has {} input tensors, only 4 inputs are supported".format(len(op.inputs)))
                 return False
@@ -493,7 +478,7 @@
                     ),
                 )
                 return False
-        if op.type == "SplitV":
+        if op.type == Op.SplitV:
             # check that maximum one size is set to -1, indicating that size should be inferred
             sizes = op.inputs[1].values
             num_to_be_inferred = 0
@@ -504,7 +489,7 @@
             if num_to_be_inferred > 1:
                 print("Warning:", op.type, "has more than one size to be inferred, which is illegal, placing on CPU")
                 return False
-        if op.type.find("Concat") != -1:
+        if op.type in set((Op.Concat, Op.ConcatTFLite,)):
             axis = op.attrs.get("axis", None)
             if axis is None:
                 print("Warning:", op.type, "invalid or missing axis, placing on CPU")
@@ -554,7 +539,7 @@
 
     @classmethod
     def check_activation_ops(cls, op):
-        if op.type == "Softmax":
+        if op.type == Op.Softmax:
             ifm_tensor = op.inputs[0]
             ofm_tensor = op.outputs[0]
 
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index c0786bf..98dfa3d 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -25,6 +25,7 @@
 from . import numeric_util
 from .data_type import DataType
 from .ethos_u55_regs.ethos_u55_regs import resampling_mode
+from .operation import Op
 from .operation import Operation
 from .range_set import MemoryRangeSet
 
@@ -242,7 +243,7 @@
     const_tensor.values = np.array(values, dtype=value_dtype)
     const_tensor.quant_values = np.frombuffer(const_tensor.values.tobytes(), dtype=np.uint8)
     # Operator
-    const_op = Operation("Const", name)
+    const_op = Operation(Op.Const, name)
     const_op.set_output_tensor(const_tensor)
     return const_tensor
 
@@ -258,7 +259,7 @@
     if not ifm_reshape:
         reshape_ifm, reshape_ofm = reshape_ofm, reshape_ifm
     # Operator
-    reshape_op = Operation("Reshape", name)
+    reshape_op = Operation(Op.Reshape, name)
     reshape_op.attrs["new_shape"] = shape
     reshape_op.add_input_tensor(reshape_ifm)
     reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape))
@@ -649,7 +650,7 @@
         return strides
 
     def needs_dma(self):
-        return len(self.ops) == 1 and self.ops[0].type == "DMA"
+        return len(self.ops) == 1 and self.ops[0].type == Op.DMA
 
     def get_dma_src_tensor(self):
         # For weight tensors that need DMA: returns the source tensor in Flash, else None
@@ -659,7 +660,7 @@
     def find_npu_op(self):
         # Returns the NPU operator that uses this tensor, excluding DMA operators.
         for op in self.consumers():
-            if op.type == "DMA":
+            if op.type == Op.DMA:
                 return op.outputs[0].find_npu_op()
             if op.run_on_npu:
                 return op
diff --git a/ethosu/vela/test/test_lut.py b/ethosu/vela/test/test_lut.py
index ee1a40f..44ee0af 100644
--- a/ethosu/vela/test/test_lut.py
+++ b/ethosu/vela/test/test_lut.py
@@ -26,6 +26,7 @@
 from ethosu.vela.data_type import DataType
 from ethosu.vela.high_level_command_stream import DMA
 from ethosu.vela.nn_graph import Graph
+from ethosu.vela.operation import Op
 from ethosu.vela.rewrite_graph import verify_graph_health
 from ethosu.vela.tensor import create_const_tensor
 from ethosu.vela.tensor import TensorPurpose
@@ -94,28 +95,28 @@
     arch = testutil.create_arch()
     shape = [1, 1, 1, 1]
     # u8 LUT op, should lead to DMA
-    op0 = testutil.create_elemwise_op("AddAct", "op0", shape, shape, shape)
+    op0 = testutil.create_elemwise_op(Op.Add, "op0", shape, shape, shape)
     set_256_lut(op0, "lut0")
     # u8 LUT op, should lead to DMA
-    op1 = testutil.create_elemwise_op("AddAct", "op1", shape, shape, shape)
+    op1 = testutil.create_elemwise_op(Op.Add, "op1", shape, shape, shape)
     set_256_lut(op1, "lut1")
     # u8 LUT op with different LUT, should lead to DMA
-    op2 = testutil.create_elemwise_op("AddAct", "op2", shape, shape, shape)
+    op2 = testutil.create_elemwise_op(Op.Add, "op2", shape, shape, shape)
     set_256_lut(op2, "lut2")
     # u8 LUT op with same LUT as in op1, should not lead to DMA
-    op3 = testutil.create_elemwise_op("AddAct", "op3", shape, shape, shape)
+    op3 = testutil.create_elemwise_op(Op.Add, "op3", shape, shape, shape)
     set_256_lut(op3, "lut1")
     # u8 LUT op with same LUT as in op2, should not lead to DMA
-    op4 = testutil.create_elemwise_op("AddAct", "op4", shape, shape, shape)
+    op4 = testutil.create_elemwise_op(Op.Add, "op4", shape, shape, shape)
     set_256_lut(op4, "lut2")
     # 2K LUT op, should lead to DMA, and will overwrite all previous LUTs in SHRAM
-    op5_2K = testutil.create_elemwise_op("AddAct", "op5", shape, shape, shape)
+    op5_2K = testutil.create_elemwise_op(Op.Add, "op5", shape, shape, shape)
     set_2K_lut(op5_2K, "lut5")
     # Another 2K LUT op, should lead to DMA, and will overwrite the previous LUT in SHRAM
-    op6_2K = testutil.create_elemwise_op("AddAct", "op6", shape, shape, shape)
+    op6_2K = testutil.create_elemwise_op(Op.Add, "op6", shape, shape, shape)
     set_2K_lut(op6_2K, "lut6")
     # u8 LUT op with same LUT as in op1, should lead to DMA
-    op7 = testutil.create_elemwise_op("AddAct", "op7", shape, shape, shape)
+    op7 = testutil.create_elemwise_op(Op.Add, "op7", shape, shape, shape)
     set_256_lut(op7, "lut1")
 
     op_list = [op0, op1, op2, op3, op4, op5_2K, op6_2K, op7]
@@ -149,28 +150,28 @@
     arch = testutil.create_arch()
     shape = [1, 1, 1, 1]
     # u8 LUT op, should lead to DMA
-    op0 = testutil.create_elemwise_op("AddAct", "op0", shape, shape, shape)
+    op0 = testutil.create_elemwise_op(Op.Add, "op0", shape, shape, shape)
     set_256_lut(op0, "lut0")
     # u8 LUT op, should lead to DMA
-    op1 = testutil.create_elemwise_op("AddAct", "op1", shape, shape, shape)
+    op1 = testutil.create_elemwise_op(Op.Add, "op1", shape, shape, shape)
     set_256_lut(op1, "lut1")
     # 1K LUT op with different LUT, should lead to DMA
-    op2_1K = testutil.create_elemwise_op("AddAct", "op2", shape, shape, shape)
+    op2_1K = testutil.create_elemwise_op(Op.Add, "op2", shape, shape, shape)
     set_1K_lut(op2_1K, "lut2")
     # u8 LUT op with same LUT as in op1, should not lead to DMA
-    op3 = testutil.create_elemwise_op("AddAct", "op3", shape, shape, shape)
+    op3 = testutil.create_elemwise_op(Op.Add, "op3", shape, shape, shape)
     set_256_lut(op3, "lut1")
     # 1K LUT op with same LUT as in op2, should not lead to DMA
-    op4_1K = testutil.create_elemwise_op("AddAct", "op4", shape, shape, shape)
+    op4_1K = testutil.create_elemwise_op(Op.Add, "op4", shape, shape, shape)
     set_1K_lut(op4_1K, "lut2")
     # 1K LUT op, should lead to DMA, and will overwrite lut2
-    op5_2K = testutil.create_elemwise_op("AddAct", "op5", shape, shape, shape)
+    op5_2K = testutil.create_elemwise_op(Op.Add, "op5", shape, shape, shape)
     set_1K_lut(op5_2K, "lut5")
     # u8 LUT op, lut0 should still be present, should not lead to DMA
-    op6 = testutil.create_elemwise_op("AddAct", "op6", shape, shape, shape)
+    op6 = testutil.create_elemwise_op(Op.Add, "op6", shape, shape, shape)
     set_256_lut(op6, "lut0")
     # 1K LUT op with same LUT as in op2, should lead to DMA
-    op7 = testutil.create_elemwise_op("AddAct", "op7", shape, shape, shape)
+    op7 = testutil.create_elemwise_op(Op.Add, "op7", shape, shape, shape)
     set_1K_lut(op7, "lut2")
 
     op_list = [op0, op1, op2_1K, op3, op4_1K, op5_2K, op6, op7]
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 53c2092..20d448d 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -19,6 +19,7 @@
 import numpy as np
 
 from ethosu.vela.data_type import DataType
+from ethosu.vela.operation import Op
 from ethosu.vela.supported_operators import SupportedOperators
 from ethosu.vela.tensor import create_const_tensor
 from ethosu.vela.tensor import QuantizationParameters
@@ -35,7 +36,7 @@
     in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1])
     out = Tensor(out_shape, DataType.uint8, "out")
     attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
-    return testutil.create_op("StridedSlice", [in0, in1, in2, in3], out, attrs=attrs)
+    return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
 
 
 def create_strided_slice():
@@ -93,21 +94,21 @@
     # Tensors cannot have None in them
     inp = Tensor([1, 8, None, 8], DataType.uint8, "in")
     out = Tensor([1, 8, 8, 8], DataType.uint8, "out")
-    op = testutil.create_op("Relu", [inp], out)
+    op = testutil.create_op(Op.Relu, [inp], out)
     assert not support.is_operator_supported(op)
 
 
 def test_constraint_tens_shapeless():
     # Shapeless input is allowed if its of a certain type:
-    op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8])
+    op = testutil.create_elemwise_op(Op.Mul, "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8])
     assert support.is_operator_supported(op)
     # Shapeless output is not allowed at all:
-    op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [1, 8, 8, 8], [])
+    op = testutil.create_elemwise_op(Op.Mul, "scalar_mul", [1, 8, 8, 8], [1, 8, 8, 8], [])
     assert not support.is_operator_supported(op)
     # Invalid shapeless input due to op type:
     inp = Tensor([], DataType.uint8, "in")
     out = Tensor([1, 8, 8, 8], DataType.uint8, "out")
-    op = testutil.create_op("Relu", [inp], out)
+    op = testutil.create_op(Op.Relu, [inp], out)
     assert not support.is_operator_supported(op)
 
 
@@ -115,7 +116,7 @@
     # Tensors cannot be > 4D
     inp = Tensor([1, 1, 8, 8, 8], DataType.uint8, "in")
     out = Tensor([1, 1, 8, 8, 8], DataType.uint8, "out")
-    op = testutil.create_op("Relu", [inp], out)
+    op = testutil.create_op(Op.Relu, [inp], out)
     assert not support.is_operator_supported(op)
 
 
@@ -123,14 +124,14 @@
     # Tensors can only be of type uint8, int8, int16 (and int32)
     inp = Tensor([1, 8, 8, 8], DataType.float32, "in")
     out = Tensor([1, 8, 8, 8], DataType.float32, "out")
-    op = testutil.create_op("Relu", [inp], out)
+    op = testutil.create_op(Op.Relu, [inp], out)
     assert not support.is_operator_supported(op)
     # For int32, only select op types are allowed:
-    op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8], DataType.int32)
+    op = testutil.create_elemwise_op(Op.Mul, "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8], DataType.int32)
     assert support.is_operator_supported(op)
     inp = Tensor([1, 8, 8, 8], DataType.int32, "in")
     out = Tensor([1, 8, 8, 8], DataType.int32, "out")
-    op = testutil.create_op("Relu", [inp], out)
+    op = testutil.create_op(Op.Relu, [inp], out)
     assert not support.is_operator_supported(op)
 
 
@@ -138,11 +139,11 @@
     # Tensors can only have values in the inclusive range of 1-65535
     inp = Tensor([1, 8, 8, 0], DataType.uint8, "in")
     out = Tensor([1, 8, 8, 0], DataType.uint8, "out")
-    op = testutil.create_op("Relu", [inp], out)
+    op = testutil.create_op(Op.Relu, [inp], out)
     assert not support.is_operator_supported(op)
     inp = Tensor([1, 8, 8, 65536], DataType.uint8, "in")
     out = Tensor([1, 8, 8, 65536], DataType.uint8, "out")
-    op = testutil.create_op("Relu", [inp], out)
+    op = testutil.create_op(Op.Relu, [inp], out)
     assert not support.is_operator_supported(op)
 
 
@@ -150,13 +151,14 @@
     # Fused activation functions, if set, must be a valid op type
     inp = Tensor([1, 8, 8, 8], DataType.uint8, "in")
     out = Tensor([1, 8, 8, 8], DataType.uint8, "out")
-    op = testutil.create_op("Relu", [inp], out, attrs={"fused_activation_function": "Conv2D"})
+    op = testutil.create_op(Op.Relu, [inp], out)
+    op.activation = Op.Conv2D
     assert not support.is_operator_supported(op)
 
 
 def test_constraint_tens_quant_scale():
     # Quantization scale cannot be infinit
-    op = testutil.create_elemwise_op("Mul", "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8])
+    op = testutil.create_elemwise_op(Op.Mul, "scalar_mul", [1, 8, 8, 8], [], [1, 8, 8, 8])
     op.inputs[0].quantization = QuantizationParameters()
     op.inputs[0].quantization.scale_f32 = np.inf
     assert not support.is_operator_supported(op)
diff --git a/ethosu/vela/test/test_tflite_reader.py b/ethosu/vela/test/test_tflite_reader.py
index d63c000..23abb4a 100644
--- a/ethosu/vela/test/test_tflite_reader.py
+++ b/ethosu/vela/test/test_tflite_reader.py
@@ -20,6 +20,7 @@
 
 import pytest
 
+from ethosu.vela.operation import Op
 from ethosu.vela.tflite_reader import TFLiteSubgraph
 
 
@@ -41,13 +42,13 @@
 
     parse_op_testdata = [
         # op_type, opt_serializer, inputs, output, expected
-        ("FullyConnected", None, [0, 1, 2], 3, 3),  # FC
-        ("FullyConnected", None, [0, 1, -1], 3, 3),  # FC disabled Bias
-        ("FullyConnected", None, [0, 1], 3, 3),  # FC no Bias
-        ("Conv2D", None, [2, 1, 3], 0, 3),  # Conv2D
-        ("Conv2DBackprop", None, [0, 1, 2, 3], 4, 4),  # TransposeConv
-        ("Conv2DBackprop", None, [0, 1, 2], 4, 4),  # TransposeConv no Bias
-        pytest.param("Conv2D", None, [0, -1, 1], 3, 3, marks=pytest.mark.xfail),  # Conv2D no Weights
+        (Op.FullyConnected, None, [0, 1, 2], 3, 3),  # FC
+        (Op.FullyConnected, None, [0, 1, -1], 3, 3),  # FC disabled Bias
+        (Op.FullyConnected, None, [0, 1], 3, 3),  # FC no Bias
+        (Op.Conv2D, None, [2, 1, 3], 0, 3),  # Conv2D
+        (Op.Conv2DBackpropInput, None, [0, 1, 2, 3], 4, 4),  # TransposeConv
+        (Op.Conv2DBackpropInput, None, [0, 1, 2], 4, 4),  # TransposeConv no Bias
+        pytest.param(Op.Conv2D, None, [0, -1, 1], 3, 3, marks=pytest.mark.xfail),  # Conv2D no Weights
     ]
 
     @pytest.mark.parametrize("op_type, opt_serializer, inputs, output, expected", parse_op_testdata)
@@ -56,7 +57,7 @@
             # Mock a TFLiteSubGraph
             sg = TFLiteSubgraph(None, None)
             sg.graph = MagicMock()
-            sg.graph.operator_codes = [(op_type, opt_serializer)]
+            sg.graph.operator_codes = [(op_type, opt_serializer, "")]
 
             # Mock a couple of tensors
             sg.tensors = [MagicMock() for _ in range(5)]
diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py
index 13b6bf4..adb874a 100644
--- a/ethosu/vela/test/testutil.py
+++ b/ethosu/vela/test/testutil.py
@@ -20,7 +20,6 @@
 from ethosu.vela import architecture_features
 from ethosu.vela.data_type import DataType
 from ethosu.vela.nn_graph import Subgraph
-from ethosu.vela.operation import NpuBlockType
 from ethosu.vela.operation import Operation
 from ethosu.vela.tensor import create_const_tensor
 from ethosu.vela.tensor import Tensor
@@ -52,7 +51,6 @@
     op.add_input_tensor(create_const_tensor(name + "_ifm2", ifm2_shape, datatype, np.zeros(ifm2_shape), np_type))
     ofm = Tensor(ofm_shape, datatype, name + "_ofm")
     op.set_output_tensor(ofm)
-    op.attrs["npu_block_type"] = NpuBlockType.ElementWise
     return op
 
 
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index c25f415..8a53039 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -22,6 +22,8 @@
 import numpy as np
 
 from .data_type import DataType
+from .operation import CustomType
+from .operation import Op
 from .tflite import AbsOptions
 from .tflite import AddNOptions
 from .tflite import AddOptions
@@ -396,16 +398,16 @@
         attrs["custom_options_format"] = op_data.CustomOptionsFormat()
 
         if np.array_equal(custom_options, self.CUSTOM_OPTIONS_NPU_OP):
-            attrs["custom_type"] = "ExistingNpuOp"
+            attrs["custom_type"] = CustomType.ExistingNpuOp
 
         return attrs
 
     def serialize(self, builder, attrs):
-        custom_type = attrs.get("custom_type", "")
+        custom_type = attrs.get("custom_type", CustomType.ThirdPartyOp)
         self.custom_opt_format = attrs.get("custom_options_format", self.CUSTOM_OPTIONS_FORMAT_DEFAULT)
 
         # Set NPU op custom options for the TensorFlow Lite custom operator
-        if custom_type == "NpuOp":
+        if custom_type == CustomType.NpuOp:
             custom_options = self.CUSTOM_OPTIONS_NPU_OP
         else:
             custom_options = attrs.get("custom_options", [])
@@ -426,11 +428,11 @@
 
 activation_function_map = {
     ActivationFunctionType.NONE: None,
-    ActivationFunctionType.RELU: "Relu",
-    ActivationFunctionType.RELU_N1_TO_1: "ReluN1To1",
-    ActivationFunctionType.RELU6: "Relu6",
-    ActivationFunctionType.TANH: "Tanh",
-    ActivationFunctionType.SIGN_BIT: "SignBit",
+    ActivationFunctionType.RELU: Op.Relu,
+    ActivationFunctionType.RELU_N1_TO_1: Op.ReluN1To1,
+    ActivationFunctionType.RELU6: Op.Relu6,
+    ActivationFunctionType.TANH: Op.Tanh,
+    ActivationFunctionType.SIGN_BIT: Op.SignBit,
 }
 
 activation_function_inv_map = inverse_map(activation_function_map)
@@ -478,52 +480,50 @@
 
 is_int_vec = True
 
-custom_prefix = "Custom_"
-
 builtin_operator_map = {
-    BuiltinOperator.ADD: ("AddAct", OptionsSerializer("AddOptions", (fused_act, "pot_scale_int16"))),
-    BuiltinOperator.AVERAGE_POOL_2D: ("AvgPoolAct", pool2d_opts),
-    BuiltinOperator.CONCATENATION: ("ConcatTFLite", OptionsSerializer("ConcatenationOptions", ("axis", fused_act))),
-    BuiltinOperator.CONV_2D: ("Conv2DBiasAct", conv2d_opts),
-    BuiltinOperator.DEPTHWISE_CONV_2D: ("DepthwiseConv2dBiasAct", depthwise_opts),
-    BuiltinOperator.DEPTH_TO_SPACE: ("DepthToSpace", OptionsSerializer("DepthToSpaceOptions", ("block_size",))),
-    BuiltinOperator.DEQUANTIZE: ("Dequantize", OptionsSerializer("DequantizeOptions")),
-    BuiltinOperator.EMBEDDING_LOOKUP: ("EmbeddingLookup", None),
-    BuiltinOperator.FLOOR: ("Floor", None),
+    BuiltinOperator.ADD: (Op.Add, OptionsSerializer("AddOptions", (fused_act, "pot_scale_int16"))),
+    BuiltinOperator.AVERAGE_POOL_2D: (Op.AvgPool, pool2d_opts),
+    BuiltinOperator.CONCATENATION: (Op.ConcatTFLite, OptionsSerializer("ConcatenationOptions", ("axis", fused_act))),
+    BuiltinOperator.CONV_2D: (Op.Conv2DBias, conv2d_opts),
+    BuiltinOperator.DEPTHWISE_CONV_2D: (Op.DepthwiseConv2DBias, depthwise_opts),
+    BuiltinOperator.DEPTH_TO_SPACE: (Op.DepthToSpace, OptionsSerializer("DepthToSpaceOptions", ("block_size",))),
+    BuiltinOperator.DEQUANTIZE: (Op.Dequantize, OptionsSerializer("DequantizeOptions")),
+    BuiltinOperator.EMBEDDING_LOOKUP: (Op.EmbeddingLookup, None),
+    BuiltinOperator.FLOOR: (Op.Floor, None),
     BuiltinOperator.FULLY_CONNECTED: (
-        "FullyConnectedAct",
+        Op.FullyConnected,
         OptionsSerializer("FullyConnectedOptions", (fused_act, "weights_format", "asymmetric_quantize_inputs")),
     ),
-    BuiltinOperator.HASHTABLE_LOOKUP: ("HashtableLookup", None),
-    BuiltinOperator.L2_NORMALIZATION: ("L2NormAct", OptionsSerializer("L2NormOptions", (fused_act,))),
-    BuiltinOperator.L2_POOL_2D: ("L2Pool2D", pool2d_opts),
+    BuiltinOperator.HASHTABLE_LOOKUP: (Op.HashtableLookup, None),
+    BuiltinOperator.L2_NORMALIZATION: (Op.L2Norm, OptionsSerializer("L2NormOptions", (fused_act,))),
+    BuiltinOperator.L2_POOL_2D: (Op.L2Pool2D, pool2d_opts),
     BuiltinOperator.LOCAL_RESPONSE_NORMALIZATION: (
-        "LRN",
+        Op.LRN,
         OptionsSerializer("LocalResponseNormalizationOptions", ("radius", "bias", "alpha", "beta")),
     ),
-    BuiltinOperator.LOGISTIC: ("Sigmoid", None),
-    BuiltinOperator.LSH_PROJECTION: ("LSHProjection", OptionsSerializer("LSHProjectionOptions", ("type",))),
-    BuiltinOperator.LSTM: ("LstmAct", lstm_opts),
-    BuiltinOperator.MAX_POOL_2D: ("MaxPool", pool2d_opts),
-    BuiltinOperator.MUL: ("MulAct", OptionsSerializer("MulOptions", (fused_act,))),
-    BuiltinOperator.RELU: ("Relu", None),
-    BuiltinOperator.RELU_N1_TO_1: ("ReluN1To1", None),
-    BuiltinOperator.RELU6: ("Relu6", None),
-    BuiltinOperator.RESHAPE: ("Reshape", OptionsSerializer("ReshapeOptions", (("new_shape", is_int_vec),))),
+    BuiltinOperator.LOGISTIC: (Op.Sigmoid, None),
+    BuiltinOperator.LSH_PROJECTION: (Op.LSHProjection, OptionsSerializer("LSHProjectionOptions", ("type",))),
+    BuiltinOperator.LSTM: (Op.Lstm, lstm_opts),
+    BuiltinOperator.MAX_POOL_2D: (Op.MaxPool, pool2d_opts),
+    BuiltinOperator.MUL: (Op.Mul, OptionsSerializer("MulOptions", (fused_act,))),
+    BuiltinOperator.RELU: (Op.Relu, None),
+    BuiltinOperator.RELU_N1_TO_1: (Op.ReluN1To1, None),
+    BuiltinOperator.RELU6: (Op.Relu6, None),
+    BuiltinOperator.RESHAPE: (Op.Reshape, OptionsSerializer("ReshapeOptions", (("new_shape", is_int_vec),))),
     BuiltinOperator.RESIZE_BILINEAR: (
-        "ResizeBilinear",
+        Op.ResizeBilinear,
         OptionsSerializer("ResizeBilinearOptions", ("align_corners", "half_pixel_centers")),
     ),
-    BuiltinOperator.RNN: ("RnnAct", rnn_opts),
-    BuiltinOperator.SOFTMAX: ("Softmax", OptionsSerializer("SoftmaxOptions", ("beta",))),
-    BuiltinOperator.SPACE_TO_DEPTH: ("SpaceToDepth", OptionsSerializer("SpaceToDepthOptions", ("block_size",))),
+    BuiltinOperator.RNN: (Op.Rnn, rnn_opts),
+    BuiltinOperator.SOFTMAX: (Op.Softmax, OptionsSerializer("SoftmaxOptions", ("beta",))),
+    BuiltinOperator.SPACE_TO_DEPTH: (Op.SpaceToDepth, OptionsSerializer("SpaceToDepthOptions", ("block_size",))),
     BuiltinOperator.SVDF: (
-        "SvdfAct",
+        Op.Svdf,
         OptionsSerializer("SVDFOptions", ("rank", fused_act, "asymmetric_quantize_inputs")),
     ),
-    BuiltinOperator.TANH: ("Tanh", None),
+    BuiltinOperator.TANH: (Op.Tanh, None),
     BuiltinOperator.CONCAT_EMBEDDINGS: (
-        "ConcatEmbeddings",
+        Op.ConcatEmbeddings,
         OptionsSerializer(
             "ConcatEmbeddingsOptions",
             (
@@ -538,40 +538,40 @@
         ),
     ),
     BuiltinOperator.SKIP_GRAM: (
-        "SkipGram",
+        Op.SkipGram,
         OptionsSerializer("SkipGramOptions", ("ngram_size", "max_skip_size", "include_all_ngrams")),
     ),
-    BuiltinOperator.CALL: ("Call", OptionsSerializer("CallOptions", ("subgraph",))),
+    BuiltinOperator.CALL: (Op.Call, OptionsSerializer("CallOptions", ("subgraph",))),
     BuiltinOperator.EMBEDDING_LOOKUP_SPARSE: (
-        "EmbeddingLookupSparse",
+        Op.EmbeddingLookupSparse,
         OptionsSerializer("EmbeddingLookupSparseOptions", ("combiner",)),
     ),
-    BuiltinOperator.PAD: ("Pad", OptionsSerializer("PadOptions")),
-    BuiltinOperator.UNIDIRECTIONAL_SEQUENCE_RNN: ("UnidirectionalSequenceRnnAct", seq_rnn_opts),
-    BuiltinOperator.GATHER: ("GatherV2", OptionsSerializer("GatherOptions", ("axis",))),
-    BuiltinOperator.BATCH_TO_SPACE_ND: ("BatchToSpaceND", OptionsSerializer("BatchToSpaceNDOptions")),
-    BuiltinOperator.SPACE_TO_BATCH_ND: ("SpaceToBatchND", OptionsSerializer("SpaceToBatchNDOptions")),
-    BuiltinOperator.TRANSPOSE: ("Transpose", OptionsSerializer("TransposeOptions")),
-    BuiltinOperator.MEAN: ("Mean", None),
-    BuiltinOperator.SUB: ("SubAct", OptionsSerializer("SubOptions", (fused_act, "pot_scale_int16",))),
-    BuiltinOperator.DIV: ("DivAct", OptionsSerializer("DivOptions", (fused_act,))),
-    BuiltinOperator.SQUEEZE: ("Squeeze", OptionsSerializer("SqueezeOptions", (("squeeze_dims", is_int_vec),))),
-    BuiltinOperator.UNIDIRECTIONAL_SEQUENCE_LSTM: ("UnidirectionalSequenceLstmAct", unidir_seq_lstm_opts),
+    BuiltinOperator.PAD: (Op.Pad, OptionsSerializer("PadOptions")),
+    BuiltinOperator.UNIDIRECTIONAL_SEQUENCE_RNN: (Op.UnidirectionalSequenceRnn, seq_rnn_opts),
+    BuiltinOperator.GATHER: (Op.GatherV2, OptionsSerializer("GatherOptions", ("axis",))),
+    BuiltinOperator.BATCH_TO_SPACE_ND: (Op.BatchToSpaceND, OptionsSerializer("BatchToSpaceNDOptions")),
+    BuiltinOperator.SPACE_TO_BATCH_ND: (Op.SpaceToBatchND, OptionsSerializer("SpaceToBatchNDOptions")),
+    BuiltinOperator.TRANSPOSE: (Op.Transpose, OptionsSerializer("TransposeOptions")),
+    BuiltinOperator.MEAN: (Op.Mean, None),
+    BuiltinOperator.SUB: (Op.Sub, OptionsSerializer("SubOptions", (fused_act, "pot_scale_int16",))),
+    BuiltinOperator.DIV: (Op.Div, OptionsSerializer("DivOptions", (fused_act,))),
+    BuiltinOperator.SQUEEZE: (Op.Squeeze, OptionsSerializer("SqueezeOptions", (("squeeze_dims", is_int_vec),))),
+    BuiltinOperator.UNIDIRECTIONAL_SEQUENCE_LSTM: (Op.UnidirectionalSequenceLstm, unidir_seq_lstm_opts),
     BuiltinOperator.STRIDED_SLICE: (
-        "StridedSlice",
+        Op.StridedSlice,
         OptionsSerializer(
             "StridedSliceOptions", ("begin_mask", "end_mask", "ellipsis_mask", "new_axis_mask", "shrink_axis_mask")
         ),
     ),
-    BuiltinOperator.BIDIRECTIONAL_SEQUENCE_RNN: ("BidirectionalSequenceRnnAct", bidir_seq_rnn_opts),
-    BuiltinOperator.EXP: ("Exp", OptionsSerializer("ExpOptions")),
-    BuiltinOperator.TOPK_V2: ("TopKV2", OptionsSerializer("TopKV2Options")),
-    BuiltinOperator.SPLIT: ("Split", OptionsSerializer("SplitOptions", ("num_splits",))),
-    BuiltinOperator.LOG_SOFTMAX: ("LogSoftmax", OptionsSerializer("LogSoftmaxOptions")),
-    BuiltinOperator.DELEGATE: ("Delegate", None),
-    BuiltinOperator.BIDIRECTIONAL_SEQUENCE_LSTM: ("BidirectionalSequenceLstmAct", bidir_seq_lstm_opts),
+    BuiltinOperator.BIDIRECTIONAL_SEQUENCE_RNN: (Op.BidirectionalSequenceRnn, bidir_seq_rnn_opts),
+    BuiltinOperator.EXP: (Op.Exp, OptionsSerializer("ExpOptions")),
+    BuiltinOperator.TOPK_V2: (Op.TopKV2, OptionsSerializer("TopKV2Options")),
+    BuiltinOperator.SPLIT: (Op.Split, OptionsSerializer("SplitOptions", ("num_splits",))),
+    BuiltinOperator.LOG_SOFTMAX: (Op.LogSoftmax, OptionsSerializer("LogSoftmaxOptions")),
+    BuiltinOperator.DELEGATE: (Op.Delegate, None),
+    BuiltinOperator.BIDIRECTIONAL_SEQUENCE_LSTM: (Op.BidirectionalSequenceLstm, bidir_seq_lstm_opts),
     BuiltinOperator.CAST: (
-        "Cast",
+        Op.Cast,
         OptionsSerializer(
             "CastOptions",
             (
@@ -580,109 +580,112 @@
             ),
         ),
     ),
-    BuiltinOperator.PRELU: ("Prelu", None),
-    BuiltinOperator.MAXIMUM: ("Maximum", OptionsSerializer("MaximumMinimumOptions")),
+    BuiltinOperator.PRELU: (Op.Prelu, None),
+    BuiltinOperator.MAXIMUM: (Op.Maximum, OptionsSerializer("MaximumMinimumOptions")),
     BuiltinOperator.ARG_MAX: (
-        "ArgMax",
+        Op.ArgMax,
         OptionsSerializer("ArgMaxOptions", (("output_type", datatype_deserialize, datatype_serialize),)),
     ),
-    BuiltinOperator.MINIMUM: ("Minimum", OptionsSerializer("MaximumMinimumOptions")),
-    BuiltinOperator.LESS: ("Less", OptionsSerializer("LessOptions")),
-    BuiltinOperator.NEG: ("Neg", OptionsSerializer("NegOptions")),
-    BuiltinOperator.PADV2: ("PadV2", OptionsSerializer("PadV2Options")),
-    BuiltinOperator.GREATER: ("Greater", OptionsSerializer("GreaterOptions")),
-    BuiltinOperator.GREATER_EQUAL: ("GreaterEqual", OptionsSerializer("GreaterEqualOptions")),
-    BuiltinOperator.LESS_EQUAL: ("LessEqual", OptionsSerializer("LessEqualOptions")),
-    BuiltinOperator.SELECT: ("Select", OptionsSerializer("SelectOptions")),
-    BuiltinOperator.SLICE: ("Slice", OptionsSerializer("SliceOptions")),
-    BuiltinOperator.SIN: ("Sin", None),
+    BuiltinOperator.MINIMUM: (Op.Minimum, OptionsSerializer("MaximumMinimumOptions")),
+    BuiltinOperator.LESS: (Op.Less, OptionsSerializer("LessOptions")),
+    BuiltinOperator.NEG: (Op.Neg, OptionsSerializer("NegOptions")),
+    BuiltinOperator.PADV2: (Op.PadV2, OptionsSerializer("PadV2Options")),
+    BuiltinOperator.GREATER: (Op.Greater, OptionsSerializer("GreaterOptions")),
+    BuiltinOperator.GREATER_EQUAL: (Op.GreaterEqual, OptionsSerializer("GreaterEqualOptions")),
+    BuiltinOperator.LESS_EQUAL: (Op.LessEqual, OptionsSerializer("LessEqualOptions")),
+    BuiltinOperator.SELECT: (Op.Select, OptionsSerializer("SelectOptions")),
+    BuiltinOperator.SLICE: (Op.Slice, OptionsSerializer("SliceOptions")),
+    BuiltinOperator.SIN: (Op.Sin, None),
     BuiltinOperator.TRANSPOSE_CONV: (
-        "Conv2DBackpropInput",
+        Op.Conv2DBackpropInput,
         OptionsSerializer("TransposeConvOptions", (padding, "stride_w", "stride_h")),
     ),
     BuiltinOperator.SPARSE_TO_DENSE: (
-        "SparseToDense",
+        Op.SparseToDense,
         OptionsSerializer("SparseToDenseOptions", ("validate_indices",)),
     ),
-    BuiltinOperator.TILE: ("Tile", OptionsSerializer("TileOptions")),
-    BuiltinOperator.EXPAND_DIMS: ("ExpandDims", OptionsSerializer("ExpandDimsOptions")),
-    BuiltinOperator.EQUAL: ("Equal", OptionsSerializer("EqualOptions")),
-    BuiltinOperator.NOT_EQUAL: ("NotEqual", OptionsSerializer("NotEqualOptions")),
-    BuiltinOperator.LOG: ("Log", None),
-    BuiltinOperator.SUM: ("Sum", None),
-    BuiltinOperator.SQRT: ("Sqrt", None),
-    BuiltinOperator.RSQRT: ("Rsqrt", None),
+    BuiltinOperator.TILE: (Op.Tile, OptionsSerializer("TileOptions")),
+    BuiltinOperator.EXPAND_DIMS: (Op.ExpandDims, OptionsSerializer("ExpandDimsOptions")),
+    BuiltinOperator.EQUAL: (Op.Equal, OptionsSerializer("EqualOptions")),
+    BuiltinOperator.NOT_EQUAL: (Op.NotEqual, OptionsSerializer("NotEqualOptions")),
+    BuiltinOperator.LOG: (Op.Log, None),
+    BuiltinOperator.SUM: (Op.Sum, None),
+    BuiltinOperator.SQRT: (Op.Sqrt, None),
+    BuiltinOperator.RSQRT: (Op.Rsqrt, None),
     BuiltinOperator.SHAPE: (
-        "Shape",
+        Op.Shape,
         OptionsSerializer("ShapeOptions", (("out_type", datatype_deserialize, datatype_serialize),)),
     ),
-    BuiltinOperator.POW: ("Pow", OptionsSerializer("PowOptions")),
+    BuiltinOperator.POW: (Op.Pow, OptionsSerializer("PowOptions")),
     BuiltinOperator.ARG_MIN: (
-        "ArgMin",
+        Op.ArgMin,
         OptionsSerializer("ArgMinOptions", (("output_type", datatype_deserialize, datatype_serialize),)),
     ),
     BuiltinOperator.FAKE_QUANT: (
-        "FakeQuantWithMinMaxArgs",
+        Op.FakeQuantWithMinMaxArgs,
         OptionsSerializer("FakeQuantOptions", ("min", "max", "num_bits", "narrow_range")),
     ),
-    BuiltinOperator.REDUCE_PROD: ("Prod", reducer_opts),
-    BuiltinOperator.REDUCE_MAX: ("Max", reducer_opts),
-    BuiltinOperator.PACK: ("Pack", OptionsSerializer("PackOptions", ("values_count", "axis"))),
-    BuiltinOperator.LOGICAL_OR: ("LogicalOr", OptionsSerializer("LogicalOrOptions")),
-    BuiltinOperator.ONE_HOT: ("OneHot", OptionsSerializer("OneHotOptions", ("axis",))),
-    BuiltinOperator.LOGICAL_AND: ("LogicalAnd", OptionsSerializer("LogicalAndOptions")),
-    BuiltinOperator.LOGICAL_NOT: ("LogicalNot", OptionsSerializer("LogicalNotOptions")),
-    BuiltinOperator.UNPACK: ("Unpack", OptionsSerializer("UnpackOptions", ("num", "axis"))),
-    BuiltinOperator.REDUCE_MIN: ("Min", reducer_opts),
-    BuiltinOperator.FLOOR_DIV: ("FloorDiv", OptionsSerializer("FloorDivOptions")),
-    BuiltinOperator.REDUCE_ANY: ("Any", reducer_opts),
-    BuiltinOperator.SQUARE: ("Square", OptionsSerializer("SquareOptions")),
-    BuiltinOperator.ZEROS_LIKE: ("ZerosLike", OptionsSerializer("ZerosLikeOptions")),
-    BuiltinOperator.FILL: ("Fill", OptionsSerializer("FillOptions")),
-    BuiltinOperator.FLOOR_MOD: ("FloorMod", OptionsSerializer("FloorModOptions")),
-    BuiltinOperator.RANGE: ("Range", OptionsSerializer("RangeOptions")),
+    BuiltinOperator.REDUCE_PROD: (Op.Prod, reducer_opts),
+    BuiltinOperator.REDUCE_MAX: (Op.Max, reducer_opts),
+    BuiltinOperator.PACK: (Op.Pack, OptionsSerializer("PackOptions", ("values_count", "axis"))),
+    BuiltinOperator.LOGICAL_OR: (Op.LogicalOr, OptionsSerializer("LogicalOrOptions")),
+    BuiltinOperator.ONE_HOT: (Op.OneHot, OptionsSerializer("OneHotOptions", ("axis",))),
+    BuiltinOperator.LOGICAL_AND: (Op.LogicalAnd, OptionsSerializer("LogicalAndOptions")),
+    BuiltinOperator.LOGICAL_NOT: (Op.LogicalNot, OptionsSerializer("LogicalNotOptions")),
+    BuiltinOperator.UNPACK: (Op.Unpack, OptionsSerializer("UnpackOptions", ("num", "axis"))),
+    BuiltinOperator.REDUCE_MIN: (Op.Min, reducer_opts),
+    BuiltinOperator.FLOOR_DIV: (Op.FloorDiv, OptionsSerializer("FloorDivOptions")),
+    BuiltinOperator.REDUCE_ANY: (Op.Any, reducer_opts),
+    BuiltinOperator.SQUARE: (Op.Square, OptionsSerializer("SquareOptions")),
+    BuiltinOperator.ZEROS_LIKE: (Op.ZerosLike, OptionsSerializer("ZerosLikeOptions")),
+    BuiltinOperator.FILL: (Op.Fill, OptionsSerializer("FillOptions")),
+    BuiltinOperator.FLOOR_MOD: (Op.FloorMod, OptionsSerializer("FloorModOptions")),
+    BuiltinOperator.RANGE: (Op.Range, OptionsSerializer("RangeOptions")),
     BuiltinOperator.RESIZE_NEAREST_NEIGHBOR: (
-        "ResizeNearestNeighbor",
+        Op.ResizeNearestNeighbor,
         OptionsSerializer("ResizeNearestNeighborOptions", ("align_corners", "half_pixel_centers")),
     ),
-    BuiltinOperator.LEAKY_RELU: ("LeakyRelu", OptionsSerializer("LeakyReluOptions", ("alpha",))),
-    BuiltinOperator.SQUARED_DIFFERENCE: ("SquaredDifference", OptionsSerializer("SquaredDifferenceOptions")),
-    BuiltinOperator.MIRROR_PAD: ("MirrorPad", OptionsSerializer("MirrorPadOptions", ("mode",))),
-    BuiltinOperator.ABS: ("Abs", OptionsSerializer("AbsOptions")),
-    BuiltinOperator.SPLIT_V: ("SplitV", OptionsSerializer("SplitVOptions", ("num_splits",))),
+    BuiltinOperator.LEAKY_RELU: (Op.LeakyRelu, OptionsSerializer("LeakyReluOptions", ("alpha",))),
+    BuiltinOperator.SQUARED_DIFFERENCE: (Op.SquaredDifference, OptionsSerializer("SquaredDifferenceOptions")),
+    BuiltinOperator.MIRROR_PAD: (Op.MirrorPad, OptionsSerializer("MirrorPadOptions", ("mode",))),
+    BuiltinOperator.ABS: (Op.Abs, OptionsSerializer("AbsOptions")),
+    BuiltinOperator.SPLIT_V: (Op.SplitV, OptionsSerializer("SplitVOptions", ("num_splits",))),
     BuiltinOperator.UNIQUE: (
-        "Unique",
+        Op.Unique,
         OptionsSerializer("UniqueOptions", (("idx_out_type", datatype_deserialize, datatype_serialize),)),
     ),
-    BuiltinOperator.CEIL: ("Ceil", None),
-    BuiltinOperator.REVERSE_V2: ("ReverseV2", OptionsSerializer("ReverseV2Options")),
-    BuiltinOperator.ADD_N: ("AddN", OptionsSerializer("AddNOptions")),
-    BuiltinOperator.GATHER_ND: ("GatherNd", OptionsSerializer("GatherNdOptions")),
-    BuiltinOperator.COS: ("Cos", OptionsSerializer("CosOptions")),
-    BuiltinOperator.WHERE: ("Where", OptionsSerializer("WhereOptions")),
-    BuiltinOperator.RANK: ("Rank", OptionsSerializer("RankOptions")),
-    BuiltinOperator.ELU: ("Elu", None),
+    BuiltinOperator.CEIL: (Op.Ceil, None),
+    BuiltinOperator.REVERSE_V2: (Op.ReverseV2, OptionsSerializer("ReverseV2Options")),
+    BuiltinOperator.ADD_N: (Op.AddN, OptionsSerializer("AddNOptions")),
+    BuiltinOperator.GATHER_ND: (Op.GatherNd, OptionsSerializer("GatherNdOptions")),
+    BuiltinOperator.COS: (Op.Cos, OptionsSerializer("CosOptions")),
+    BuiltinOperator.WHERE: (Op.Where, OptionsSerializer("WhereOptions")),
+    BuiltinOperator.RANK: (Op.Rank, OptionsSerializer("RankOptions")),
+    BuiltinOperator.ELU: (Op.Elu, None),
     BuiltinOperator.REVERSE_SEQUENCE: (
-        "ReverseSequence",
+        Op.ReverseSequence,
         OptionsSerializer("ReverseSequenceOptions", ("seq_dim", "batch_dim")),
     ),
-    BuiltinOperator.MATRIX_DIAG: ("MatrixDiag", OptionsSerializer("MatrixDiagOptions")),
-    BuiltinOperator.QUANTIZE: ("Quantize", OptionsSerializer("QuantizeOptions")),
-    BuiltinOperator.MATRIX_SET_DIAG: ("MatrixSetDiag", OptionsSerializer("MatrixSetDiagOptions")),
-    BuiltinOperator.ROUND: ("Round", None),
-    BuiltinOperator.HARD_SWISH: ("HardSwish", OptionsSerializer("HardSwishOptions")),
-    BuiltinOperator.IF: ("If", OptionsSerializer("IfOptions", ("then_subgraph_index", "else_subgraph_index"))),
-    BuiltinOperator.WHILE: ("While", OptionsSerializer("WhileOptions", ("cond_subgraph_index", "body_subgraph_index"))),
-    BuiltinOperator.NON_MAX_SUPPRESSION_V4: ("NonMaxSuppressionV4", OptionsSerializer("NonMaxSuppressionV4Options")),
-    BuiltinOperator.NON_MAX_SUPPRESSION_V5: ("NonMaxSuppressionV5", OptionsSerializer("NonMaxSuppressionV5Options")),
-    BuiltinOperator.SCATTER_ND: ("ScatterNd", OptionsSerializer("ScatterNdOptions")),
-    BuiltinOperator.SELECT_V2: ("SelectV2", OptionsSerializer("SelectV2Options")),
-    BuiltinOperator.DENSIFY: ("Densify", OptionsSerializer("DensifyOptions")),
-    BuiltinOperator.SEGMENT_SUM: ("SegmentSum", OptionsSerializer("SegmentSumOptions")),
-    BuiltinOperator.BATCH_MATMUL: ("BatchMatMul", OptionsSerializer("BatchMatMulOptions", ("adj_x", "adj_y"))),
-    BuiltinOperator.CUSTOM: (custom_prefix, CustomOptionsSerializer()),
+    BuiltinOperator.MATRIX_DIAG: (Op.MatrixDiag, OptionsSerializer("MatrixDiagOptions")),
+    BuiltinOperator.QUANTIZE: (Op.Quantize, OptionsSerializer("QuantizeOptions")),
+    BuiltinOperator.MATRIX_SET_DIAG: (Op.MatrixSetDiag, OptionsSerializer("MatrixSetDiagOptions")),
+    BuiltinOperator.ROUND: (Op.Round, None),
+    BuiltinOperator.HARD_SWISH: (Op.HardSwish, OptionsSerializer("HardSwishOptions")),
+    BuiltinOperator.IF: (Op.If, OptionsSerializer("IfOptions", ("then_subgraph_index", "else_subgraph_index"))),
+    BuiltinOperator.WHILE: (
+        Op.While,
+        OptionsSerializer("WhileOptions", ("cond_subgraph_index", "body_subgraph_index")),
+    ),
+    BuiltinOperator.NON_MAX_SUPPRESSION_V4: (Op.NonMaxSuppressionV4, OptionsSerializer("NonMaxSuppressionV4Options")),
+    BuiltinOperator.NON_MAX_SUPPRESSION_V5: (Op.NonMaxSuppressionV5, OptionsSerializer("NonMaxSuppressionV5Options")),
+    BuiltinOperator.SCATTER_ND: (Op.ScatterNd, OptionsSerializer("ScatterNdOptions")),
+    BuiltinOperator.SELECT_V2: (Op.SelectV2, OptionsSerializer("SelectV2Options")),
+    BuiltinOperator.DENSIFY: (Op.Densify, OptionsSerializer("DensifyOptions")),
+    BuiltinOperator.SEGMENT_SUM: (Op.SegmentSum, OptionsSerializer("SegmentSumOptions")),
+    BuiltinOperator.BATCH_MATMUL: (Op.BatchMatMul, OptionsSerializer("BatchMatMulOptions", ("adj_x", "adj_y"))),
+    BuiltinOperator.CUSTOM: (Op.Custom, CustomOptionsSerializer()),
 }
 
 builtin_operator_inv_map = {v[0]: (k, v[1]) for k, v in builtin_operator_map.items()}
 
-builtin_operator_inv_map["NpuOp"] = (BuiltinOperator.CUSTOM, CustomOptionsSerializer())
+builtin_operator_inv_map[Op.CustomNpuOp] = (BuiltinOperator.CUSTOM, CustomOptionsSerializer())
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 77cc796..a03f9ec 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -23,6 +23,7 @@
 from .errors import TensorError
 from .nn_graph import Graph
 from .nn_graph import Subgraph
+from .operation import Op
 from .operation import Operation
 from .tensor import QuantizationParameters
 from .tensor import Tensor
@@ -53,7 +54,7 @@
     if tens.quant_values is not None:
         tens.quant_values = tens.quant_values.transpose(reorder)
 
-    op = Operation("Const", tens.name)
+    op = Operation(Op.Const, tens.name)
     op.set_output_tensor(tens)
     return tens
 
@@ -78,12 +79,12 @@
             if tens.ops != []:
                 TensorError(tens, "This subgraph input tensor has unexpected driving operators.")
 
-            op = Operation("Placeholder", tens.name)
+            op = Operation(Op.Placeholder, tens.name)
             op.set_output_tensor(tens)
 
         for tens in self.tensors:
             if not tens.ops:
-                op = Operation("Const", tens.name)
+                op = Operation(Op.Const, tens.name)
                 op.set_output_tensor(tens)
 
     def get_tensors_from_indices_remove_duplicates(self, indices, warning_str):
@@ -136,7 +137,7 @@
         return tens
 
     def parse_operator(self, op_index, op_data):
-        op_type, opt_serializer = self.graph.operator_codes[op_data.OpcodeIndex()]
+        op_type, opt_serializer, custom_code = self.graph.operator_codes[op_data.OpcodeIndex()]
         inputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.InputsAsNumpy()]
         outputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.OutputsAsNumpy()]
         name = "unknown_op_name"
@@ -149,19 +150,13 @@
         for out in op.outputs:
             out.ops = [op]
 
-        if op_type.startswith("DepthwiseConv2d") or op_type.startswith("Conv2D"):
+        if op.type.is_depthwise_conv2d_op() or op.type.is_conv2d_op() or op.type == Op.FullyConnected:
             if inputs[1].values is not None:
-                inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0))
-            if len(inputs) < 3 or (len(inputs) < 4 and "Backprop" in op_type):
-                # No Bias tensor
-                inputs.append(None)
-            if inputs[-1]:
-                inputs[-1] = clone_and_reshape_tensor(inputs[-1], (0,))
-
-        if op_type.startswith("FullyConnected"):
-            if inputs[1].values is not None:
-                inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0))
-            if len(inputs) < 3:
+                if op.type == Op.FullyConnected:
+                    inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0))
+                else:
+                    inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0))
+            if op.type.needs_bias() and len(inputs) <= op_type.info.indices.biases[0]:
                 # No Bias tensor
                 inputs.append(None)
             if inputs[-1]:
@@ -170,11 +165,11 @@
         if opt_serializer is not None:
             op.attrs = opt_serializer.deserialize(op_data)
 
-            if op_type == "Reshape" and "new_shape" not in op.attrs:
+            if op_type == Op.Reshape and "new_shape" not in op.attrs:
                 # Reshape should have an attrib "new_shape" but if it is missing, add it based on the output shape
                 op.attrs["new_shape"] = outputs[0].shape
 
-            if op_type == "Cast":
+            if op_type == Op.Cast:
                 # Cast op should have "in/out_data_type" attribs add if missing
                 if "in_data_type" not in op.attrs:
                     op.attrs["in_data_type"] = inputs[0].dtype
@@ -190,6 +185,9 @@
             if "depth_multiplier" in op.attrs:
                 op.attrs["channel_multiplier"] = op.attrs["depth_multiplier"]
 
+            op.activation = op.attrs.pop("fused_activation_function", None)
+            if custom_code is not None:
+                op.attrs["custom_code"] = custom_code
 
     @staticmethod
     def len1_array_to_scalar(arr):
@@ -260,9 +258,10 @@
             msg = "The input file contains operator code {} which is currently not supported".format(c)
             raise InputFileError(self.name, msg)
         op_type, ser = builtin_operator_map[c]
+        custom_code = None
         if c == BuiltinOperator.CUSTOM:
-            op_type += decode_str(code.CustomCode())
-        return op_type, ser
+            custom_code = decode_str(code.CustomCode())
+        return op_type, ser, custom_code
 
 
 def read_tflite(
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 68af487..f444ee5 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -22,6 +22,7 @@
 from flatbuffers.builder import UOffsetTFlags
 
 from .nn_graph import PassPlacement
+from .operation import Op
 from .tensor import MemType
 from .tensor import TensorPurpose
 from .tflite import Buffer
@@ -34,7 +35,6 @@
 from .tflite import Tensor
 from .tflite_mapping import builtin_operator_inv_map
 from .tflite_mapping import BuiltinOperator
-from .tflite_mapping import custom_prefix
 from .tflite_mapping import datatype_inv_map
 
 # ugh, the python flatbuffer interface is missing a method to add in file identifier. patching it in here:
@@ -77,7 +77,7 @@
         self.scratch_fast_buf_id = 1  # Always assign scratch_fast to buffer 1
         self.buffers_to_write = []  # have an empty array there
 
-        self.ops_to_ignore = set(("Const", "Placeholder", "SubgraphInput"))
+        self.ops_to_ignore = set((Op.Const, Op.Placeholder, Op.SubgraphInput))
 
         self.tensors_to_reshape = {}
 
@@ -89,16 +89,17 @@
                 for op in ps.ops:
                     if op.type not in self.ops_to_ignore:
                         all_ops.append(op)
-                    if op.type.startswith("Conv2D") or op.type.startswith("DepthwiseConv2d"):
+                    if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
                         # If values are None op has non-constant weights
                         if op.inputs[1].values is not None:
                             self.tensors_to_reshape[op.inputs[1]] = (3, 0, 1, 2)
-                    if op.type.startswith("FullyConnected"):
+                    if op.type == Op.FullyConnected:
                         # If values are None op has non-constant weights
                         if op.inputs[1].values is not None:
                             self.tensors_to_reshape[op.inputs[1]] = (1, 0)
 
-        self.operator_codes = list(sorted(set(op.type for op in all_ops)))
+        # list of tuple(Op, string); the custom code is only used for 3rd party custom operators
+        self.operator_codes = sorted(set((op.type, op.attrs.get("custom_code", "")) for op in all_ops))
         self.operator_code_map = {}
 
     def write_byte_vector(self, v, alignment=1):
@@ -163,25 +164,25 @@
 
         return buffer_map
 
-    def serialise_operator_code(self, idx, code):
+    def serialise_operator_code(self, idx, op_type, custom_code):
         builder = self.builder
         custom_code_offset = None
-        if code.startswith(custom_prefix):
-            tf_code, opt_serializer = builtin_operator_inv_map[custom_prefix]
-            custom_code_offset = builder.CreateString(code[len(custom_prefix) :])
+        if op_type == Op.Custom:
+            tf_code, opt_serializer = builtin_operator_inv_map[op_type]
+            custom_code_offset = builder.CreateString(custom_code)
         else:
             assert (
-                code in builtin_operator_inv_map
-            ), "Vela does not contain a mapping to serialise {} operator to a TensorFlow Lite operator".format(code)
-            tf_code, opt_serializer = builtin_operator_inv_map[code]
+                op_type in builtin_operator_inv_map
+            ), "Vela does not contain a mapping to serialise {} operator to a TensorFlow Lite operator".format(op_type)
+            tf_code, opt_serializer = builtin_operator_inv_map[op_type]
 
             if tf_code == BuiltinOperator.CUSTOM:
                 assert (
-                    code == "NpuOp"
+                    op_type == Op.CustomNpuOp
                 ), "Vela only supports serialising NpuOp operators as TensorFlow Lite Custom operators"
                 custom_code_offset = builder.CreateString("ethos-u")
 
-        self.operator_code_map[code] = (idx, tf_code, opt_serializer)
+        self.operator_code_map[op_type] = (idx, tf_code, opt_serializer)
 
         OperatorCode.OperatorCodeStart(builder)
         OperatorCode.OperatorCodeAddBuiltinCode(builder, tf_code)
@@ -281,6 +282,8 @@
                 attrs["dilation_w_factor"] = attrs["dilation"][2]
             if "channel_multiplier" in attrs:
                 attrs["depth_multiplier"] = attrs["channel_multiplier"]
+            if op.activation is not None:
+                attrs["fused_activation_function"] = op.activation
 
             builtin_opt_offset, custom_opt_offset = opt_serializer.serialize(builder, attrs)
 
@@ -310,7 +313,7 @@
             for op in ps.ops:
                 if op.type not in self.ops_to_ignore:
                     all_ops.append(op)
-                elif op.type == "Placeholder":
+                elif op.type == Op.Placeholder:
                     placeholder_ops.append(op)
 
         # Add the tensors from all valid ops, as well as the tensors from placeholder ops
@@ -404,7 +407,7 @@
     def serialise_model(self):
         builder = self.builder
         operator_code_offset = self.write_offset_vector(
-            [self.serialise_operator_code(idx, code) for idx, code in enumerate(self.operator_codes)]
+            [self.serialise_operator_code(idx, optype, code) for idx, (optype, code) in enumerate(self.operator_codes)]
         )
 
         description = builder.CreateString("Vela Optimised")
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 8426705..9453521 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -28,6 +28,7 @@
 from .numeric_util import round_up
 from .numeric_util import round_up_divide
 from .operation import NpuBlockType
+from .operation import Op
 from .scaling import quantise_scale
 from .scaling import reduced_quantise_scale
 from .tensor import create_equivalence_id
@@ -336,7 +337,7 @@
     is_depthwise = tens.block_traversal == TensorBlockTraversal.DepthWise
     is_partkernel = tens.block_traversal == TensorBlockTraversal.PartKernelFirst
 
-    if tens.consumer_list[0].type == "Conv2DBackpropInputSwitchedBias":
+    if tens.consumer_list[0].type == Op.Conv2DBackpropInputSwitchedBias:
         # Transpose Convoluion, reverse weights in H and W axes
         weights = np.flip(weights, axis=(0, 1))
 
@@ -406,9 +407,9 @@
     assert tens.purpose == TensorPurpose.FeatureMap
     assert tens.format == TensorFormat.NHWC
     # the connected operator should expect a bias input unless it is a FullyConnected
-    assert "Bias" in tens.consumer_list[0].type or tens.consumer_list[0].type.startswith("FullyConnected")
+    assert tens.consumer_list[0].type.needs_bias()
     # the input bias tensor is the same as that connected to the operator
-    _, _, bias_tens, _ = tens.consumer_list[0].get_ifm_weights_biases_ofm()
+    bias_tens = tens.consumer_list[0].bias
     assert tens is bias_tens
 
     # the operator should only have a single output
@@ -508,14 +509,13 @@
                 op = tens.find_npu_op()
                 if op is None:
                     continue
-                npu_usage_of_tensor = op.attrs["npu_block_type"]
                 needs_dma = tens.needs_dma()
                 if ps.cascade.strategy == SchedulingStrategy.WeightStream and needs_dma:
                     ofm_depth_step = ps.block_config[-1]
                 else:
                     ofm_depth_step = tens.shape[-1]
                 compress_weights(
-                    arch, nng, tens, npu_usage_of_tensor, ps.block_config[-1], ofm_depth_step, op.get_dilation_h_w()
+                    arch, nng, tens, op.type.npu_block_type, ps.block_config[-1], ofm_depth_step, op.get_dilation_h_w()
                 )
                 # Update source tensor
                 if needs_dma:
@@ -527,7 +527,7 @@
 
             if ps.scale_tensor is not None:
                 rescale_for_faf = False
-                activation_ops = set(("Sigmoid", "Tanh"))
+                activation_ops = set((Op.Sigmoid, Op.Tanh))
                 if (ps.ops[-1].type in activation_ops) and (ps.npu_block_type != NpuBlockType.ElementWise):
                     rescale_for_faf = True
                 calc_scales_and_pack_biases(ps.scale_tensor, arch, ofm_depth_step, rescale_for_faf)