MLBEDSW-4838 Added basic TOSA support.

Added basic TOSA support, enabling Vela to
read and compile  a .tosa file corresponding to
CONV2D + Rescale + Clamp, and writing it to an
optimized .tflite file.

The optimized .tflite file, will in this case, hold
a commandstream where the Rescale and Clamp has been
fused into the CONV2D.

The optimized tflite file is not output from Vela.

  -Added support to read .tosa file into Vela
    internal structure.
      - Added tosa_reader.py, tosa_mapper.py and
        helper files stored under tosa/
      - Support for this limited to ~10 ops

    -Added reader_util.py for functions common
    for TOSA and TFLite

    -Added tosa_graph_optimiser.py
      -Added support to fuse Rescale into convolution
      -Modified handling for padding

    -Added support to fuse Clamp to previous op

    -Added graph_optimiser_util.py
      -Moved functions common for TOSA/TFLite graph
       optimization to this file.

    -Renamed graph_optimiser.py to tflite_graph_optmiser.py

    -Added separate tosa_supported_operators.py

    -Added supported_operator_util.py
       -For functions in common for TOSA/TFLite

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: Ic3c540504ec8c5eb4771397fdc6882050ecf33ab
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 8702966..14d098b 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,4 +1,4 @@
-exclude: '^ethosu/vela/(tflite|ethos_u55_regs)/'
+exclude: '^ethosu/vela/(tflite|ethos_u55_regs|tosa)/'
 repos:
 -   repo: https://github.com/asottile/reorder_python_imports
     rev: v2.2.0
diff --git a/README.md b/README.md
index 7d8c09e..eecc3db 100644
--- a/README.md
+++ b/README.md
@@ -23,6 +23,10 @@
 The tool will also generate performance estimates (EXPERIMENTAL) for the
 compiled model.
 
+The tool has limited functionality for compiling a 
+[TOSA](https://git.mlplatform.org/tosa/specification.git/) neural network
+(EXPERIMENTAL).
+
 ## TensorFlow Support
 
 * Vela 2.1.0 to current supports TensorFlow 2.4
diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py
index 19133f5..98d3d8c 100644
--- a/ethosu/vela/architecture_features.py
+++ b/ethosu/vela/architecture_features.py
@@ -38,6 +38,7 @@
 from .tensor import MemType
 from .tensor import TensorFormat
 from .tensor import TensorPurpose
+from .tosa_supported_operators import TosaSupportedOperators
 
 
 class Block:
@@ -398,6 +399,7 @@
 
         # Setup supported operators and restriction checkers class
         self.supported_operators = SupportedOperators()
+        self.tosa_supported_operators = TosaSupportedOperators()
 
     # Returns available number of SHRAM banks depending on activation lookup table
     # being used or not
diff --git a/ethosu/vela/compiler_driver.py b/ethosu/vela/compiler_driver.py
index a9e3839..cb47539 100644
--- a/ethosu/vela/compiler_driver.py
+++ b/ethosu/vela/compiler_driver.py
@@ -146,14 +146,14 @@
             )
 
 
-def compiler_driver(nng, arch, options, scheduler_options):
+def compiler_driver(nng, arch, options, scheduler_options, network_type):
     assert verify_graph_health(nng)
 
     # Pre-optimisation operator tracking
     for sg in nng.subgraphs:
         visit_graph_post_order(sg.output_tensors, arch, [], [_record_operator])
 
-    nng = graph_optimiser.optimise_graph_a(nng, arch, options.verbose_graph)
+    nng = graph_optimiser.optimise_graph(nng, arch, network_type, options.verbose_graph)
     assert verify_graph_health(nng)
 
     if options.verbose_quantization:
diff --git a/ethosu/vela/data_type.py b/ethosu/vela/data_type.py
index 3ad642a..07086d6 100644
--- a/ethosu/vela/data_type.py
+++ b/ethosu/vela/data_type.py
@@ -44,9 +44,11 @@
 
     __slots__ = "type", "bits"
 
+    int4: Any
     int8: Any
     int16: Any
     int32: Any
+    int48: Any
     int64: Any
     uint8: Any
     uint16: Any
@@ -113,9 +115,11 @@
 
 
 # generate the standard set of data types
+DataType.int4 = DataType(BaseType.SignedInt, 4)
 DataType.int8 = DataType(BaseType.SignedInt, 8)
 DataType.int16 = DataType(BaseType.SignedInt, 16)
 DataType.int32 = DataType(BaseType.SignedInt, 32)
+DataType.int48 = DataType(BaseType.SignedInt, 48)
 DataType.int64 = DataType(BaseType.SignedInt, 64)
 
 DataType.uint8 = DataType(BaseType.UnsignedInt, 8)
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index d2598ae..87e3bc8 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -14,1779 +14,32 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # Description:
-# Early optimisation of the network graph, using the rewrite_graph module to do the traversal of the graph. These are
-# split into two parts optimise_graph_a and optimise_graph_b.
-import math
-import uuid
-from typing import Tuple
-
-import numpy as np
-
-from . import fp_math
-from . import lut
+# Early optimisation of the network graph, using the rewrite_graph module to do the traversal of the graph.
 from . import rewrite_graph
-from . import scaling
-from .api import NpuRoundingMode
-from .data_type import DataType
-from .debug_database import DebugDatabase
-from .errors import UnsupportedFeatureError
-from .errors import VelaError
-from .ethos_u55_regs.ethos_u55_regs import resampling_mode
-from .numeric_util import clamp_sigmoid
-from .numeric_util import full_shape
-from .numeric_util import round_away_zero
-from .operation import create_activation_function
-from .operation import NpuBlockType
-from .operation import Op
-from .operation import Operation
-from .operation import Padding
-from .operation_util import create_avgpool_nop
-from .operation_util import get_pad_values_from_input
-from .shape4d import Shape4D
-from .softmax import SoftMax
-from .tensor import check_quantized_tens_scaling_equal
-from .tensor import create_const_tensor
-from .tensor import create_equivalence_id
-from .tensor import QuantizationParameters
-from .tensor import Tensor
-from .tensor import TensorPurpose
-from .tflite_mapping import optype_to_builtintype
+from .graph_optimiser_util import check_format_restrictions
+from .graph_optimiser_util import check_reshapes
+from .graph_optimiser_util import record_optimised
+from .nn_graph import NetworkType
+from .tflite_graph_optimiser import tflite_optimise_graph
+from .tosa_graph_optimiser import tosa_optimise_graph
 
-passthrough_nodes = (Op.Identity,)
 
-memory_only_ops = (Op.Reshape,)
-
-
-def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
-    """Creates an average pool for the given concat op/input feature map"""
-    ofm = concat_op.ofm
-    avgpool_op = create_avgpool_nop(name)
-    avgpool_op.inputs = [ifm]
-    avgpool_op.outputs = [ofm]
-
-    avgpool_op.write_offset = write_offset
-    avgpool_op.write_shape = ifm_shape
-    ofm.ops.append(avgpool_op)
-    DebugDatabase.add_optimised(concat_op, avgpool_op)
-    avgpool_op.ifm_shapes.append(ifm_shape)
-    avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
-    avgpool_op.memory_function = Op.ConcatSliceWrite
-    return avgpool_op
-
-
-def remove_passthrough_tensor(tens, arch, nng):
-    if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
-        assert len(tens.ops[0].inputs) == 1
-        tens = tens.ops[0].inputs[0]
-    return tens
-
-
-def rewrite_concat_ops(op, arch):
-    if not op.run_on_npu or not op.type.is_concat_op():
-        return
-
-    axis_4D = 0
-    ofm = op.ofm
-    ofm.ops = []
-    offset = 0
-
-    unfuse_activation_function(op)
-
-    if op.type == Op.Pack:
-        # Pack is also referred to as Stack
-        axis = int(op.attrs["axis"])
-        if axis < 0:  # Convert to positive axis
-            axis = len(op.inputs[0].shape) + 1 + axis
-
-        desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
-
-        axis_4D = axis + (4 - len(desired_shape))
-
-        for idx, inp in enumerate(op.inputs):
-            op.ifm_shapes[idx] = Shape4D(desired_shape)
-        op.type = Op.PackReshaped
-
-    inputs, axis = op.get_concat_inputs_axis()
-    for idx, inp in enumerate(inputs):
-        if op.type != Op.PackReshaped:
-            op.ifm_shapes[idx] = Shape4D(inp.shape)
-            if axis >= 0:
-                axis_4D = axis + (4 - len(inp.shape))
-            else:
-                axis_4D = axis
-        write_offset = [0, 0, 0, 0]
-        write_offset[axis_4D] = offset
-        concat_end = offset + op.ifm_shapes[idx][axis_4D]
-        create_avg_pool_for_concat(
-            op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
-        )
-        offset = concat_end
-    assert ofm.shape[axis] == offset
-
-    return op
-
-
-def rewrite_split_ops(tens, arch, nng):
-
-    if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
-        split_op = tens.ops[0]
-
-        # Not supported so leave it and run on CPU
-        if not split_op.run_on_npu:
-            return tens
-
-        inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
-
-        tens.ops = []
-        new_op = Operation(Op.SplitSliceRead, split_op.name)
-        new_op.inputs = [inp]
-        ofm_shape_idx = 0
-        read_shape = offset_end
-
-        # For Split the offset cannot be extracted from the tensor so it has to
-        # be calculated from the index of the output tensor
-        if axis is not None:
-            # Get the start and end of the split
-            offset_start = [0] * 4
-            axis_4D_list = split_op.attrs.get("split_axis_4D", None)  # Present for UnpackReshaped and some StridedSlice
-            for idx, out in enumerate(outputs):
-                if axis_4D_list is not None:
-                    axis_4D = axis_4D_list[idx]
-                else:
-                    split_op.ofm_shapes[idx] = Shape4D(out.shape)
-                    if axis >= 0:
-                        axis_4D = axis + (4 - len(out.shape))
-                    else:
-                        axis_4D = axis
-
-                if out == tens:
-                    ofm_shape_idx = idx
-                    read_shape = split_op.ofm_shapes[idx]
-                    break
-
-                offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
-
-        new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
-        new_op.read_shapes[0] = read_shape
-        new_op.run_on_npu = True
-        new_op.set_output_tensor(tens)
-        new_op.ifm_shapes.append(Shape4D(inp.shape))
-        new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
-        DebugDatabase.add_optimised(split_op, new_op)
-
-    return tens
-
-
-def remove_SplitSliceRead(op, arch):
-
-    if op.type == Op.SplitSliceRead:
-        # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
-        if (
-            len(op.ofm.consumer_list) == 1
-            and op.ofm.consumer_list[0] is not None
-            and op.ofm.consumer_list[0].run_on_npu
-            and op.ofm.consumer_list[0].type != Op.Reshape
-            and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
-        ):
-            # SplitSliceRead can be performed by tensor consumer
-            cons_op = op.ofm.consumer_list[0]
-            if cons_op.ifm == op.ofm:
-                cons_op.read_offsets[0] = op.read_offsets[0]
-                cons_op.read_shapes[0] = op.read_shapes[0]
-                cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[0])
-                cons_op.ifm_shapes[0] = op.ifm_shapes[0]
-            elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == op.ofm:
-                cons_op.read_offsets[1] = op.read_offsets[0]
-                cons_op.read_shapes[1] = op.read_shapes[0]
-                cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[1])
-                cons_op.ifm_shapes[1] = op.ifm_shapes[0]
-
-            if "skirt" in cons_op.attrs:
-                assert cons_op.attrs["explicit_padding"] == cons_op.attrs["skirt"]
-                cons_op.attrs["skirt"] = None
-                cons_op.attrs["force_padding"] = True
-            op.ofm.consumer_list.remove(cons_op)
-            op.ofm.ops = []
-            op.ifm.consumer_list.remove(op)
-        else:
-            avgpool_op = create_avgpool_nop(op.name + "_avgpool")
-            avgpool_op.add_input_tensor(op.ifm)
-            avgpool_op.outputs = [op.ofm]
-            op.ofm.ops.remove(op)
-            op.ofm.ops.append(avgpool_op)
-            avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
-            avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
-            avgpool_op.read_offsets[0] = op.read_offsets[0]
-            avgpool_op.read_shapes[0] = op.read_shapes[0]
-
-            op.ifm.consumer_list.remove(op)
-            DebugDatabase.add_optimised(op, avgpool_op)
-
-
-def avoid_nhcwb16_for_concat(tens):
-    # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
-    # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
-    # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
-    # and those addresses are always 16 byte aligned due to the NHCWB16 format.
-    return any(op.write_offset.depth % 16 != 0 for op in tens.ops if op.write_offset is not None)
-
-
-def avoid_nhcwb16_for_split(tens):
-    # If read offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
-    for cons_op in tens.consumer_list:
-        if cons_op.ifm == tens:
-            read_offset = cons_op.read_offsets[0]
-        elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens:
-            read_offset = cons_op.read_offsets[1]
-        else:
-            assert False
-        if read_offset is not None and (read_offset[-1] % 16) != 0:
-            return True
-    return False
-
-
-def avoid_nhcwb16_for_shapes(tens):
-    # check all producers/consumers to see if any op shape is preventing NHCWB16
-    for cons_op in tens.consumer_list:
-        if cons_op.ifm == tens:
-            cons_op_shape = cons_op.ifm_shapes[0]
-        elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens:
-            cons_op_shape = cons_op.ifm_shapes[1]
-        else:
-            assert False
-        if Shape4D(tens.shape) != cons_op_shape:
-            return True
-
-    for prod_op in tens.ops:
-        if Shape4D(tens.shape) != prod_op.ofm_shapes[0]:
-            return True
-
-    return False
-
-
-# Check if non linear format can be used
-def check_format_restrictions(tens, arch):
-    if len(tens.ops) < 1:
-        return
-    if tens.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const) or any(
-        cons is None for cons in tens.consumer_list
-    ):
-        return
-
-    # Check if any of the producers/consumers is run on CPU
-    if not all(cons.run_on_npu for cons in tens.consumer_list):
-        return
-    if not all(prod.run_on_npu for prod in tens.ops):
-        return
-
-    # "Concat" ofm exception:
-    if avoid_nhcwb16_for_concat(tens):
-        return
-
-    # "Split" ifm exception:
-    if avoid_nhcwb16_for_split(tens):
-        return
-
-    # Shapes checking: check all producers/consumers are NHCWB16 compatible with tens.shape
-    if avoid_nhcwb16_for_shapes(tens):
-        return
-
-    for op in tens.consumer_list:
-        if op.type == Op.ReduceSum and tens.dtype == DataType.int32:
-            return
-        if op.type == Op.Reshape:
-            # Using NHCWB16 format for a no-op reshape is only an option if subsequent
-            # consumers do not also need to perform a reshape or if the OFM is going to
-            # be processed by CPU operations. No-op reshape consumers with empty lists
-            # (those that have no consumers, or null-consumers used as list terminators)
-            # must use normal NHWC output.
-
-            def incompatible_consumers(oper):
-                if oper and oper.type == Op.Reshape:
-                    for consumer in oper.outputs[0].consumer_list:
-                        yield from incompatible_consumers(consumer)
-                yield not oper or not oper.run_on_npu
-
-            if not any(incompatible_consumers(op)):
-
-                def get_rewrites(oper):
-                    if oper and oper.type == Op.Reshape:
-                        for consumer in oper.outputs[0].consumer_list:
-                            yield from get_rewrites(consumer)
-                        yield oper
-
-                # Detect no-op reshapes by comparing their full input and output tensor shapes.
-                inshape = op.ifm_shapes[0]
-                compatible_shape = [(inshape == oper.ofm_shapes[0]) for oper in get_rewrites(op)]
-                if not (compatible_shape and all(compatible_shape)):
-                    return
-            else:
-                return
-
-    tens.needs_linear_format = False
-
-
-def insert_copy_op_after_tens(tens):
-    tens_cons_list_copy = tens.consumer_list.copy()
-
-    # Create a avg_pool nop op with ifm as input
-    copy_tens = tens.clone()
-    copy_op = create_avgpool_nop(tens.name + "_avgpool")
-    copy_op.add_input_tensor(tens)
-    copy_op.set_output_tensor(copy_tens)
-    copy_op.set_ifm_ofm_shapes()
-    copy_op.run_on_npu = True
-
-    # Set copy_ifm consumers
-    for tens_cons in tens_cons_list_copy:
-        if tens_cons is not None:
-            for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
-                if cons_inp == tens:
-                    tens_cons.set_input_tensor(copy_tens, ifm_idx)
-
-    DebugDatabase.add_optimised(tens.ops[0], copy_op)
-
-
-def fix_sg_input_output(op, arch, nng):
-    if not op.run_on_npu or op.type != Op.Reshape:
-        return op
-
-    # For the Reshape operators we want to remove, tensors are removed.
-    # But in order to to do this, they cannot be outputs of the sg,
-    # this need to be fixed prior to the removal.
-    # Solution is to add a avgpool NOP, to maintain the original tensor.
-    # This is also valid when reshape ifm/ofm is produced respectively
-    # consumed by CPU
-
-    # Check if operator ifm/ofm are sg ifm/ofm
-    ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
-    ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
-    ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
-    # Check if ifm/ofm is produced repectivly consumed by CPU
-    ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
-    ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
-
-    if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
-        # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape
-        insert_copy_op_after_tens(op.ifm)
-
-    return op
-
-
-def needed_total_padding(input_size, stride, filter_size):
-    out_size = (input_size + stride - 1) // stride
-    needed_input = (out_size - 1) * stride + filter_size
-    total_padding = max(0, needed_input - input_size)
-    return total_padding
-
-
-def calc_explicit_padding(input_size, stride, filter_size, pad_before, pad_after) -> Tuple[int, int]:
-    """
-    Based on explicit padding provided in a PAD operation, returns the corresponding hardware padding
-    that provides equivalent results.
-    """
-    total_padding = needed_total_padding(input_size, stride, filter_size)
-    # The top/left padding can be taken as is from the PAD
-    output_pad_before = pad_before
-    # The bottom/right padding might need downward adjustment depending on stride/input size
-    output_pad_after = pad_after
-    while output_pad_after > 0 and output_pad_after % stride != (total_padding - pad_before) % stride:
-        output_pad_after -= 1
-    return output_pad_before, output_pad_after
-
-
-def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
-    k_w, k_h = kernel.dilated_wh()
-    s_x, s_y = kernel.stride
-    ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
-    xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
-    if padding_type == Padding.SAME:
-        left_pad = (xpad + 0) // 2
-        right_pad = (xpad + 1) // 2
-        top_pad = (ypad + 0) // 2
-        bottom_pad = (ypad + 1) // 2
-    elif padding_type == Padding.VALID:
-        left_pad = 0
-        right_pad = 0
-        top_pad = 0
-        bottom_pad = 0
-    elif padding_type == Padding.EXPLICIT:
-        # Padding is specified in a PAD operator which has been bypassed.
-        top, left, bottom, right = explicit_padding
-        top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
-        left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
-    else:
-        raise UnsupportedFeatureError(f"Unknown padding")
-    padding = (top_pad, left_pad, bottom_pad, right_pad)
-    skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
-    return padding, skirt
-
-
-def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
-    kernel_height, kernel_width = kernel_size[0], kernel_size[1]
-    if padding_type == Padding.SAME:
-        ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
-        xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
-        right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
-        bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
-        left_pad = max(kernel_width - 1 - right_pad, 0)
-        top_pad = max(kernel_height - 1 - bottom_pad, 0)
-    elif padding_type == Padding.VALID:
-        right_pad = max(kernel_width - 2, 0)
-        bottom_pad = max(kernel_height - 2, 0)
-        left_pad = kernel_width - 1
-        top_pad = kernel_height - 1
-    else:
-        raise UnsupportedFeatureError(f"Unknown padding")
-    padding = (top_pad, left_pad, bottom_pad, right_pad)
-    skirt = padding
-    return padding, skirt
-
-
-def fixup_conv2d_backprop(op, arch, nng):
-    if op.type == Op.Conv2DBackpropInput:
-        # flip the inputs
-        op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
-        op.type = Op.Conv2DBackpropInputSwitchedBias
-        op.ifm.resampling_mode = resampling_mode.TRANSPOSE
-
-        # Update strides
-        op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
-
-    return op
-
-
-# Convert the op to an elementwise add
-def convert_resizebilinear_1x1_to_add(op):
-    op.type = Op.Add
-    op.name = op.name + "_add"
-    op.attrs["resizebilinear"] = True
-    # Create an input tensor filled with zeros
-    shape = op.ofm_shapes[0].as_list()
-    tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
-    tens.values = np.zeros(shape)
-    tens.quant_values = np.zeros(shape, np.uint8)
-    tens.quantization = QuantizationParameters(0.0, 255.0)
-    tens.quantization.scale_f32 = 1.0
-    tens.quantization.zero_point = 0
-    tens.consumer_list = [op]
-    tens_op = op.inputs[1].ops[0]
-    tens_op.set_output_tensor(tens)
-    # Set the add inputs
-    op.inputs[1] = op.inputs[0]
-    op.inputs[0] = tens
-    op.set_ifm_ofm_shapes()
-
-    return op
-
-
-# Convert ResizeBilinear to a number of 2x2 pool ops
-def convert_resizebilinear_to_2x2_pool(op):
-    count = 0
-    pre_op = op
-    outputs = op.outputs
-
-    op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
-    if op.attrs["align_corners"]:
-        shape_modifier = 1
-        op.attrs["padding"] = Padding.VALID
-    else:
-        shape_modifier = 0
-        op.attrs["padding"] = Padding.SAME
-    op.inputs[0].resampling_mode = resampling_mode.NEAREST
-
-    upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
-    out_shape = np.array(op.ofm_shapes[0].get_hw_as_list())
-    if (upscaled_shape == upscaled_shape * 2 - shape_modifier).all():
-        return op
-
-    while (upscaled_shape < out_shape).all():
-        if count == 0:
-            scaled_op = pre_op
-        else:
-            scaled_op = op.clone("_{}".format(count))
-            scaled_op.inputs[0] = pre_op.outputs[0]
-
-        upscaled_shape = upscaled_shape * 2 - shape_modifier
-
-        if (upscaled_shape == out_shape).all():
-            scaled_op.outputs = outputs
-            scaled_op.outputs[0].ops = [scaled_op]
-        else:
-            shape = op.ofm_shapes[0].as_list()
-            shape[1:3] = upscaled_shape
-            out_tens = Tensor(shape, DataType.int16, "{}_{}".format(op.outputs[0].name, count))
-            out_tens.quantization = op.outputs[0].quantization.clone()
-            out_tens.quantization.quant_min = np.iinfo(np.int16).min
-            out_tens.quantization.quant_max = np.iinfo(np.int16).max
-            scaled_op.set_output_tensor(out_tens)
-            pre_op = scaled_op
-            count += 1
-
-        # Setup the scale value
-        if scaled_op.inputs[0].dtype.bits == 8 and scaled_op.outputs[0].dtype.bits == 16:
-            scaled_op.rescale = 128
-        elif scaled_op.inputs[0].dtype.bits == 16 and scaled_op.outputs[0].dtype.bits == 8:
-            scaled_op.rescale = 1 / 128
-        else:
-            scaled_op.rescale = None
-        scaled_op.set_ifm_ofm_shapes()
-
-    return op
-
-
-def fixup_resizebilinear(op, arch, nng):
-    if op.type == Op.ResizeBilinear and op.run_on_npu:
-        if op.ifm_shapes[0] == op.ofm_shapes[0]:
-            # Bypass nop resizebilinear
-            op.inputs = op.inputs[:1]
-            op.type = Op.Identity
-        elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
-            convert_resizebilinear_1x1_to_add(op)
-        else:
-            convert_resizebilinear_to_2x2_pool(op)
-
-    return op
-
-
-def convert_nop_split_to_identity(op, arch, nng):
-    if op.type == Op.Split and op.attrs.get("num_splits") == 1:
-        # the list comprehension should return a list with a single tensor
-        # if it shouldn't, remove_passthrough_tensor will fail appropriately
-        op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
-        op.type = Op.Identity
-    return op
-
-
-def rewrite_fully_connected_input(op, arch, nng):
-    if op.type == Op.FullyConnected:
-        n_in_elems = op.weights.shape[-2]
-        elms = op.ifm.elements()
-        batch_size = elms // n_in_elems
-        assert batch_size * n_in_elems == elms
-
-        op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
-    return op
-
-
-def convert_batched_fc_shape(op, arch, nng):
-    if op.type == Op.FullyConnected:
-        # Check if the first dimension indicates batching
-        if op.ifm_shapes[0].batch > 1:
-            batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
-            n = op.ifm_shapes[0].batch
-            h, w = batching_split.get(n, (1, n))
-            op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
-
-            # Reshape Weights to be 4D. IO becomes HWIO
-            weight_tensor = op.inputs[1]
-            weight_tensor.quant_values = np.expand_dims(np.expand_dims(weight_tensor.quant_values, axis=0), axis=0)
-            weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
-
-            n = op.ofm_shapes[0].batch
-            h, w = batching_split.get(n, (1, n))
-            op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
-    return op
-
-
-def unfuse_activation_function(op):
-    if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
-        act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
-        op.activation = None
-        out_tens = op.outputs[0]
-        intermediate_tens = out_tens.clone("_act_intermediate")
-        act_op.set_output_tensor(out_tens)
-        act_op.add_input_tensor(intermediate_tens)
-        op.set_output_tensor(intermediate_tens)
-        act_op.set_ifm_ofm_shapes()
-
-
-def rewrite_stridedslice_output(op, arch, nng):
-    if not op.run_on_npu or op.type != Op.StridedSlice:
-        return op
-
-    new_axis_mask = op.attrs["new_axis_mask"]
-    shrink_axis_mask = op.attrs["shrink_axis_mask"]
-
-    if shrink_axis_mask == 0 and new_axis_mask == 0:
-        return op
-
-    axis_4D = [0] * len(op.outputs)
-    for idx, out_tens in enumerate(op.outputs):
-        output_shape = list(out_tens.shape)
-
-        if shrink_axis_mask != 0:
-            n = 0
-            axis = 0
-            while shrink_axis_mask:
-                prev_mask = shrink_axis_mask
-                n += 1
-                shrink_axis_mask &= shrink_axis_mask - 1
-                axis = int(math.log2(prev_mask - shrink_axis_mask))
-                output_shape = output_shape[:axis] + [1] + output_shape[axis:]
-
-            assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
-            op.attrs["shrink_axis_mask"] = 0
-            if axis >= 0:
-                axis_4D[idx] = axis + (4 - len(output_shape))
-            else:
-                axis_4D[idx] = axis
-            op.ofm_shapes[idx] = Shape4D(output_shape)
-
-        elif new_axis_mask != 0:
-            n = 0
-            axis = 0
-            while new_axis_mask:
-                prev_mask = new_axis_mask
-                n += 1
-                new_axis_mask &= new_axis_mask - 1
-                axis = int(math.log2(prev_mask - new_axis_mask))
-                output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
-                new_axis_mask >>= 1
-
-            assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
-            op.attrs["new_axis_mask"] = 0
-            if axis >= 0:
-                axis_4D[idx] = axis + (4 - len(output_shape))
-            else:
-                axis_4D[idx] = axis
-            op.ofm_shapes[idx] = Shape4D(output_shape)
-
-    op.attrs["split_axis_4D"] = axis_4D
-    return op
-
-
-def rewrite_unpack_output(op, arch, nng):
-    tens = op.outputs[0]
-    if op.run_on_npu and op.type == Op.Unpack:
-        # Unpack is also referred to as Unstack
-        axis = int(op.attrs["axis"])
-        if axis < 0:  # Convert to positive axis
-            axis = len(op.inputs[0].shape) + 1 + axis
-        op.type = Op.UnpackReshaped
-        desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
-
-        axis_4D = axis + (4 - len(desired_output_shape))
-        op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
-
-        for idx, out_tens in enumerate(op.outputs):
-            op.ofm_shapes[idx] = Shape4D(desired_output_shape)
-    return op
-
-
-def add_padding_fields(op, arch, nng):
-    if op.run_on_npu:
-        if "padding" in op.attrs:
-            input_shape = op.ifm_shapes[0]
-            output_shape = op.ofm_shapes[0]
-            if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
-                kernel_size = op.inputs[1].shape[:2]
-            elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
-                kernel_size = op.attrs["ksize"][1:3]
-            else:
-                raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
-
-            if op.type == Op.Conv2DBackpropInputSwitchedBias:
-                upscaling_factor = output_shape.height // input_shape.height
-                padding, skirt = calc_upscaled_padding_and_skirt(
-                    op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
-                )
-            else:
-                padding, skirt = calc_padding_and_skirt(
-                    op.attrs["padding"], op.kernel, input_shape, op.attrs.get("explicit_padding"),
-                )
-
-            op.attrs["explicit_padding"] = padding
-            op.attrs["skirt"] = skirt
-
-    return op
-
-
-def convert_depthwise_to_conv(op, arch, nng):
-    # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and
-    # the ofm depth equals the depth multipler.
-    # If those conditions are true, then we can perform a simple
-    # switch of the operator type (and weight order)
-
-    if op.type == Op.DepthwiseConv2DBias and (op.attrs["depth_multiplier"] != 1):
-        ifm_shape = op.ifm_shapes[0]
-        weight_tensor = op.inputs[1]
-        ofm_shape = op.ofm_shapes[0]
-        if (ifm_shape.depth == 1) and (ofm_shape.depth == op.attrs["depth_multiplier"]):
-            # Change op type to Conv2d
-            op.type = Op.Conv2DBias
-            del op.attrs["channel_multiplier"]
-            del op.attrs["depth_multiplier"]
-
-            weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2))
-            weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
-        else:
-            raise UnsupportedFeatureError(
-                f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},",
-                f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}",
-            )
-        DebugDatabase.add_optimised(op, op)
-    return op
-
-
-def reorder_depthwise_weights(op, arch, nng):
-    if op.type.is_depthwise_conv2d_op():
-        weight_tensor = op.inputs[1]
-        weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2))
-        weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
-        weight_tensor.weight_transpose_depthwise = True
-
-    return op
-
-
-def optimise_strided_conv(op, arch, nng):
-    stride_x, stride_y = op.get_kernel_stride()
-    ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()
-
-    if (
-        op.type == Op.Conv2DBias
-        and op.op_index == 0
-        and stride_x == 2
-        and op.ifm_shapes[0].depth <= 4
-        and op.ifm_shapes[0].width % 2 == 0
-        and weight_tensor is not None
-        and weight_tensor.shape[1] >= 2
-    ):
-        ifm_shape = op.ifm_shapes[0]
-        # IFM
-        op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
-
-        # Weights
-        weight_shape = weight_tensor.shape
-        if weight_shape[1] % 2 != 0:
-            weight_shape[1] = weight_shape[1] + 1
-            padded_array = np.zeros(weight_shape)
-            for i in range(weight_shape[0]):
-                padded_array[i] = np.vstack(
-                    [
-                        weight_tensor.quant_values[i],
-                        np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
-                    ]
-                )
-            weight_tensor.quant_values = padded_array
-        weight_shape[1] //= 2
-        weight_shape[2] *= 2
-        weight_tensor.quant_values = np.reshape(weight_tensor.quant_values, weight_shape)
-        weight_tensor.set_all_shapes(weight_shape)
-        # If multiple copies of the weights are used, we could avoid
-        # them having the same address by changing the value_id
-        weight_tensor.value_id = uuid.uuid4()
-
-        # Strides
-        stride_x = 1
-        op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
-
-    return op
-
-
-def convert_conv_to_fc(op, arch, nng):
-    # Conv 1x1 can be equivalent to Fully Connected.
-    # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
-    # caching/double buffering for the weights.
-    # (Weights dont need to be reloaded for convs when IFM H and W are 1)
-    if op.type == Op.Conv2DBias:
-        h = op.ifm_shapes[0].height
-        w = op.ifm_shapes[0].width
-        kh, kw, _, _ = op.inputs[1].shape
-        if h == 1 and w == 1 and kh == 1 and kw == 1:
-            # Overwrite this op as a Fully Connected Op
-            op.name += "_fc"
-            op.type = Op.FullyConnected
-            op.attrs = {
-                "weights_format": 0,
-            }
-            # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
-            weight_tensor = op.inputs[1]
-            weight_tensor.quant_values = weight_tensor.quant_values.squeeze(axis=(0, 1))
-            weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
-
-            DebugDatabase.add_optimised(op, op)
-    return op
-
-
-def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
-    if op.run_on_npu and op.type.is_relu_op():
-        ifm = op.inputs[0]
-        ofm = op.outputs[0]
-        # Relu with differing IFM and OFM scaling cannot be fused with another primary op
-        # and requires its own to be inserted
-        if not check_quantized_tens_scaling_equal(ifm, ofm):
-            # Override this op with its own primary op (avgpool)
-            relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
-            # And fuse the original activation function to it
-            relu_fused_op.activation = create_activation_function(op.type)
-            # Tidy up and assign the ifm and ofm to the new op
-            ifm.consumer_list.remove(op)
-
-            relu_fused_op.add_input_tensor(ifm)
-            relu_fused_op.set_output_tensor(ofm)
-            relu_fused_op.set_ifm_ofm_shapes()
-            op = relu_fused_op
-    return op
-
-
-def fixup_elementwise_with_scalars(op, arch, nng):
-    if op.type.is_binary_elementwise_op():
-        ifm_tensor, ifm2_tensor, _, _ = op.get_ifm_ifm2_weights_ofm()
-        if ifm2_tensor.shape != [] and ifm_tensor.shape != []:
-            diff = len(ifm_tensor.shape) - len(ifm2_tensor.shape)
-            if diff > 0:
-                ifm2_tensor.shape = full_shape(len(ifm_tensor.shape), ifm2_tensor.shape, 1)
-            elif diff < 0:
-                ifm_tensor.shape = full_shape(len(ifm2_tensor.shape), ifm_tensor.shape, 1)
-        elif ifm_tensor.shape == [] and ifm_tensor.quant_values is None:
-            # IFM is marked as a scalar, but is a result of an operation; change it to a shape of size 1
-            ifm_tensor.shape = len(ifm2_tensor.shape) * [1]
-            ifm_tensor.storage_shape = ifm_tensor.shape
-        elif ifm2_tensor.shape == [] and ifm2_tensor.quant_values is None:
-            # IFM2 is marked as a scalar, but is a result of an operation; change it to a shape of size 1
-            ifm2_tensor.shape = len(ifm_tensor.shape) * [1]
-            ifm2_tensor.storage_shape = ifm2_tensor.shape
-    return op
-
-
-# Set input/output tensor equivalence to the same id for memory operations
-def set_tensor_equivalence(op, arch, nng):
-    if op.type in memory_only_ops:
-        eid = op.outputs[0].equivalence_id
-        for inp in op.inputs:
-            inp.equivalence_id = eid
-    return op
-
-
-def set_ifm_ofm_op_shapes(op, arch, nng):
-    if op.run_on_npu and op.type.needs_shapes():
-        if op.ifm_shapes or op.ofm_shapes:
-            # Shapes already set
-            return op
-        op.set_ifm_ofm_shapes()
-    return op
-
-
-def convert_softmax(op, arch, nng):
-    if op.type == Op.Softmax and op.run_on_npu:
-        softmax = SoftMax(op)
-        op = softmax.get_graph()
-    return op
-
-
-def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
-    r"""Whenever there is a subgraph with this topology:
-
-       Input    X   For X = -1 or X > 0
-       |   \   /    This subgraph can be replaced with either
-       |    Mul     an Abs (if X = -1) or a LeakyReLU (if X > 0)
-       |   /
-       Max
-    """
-
-    if op.type == Op.Maximum:
-        # finds the Mul input(s) to the Max
-        muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
-        if len(muls) == 1:
-            mul = muls[0].ops[0]
-        elif len(muls) == 2:
-            # In the case both inputs are Muls, find the one with the same input as the Max
-            mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
-        else:
-            # No Mul inputs
-            return op
-
-        # make sure the Mul doesn't have any other consumers
-        mul_ofm = mul.outputs[0]
-        if len(mul_ofm.consumers()) != 1:
-            return op
-        # make sure the Mul doesn't have a fused activation function
-        if mul.activation:
-            return op
-        ifm, ofm = op.get_ifm_ofm()
-        if ifm is None or ofm is None:
-            return op
-
-        if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
-            return op
-        if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
-            # rewrite to LeakyRelu currently only makes sense if the quantization is identical
-            return op
-
-        # finds the branched input that goes to both the Max and the Mul
-        shared = set(op.inputs) & set(mul.inputs)
-        if len(shared) == 1:
-            shared_in = shared.pop()
-            # find the constant scalar input to the Mul
-            const_tens = (set(mul.inputs) - {shared_in}).pop()
-            # check that it is a scalar
-            if const_tens.shape != []:
-                return op
-            const = const_tens.ops[0]
-            # check that it is a constant
-            if const.type != Op.Const:
-                return op
-            # Remove the Mul from the shared input's consumers
-            shared_in.consumer_list.remove(mul)
-        else:
-            return op
-
-        val = const.outputs[0].values
-        if val >= 0:
-            new_op = Op.LeakyRelu
-            op.attrs["alpha"] = val
-            # to produce bit exact results, the alpha is not enough;
-            # save additional scaling info in attr "alpha_scale", to be used as input
-            # to the LUT construction
-            alpha_scalar = const_tens.quant_values - const_tens.quantization.zero_point
-            mul_ifm_scale = np.double(ifm.quantization.scale_f32)
-            mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
-            mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
-            alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
-            op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
-        elif val == -1:
-            new_op = Op.Abs
-        else:
-            return op
-
-        op.type = new_op
-        op.name = op.name.replace("Maximum", new_op.name)
-        op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
-        op.inputs = [shared_in]
-        op.set_ifm_ofm_shapes()
-
-        # Record optimisation in debug database
-        DebugDatabase.add_optimised(op, op)
-
-    return op
-
-
-def convert_hardswish_to_lut(op, arch, nng):
-    if op.type == Op.HardSwish:
-        ifm, ofm = op.get_ifm_ofm()
-        # Generate the LUT
-        ifm_scale = np.double(ifm.quantization.scale_f32)
-        ofm_scale = np.double(ofm.quantization.scale_f32)
-        zp_in = ifm.quantization.zero_point
-        zp_out = ofm.quantization.zero_point
-        ifm_scale_hires = (1 / 128) * ifm_scale
-        relu_multiplier = np.double(3 / 32768)
-        out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
-        relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
-        # Use 16bit scale
-        out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
-        relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
-
-        values = []
-        ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
-        quantized_min = min(ix)
-        quantized_max = max(ix)
-        for x in ix:
-            input_value = x - zp_in
-            input_value_hires = input_value * 128
-            # Compute the input value on essentially the output scale, not shifted yet
-            input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
-            # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
-            relu_value = np.int16(input_value_hires)
-            if relu_shift < 31:
-                relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
-
-            relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
-
-            if relu_shift < 31:
-                relu_value = fp_math.shift_left16(relu_value, 1)
-
-            if relu_shift > 31:
-                relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
-
-            # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
-            # Now convert that to a 16bit fixedpoint value in [0, 1]
-            relu_value = (relu_value + (1 << 15)) >> 1
-            lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
-            shift = 31 - out_shift
-            shift = -shift if shift < 0 else 0
-            # Finally apply the output shift
-            lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
-            lut_result = min(quantized_max, max(quantized_min, lut_result))
-            values.append(lut_result)
-        return convert_to_lut(op, values, "hardswish")
-    return op
-
-
-def convert_lrelu_to_mul_max(op, arch):
-    # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
-    # (the opposite of convert_mul_max_to_abs_or_lrelu)
-    ifm, ofm = op.get_ifm_ofm()
-    if ifm is None or ofm is None:
-        return op
-
-    # Add multiplication with alpha
-    mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
-    mul_alpha.add_input_tensor(ifm)
-    # Create const tensor containing alpha as scalar
-    alpha = op.attrs["alpha"]
-    quantization = ifm.quantization.clone()
-    quantization.min = 0
-    quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
-    quantization.zero_point = 0
-    if np.isinf(1 / np.float32(alpha)):
-        # Handling of alpha near zero
-        quantization.scale_f32 = 1
-        scalar = 0
-    else:
-        quantization.scale_f32 = alpha
-        scalar = alpha
-    alpha_tens = create_const_tensor(
-        op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
-    )
-    alpha_tens.quant_values = np.array([1])
-    mul_alpha.add_input_tensor(alpha_tens)
-    fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
-    mul_alpha.set_output_tensor(fm_alpha)
-    mul_alpha.set_ifm_ofm_shapes()
-    DebugDatabase.add_optimised(op, mul_alpha)
-
-    if check_quantized_tens_scaling_equal(ifm, ofm):
-        # No identity multiplication is needed
-        fm_id = ifm
-    else:
-        # Add multiplication with identity
-        mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
-        mul_identity.add_input_tensor(ifm)
-        # Create const tensor containing identity as scalar
-        quantization = ifm.quantization.clone()
-        quantization.min = 0
-        quantization.max = quantization.quant_max - quantization.quant_min
-        quantization.scale_f32 = 1
-        quantization.zero_point = 0
-        identity_tens = create_const_tensor(
-            op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
-        )
-        mul_identity.add_input_tensor(identity_tens)
-        # Make sure that fm_id is allocated to a different address than fm_alpha
-        fm_id = ofm.clone(op.name + "_id", set_unique=True)
-        mul_identity.set_output_tensor(fm_id)
-        mul_identity.set_ifm_ofm_shapes()
-        DebugDatabase.add_optimised(op, mul_identity)
-
-    # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
-    op.type = Op.Maximum
-    op.name = op.name.replace("LeakyRelu", "Maximum")
-    op.inputs = []
-    ifm.consumer_list.remove(op)
-    op.add_input_tensor(fm_alpha)
-    op.add_input_tensor(fm_id)
-    op.set_ifm_ofm_shapes()
-
-    DebugDatabase.add_optimised(op, op)
-    return op
-
-
-def convert_to_lut(op, lut_values, lut_name):
-    # Rewrite the operation by Add with scalar 0 + LUT activation
-    ifm = op.inputs[0]
-    if ifm is None:
-        return op
-    assert ifm.dtype.size_in_bytes() == 1
-    op.type = Op.Add
-    op.name = op.name + "_lut_" + lut_name
-    # Mark as no-op to enable potential fusing optimizations
-    op.attrs["is_nop"] = True
-    # Create an input tensor containing scalar zero
-    quantization = QuantizationParameters(0.0, 255.0)
-    quantization.scale_f32 = ifm.quantization.scale_f32
-    quantization.zero_point = 0
-    tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization)
-    op.add_input_tensor(tens)
-    op.ifm_shapes.append(Shape4D(tens.shape))
-
-    # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
-    # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
-    # should be the same as the IFM
-    op.forced_output_quantization = ifm.quantization
-    lut_tensor = lut.create_lut_tensor(op.name + "_values", lut_values, DataType.int8)
-    op.set_activation_lut(lut_tensor)
-    op.set_ifm_ofm_shapes()
-    return op
-
-
-def convert_to_lut8(op, fn, fn_name):
-    # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
-    # fn is a function(real) -> real
-    ifm, ofm = op.get_ifm_ofm()
-    if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
-        return op
-    # Generate the LUT
-    ifm_scale = np.double(ifm.quantization.scale_f32)
-    ofm_scale = np.double(ofm.quantization.scale_f32)
-    zp_in = ifm.quantization.zero_point
-    zp_out = ofm.quantization.zero_point
-    values = []
-    ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
-    quantized_min = min(ix)
-    quantized_max = max(ix)
-    for x in ix:
-        x_real = ifm_scale * (x - zp_in)
-        y_real = fn(x_real)
-        lut_result = round_away_zero(zp_out + y_real / ofm_scale)
-        lut_result = min(quantized_max, max(quantized_min, lut_result))
-        values.append(lut_result)
-    return convert_to_lut(op, values, fn_name)
-
-
-def convert_lrelu_to_lut(op, arch):
-    ifm, ofm = op.get_ifm_ofm()
-    # Generate the LUT
-    alpha = op.attrs["alpha"]
-    ifm_scale = np.double(ifm.quantization.scale_f32)
-    ofm_scale = np.double(ofm.quantization.scale_f32)
-    zp_in = ifm.quantization.zero_point
-    zp_out = ofm.quantization.zero_point
-    identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
-    alpha_scalar = 1
-    alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
-    if "alpha_scaling" in op.attrs:
-        # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
-        alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
-    values = []
-    ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
-    quantized_min = min(ix)
-    quantized_max = max(ix)
-    for x in ix:
-        if x < zp_in:
-            lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
-                alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
-            )
-        else:
-            lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
-        lut_result = min(quantized_max, max(quantized_min, lut_result))
-        values.append(lut_result)
-    return convert_to_lut(op, values, "lrelu")
-
-
-def convert_lrelu(op, arch, nng):
-    # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
-    if op.type != Op.LeakyRelu:
-        return op
-    ifm, ofm = op.get_ifm_ofm()
-    if ifm is None or ofm is None:
-        return op
-    if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
-        # use LUT for int8/uint8
-        return convert_lrelu_to_lut(op, arch)
-    if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
-        # use LeakyRelu unmodified for int16 with equal input/output scaling
-        return op
-    return convert_lrelu_to_mul_max(op, arch)
-
-
-def convert_tanh_sigmoid_to_lut(op, arch, nng):
-    # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
-    if op.type == Op.Sigmoid:
-        return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
-    elif op.type == Op.Tanh:
-        return convert_to_lut8(op, math.tanh, "tanh")
-    return op
-
-
-def remove_reshapes(op, arch):
-    if op.run_on_npu and op.type == Op.Reshape:
-        ofm = op.ofm
-        ifm = op.ifm
-
-        # Check if quantization is the same in the input and output for the reshape ops
-        if not check_quantized_tens_scaling_equal(ifm, ofm):
-            # TODO Both tensors are needed, since quantisation properties currently are linked to Tensors.
-            # In order to remove this reshape either quantization properties need to be moved to Operator,
-            # or the reshape need to be replace with a NOP.
-            return
-
-        # Check if Reshape ifm/ofm are network ifm/ofm
-        ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
-        ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list)
-        ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list)
-        # Check if ifm/ofm is produced repectivly consumed by CPU
-        ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
-        ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
-
-        # This case should be handled prior to this function
-        assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed))
-
-        if ofm_is_sg_ofm or ofm_is_cpu_consumed:
-            # Bypassed by replacing ifm with ofm
-            ofm.ops = []
-            for prev_op in ifm.ops:
-                prev_op.outputs = [ofm]
-                ofm.ops.append(prev_op)
-
-            # All ifm consumers need to use ofm as input
-            for ifm_cons in ifm.consumer_list:
-                for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
-                    if cons_ifm == ifm:
-                        ifm_cons.set_input_tensor(ofm, ifm_idx)
-        else:
-            # Bypassed Reshape by replacing ofm with ifm
-            for cons in ofm.consumer_list:
-                for ifm_idx, cons_ifm in enumerate(cons.inputs):
-                    if cons_ifm == ofm:
-                        cons.set_input_tensor(ifm, ifm_idx)
-
-
-def check_reshapes(op, arch):
-    if op.run_on_npu and op.type == Op.Reshape:
-        ofm = op.ofm
-
-        if check_quantized_tens_scaling_equal(op.ifm, ofm):
-            # Reshape should have been removed
-            raise VelaError(f"Reshape op {op} expected to have been removed, still remains")
-
-
-def fuse_activation_function_with_prev(op, arch, nng):
-    # if op is a no-op: attempts to move the activation function to the preceding op
-    if not op.attrs.get("is_nop", False) or op.activation is None:
-        return op
-    ifm, ofm = op.get_ifm_ofm()
-    if ifm is None or ofm is None:
-        return op
-    # finds the input(s) to the operation
-    prev_op = ifm.ops[0]
-    # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
-    fuse = (
-        prev_op.run_on_npu
-        and prev_op.type.npu_block_type != NpuBlockType.Default
-        and len(ifm.ops) == 1
-        and len(prev_op.outputs[0].consumers()) == 1
-        and prev_op.activation is None
-    )
-    if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
-        # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
-        # LUT currently only works correctly for elementwise ops
-        fuse = False
-    if not fuse:
-        return op
-    # Move the fused activation function + corresponding info to prev_op
-    prev_op.activation = op.activation
-    prev_op.forced_output_quantization = op.forced_output_quantization
-    if op.activation_lut is not None:
-        prev_op.set_activation_lut(op.activation_lut)
-    # Bypass op
-    prev_op.set_output_tensor(ofm)
-    DebugDatabase.add_optimised(op, prev_op)
-    return op
-
-
-def _leading_pad_ok(leading_pad, stride, kernel_size):
-    # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
-    # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
-    max_size = kernel_size // 2
-    return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
-
-
-def replace_pad_by_hw_pad(op: Operation, arch, nng):
-    """
-    Tries to completely remove a PAD operator by using hardware padding.
-    E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
-    is rewritten such that the PAD is removed, and the CONV uses SAME padding.
-    Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
-    if both operations can be run on the NPU.
-    This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
-    """
-    if (
-        (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
-        and op.run_on_npu
-        and op.attrs["padding"] == Padding.VALID
-    ):
-        pad_op = op.ifm.ops[0]
-        if pad_op.type != Op.Pad or not pad_op.run_on_npu:
-            return op
-        if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
-            return op
-        top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
-        k = op.kernel
-        k_w, k_h = k.dilated_wh()
-
-        # Check if the PAD operator can be replaced by hardware padding
-        if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
-            # Too much padding, it would require hardware padding to actually insert zeros
-            return op
-        if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
-            return op
-
-        if op.type.is_avgpool_op():
-            # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
-            for pad, k_size in (
-                (left, k_w),
-                (right, k_w),
-                (top, k_h),
-                (bottom, k_h),
-            ):
-                if pad not in (0, k_size // 2):
-                    return op
-            # Average pool is converted to depthwise, because NPU average pool + same padding
-            # has a special implementation that is different from PAD followed by average pool with
-            # valid padding.
-            k_w, k_h = op.kernel.width, op.kernel.height
-            ifm = op.ifm
-            # Remember other inputs
-            other_inputs = op.inputs[1:]
-            # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
-            quantization = QuantizationParameters(0.0, 255.0)
-            quantization.scale_f32 = 1.0 / (k_w * k_h)
-            quantization.zero_point = 0
-            shape = [k_h, k_w, 1, op.ofm.shape[-1]]
-            weights = np.full(shape, 1)
-
-            weight_tens = create_const_tensor(
-                op.name + "_weights",
-                shape,
-                op.ifm.dtype,
-                weights,
-                np.uint8,
-                purpose=TensorPurpose.Weights,
-                quantization=quantization,
-            )
-            weight_tens.quant_values = weights
-            op.type = Op.DepthwiseConv2DBias
-            op.inputs = []
-            op.add_input_tensor(ifm)
-            op.add_input_tensor(weight_tens)
-            # Add bias tensor, all biases set to 0
-            op.inputs.append(None)
-            fixup_bias_tensors(op, arch, nng)
-            # Add other inputs
-            op.inputs.extend(other_inputs)
-            op.rounding_mode = NpuRoundingMode.NATURAL
-
-        # Bypass the PAD operator
-        op.set_input_tensor(pad_op.ifm, 0)
-        # Adjust the padding attributes of the convolution operator
-        op.attrs["padding"] = Padding.EXPLICIT
-        op.attrs["explicit_padding"] = (top, left, bottom, right)
-        op.set_ifm_ofm_shapes()
-    return op
-
-
-def convert_pad(op: Operation, arch, nng):
-    """
-    Rewrites PAD operator to an average pool that copies the IFM to the OFM
-    + up to 4 average pool operators that fill the OFM with zeros at the borders.
-    This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
-    """
-    if op.type != Op.Pad or not op.run_on_npu:
-        return op
-    top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
-
-    ifm = op.ifm
-    assert ifm is not None
-    ifm_shape = Shape4D(ifm.shape)
-    ofm = op.ofm
-    assert ofm is not None
-    ofm.ops = []
-    ofm_shape = op.ofm_shapes[0]
-
-    # Average pool op that copies IFM to the right place inside the OFM
-    shp0 = Shape4D(0, 0, 0, 0)
-    shp_top = shp0.with_height(top)
-    avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
-    avgpool_op.activation = op.activation
-    quant = ofm.quantization
-    pad_value = quant.zero_point
-    # Add operations that fill the borders of the OFM
-    if top > 0:
-        shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
-        zero_tens = create_const_tensor(
-            op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
-        )
-        # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
-        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
-        create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
-    if bottom > 0:
-        shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
-        zero_tens = create_const_tensor(
-            op.name + "_bottom",
-            shape.as_list(),
-            ofm.dtype,
-            shape.elements() * [pad_value],
-            np.uint8,
-            quantization=quant,
-        )
-        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
-        create_avg_pool_for_concat(
-            op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
-        )
-    if left > 0:
-        shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
-        zero_tens = create_const_tensor(
-            op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
-        )
-        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
-        create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
-    if right > 0:
-        shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
-        zero_tens = create_const_tensor(
-            op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
-        )
-        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
-        create_avg_pool_for_concat(
-            op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
-        )
-
-    op.type = Op.ConcatTFLite
-    return avgpool_op
-
-
-def add_attrs_to_resizebilinear(op, arch, nng):
-    if op.type == Op.ResizeBilinear and op.run_on_npu:
-        input_tensor = op.inputs[0]
-        input_shape = op.ifm_shapes[0]
-        upscaled_height = input_shape.height * 2
-        upscaled_width = input_shape.width * 2
-        out_shape = op.ofm_shapes[0]
-        if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
-            # this means the output is supposed to be a x2 upscale,
-            # so we need to do SAME padding
-            op.attrs["padding"] = Padding.SAME
-        elif (
-            op.attrs["align_corners"]
-            and out_shape.height == (upscaled_height - 1)
-            and out_shape.width == (upscaled_width - 1)
-        ):
-            # here we can just run the avg pool without padding and
-            # produce a (M * 2 - 1, N * 2 - 1) sized output
-            op.attrs["padding"] = Padding.VALID
-        else:
-            return op
-        input_tensor.resampling_mode = resampling_mode.NEAREST
-        op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
-    return op
-
-
-def fixup_bias_tensors(op, arch, nng):
-    if op.type.needs_bias() and op.bias is None:
-        # Op has no bias, add bias tensor filled with zeros
-        nr_biases = op.inputs[1].shape[-1]
-        bias_values = [0] * nr_biases
-        bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
-        bias_tensor.quant_values = bias_tensor.values
-        op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
-
-    return op
-
-
-def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
-    if op.type == Op.Mean and op.run_on_npu:
-        keep_dims = op.attrs.get("keep_dims", False)
-        inp, axis = op.inputs
-        shape = inp.shape
-        dims = len(shape)
-
-        # Height and width axes have different index depending on dimensions
-        if axis.shape == [] or axis.shape[0] == 1:  # single axis
-            axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
-            if dims in (2, 3):
-                if axis == 0:
-                    h, w = shape[axis], 1
-                else:
-                    h, w = 1, shape[axis]
-            else:
-                if axis == 1:
-                    h, w = shape[axis], 1
-                else:
-                    h, w = 1, shape[axis]
-        else:  # multiple axes
-            axis = sorted(axis.values)
-            h, w = [shape[i] for i in axis]
-
-        # Set necessary depthwise attributes
-        op.attrs.update(
-            {
-                "padding": Padding.VALID,
-                "stride_h": 1,
-                "stride_w": 1,
-                "strides": (1, 1, 1, 1),
-                "depth_multiplier": 1,
-                "channel_multiplier": 1,
-                "dilation_h_factor": 1,
-                "dilation_w_factor": 1,
-                "dilation": (1, 1, 1, 1),
-            }
-        )
-        # Change op type
-        op.type = Op.DepthwiseConv2DBias
-        # Set IFM/OFM shapes after changing op type
-        op.set_ifm_ofm_shapes()
-
-        weight_scale, bias = 1, None
-        ofmq, ifmq = op.ofm.quantization, inp.quantization
-        # Set rounding mode, scaling and zero point based on which reference implementation to match
-        if len(shape) == 4 and axis == [1, 2] and keep_dims:
-            if inp.dtype == DataType.uint8:
-                # This attribute means a different scaling calculation is used in order to match reference
-                op.low_precision_scaling = True
-                weight_scale = h * w
-                # Set zero points to 0 as they will be adjusted for with bias term
-                foq = ofmq.clone()
-                foq.zero_point = 0
-                fiq = ifmq.clone()
-                fiq.zero_point = 0
-                op.forced_input_quantization = fiq
-                bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
-                # If the bias term is outside uint8 range, we need an Add op to apply it.
-                if bias_term < 0 or bias_term > 255:
-                    intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
-                    # Bias term has higher bitness (i32) than input/output (u8).
-                    # 16 bits is enough since the bias is added/subtracted from a u8 value,
-                    # the bias can only effectively assume values in the range [-255, 255].
-                    intermediate.dtype = DataType.int16
-                    intermediate.quantization.zero_point = 0
-                    add_op = Operation(Op.Add, op.name + "_bias")
-                    add_op.forced_output_quantization = foq
-                    add_op.add_input_tensor(intermediate)
-                    quant = QuantizationParameters()
-                    quant.zero_point = 0
-                    bias_term_tens = create_const_tensor(
-                        op.name + "_bias",
-                        [1, 1, 1, 1],
-                        DataType.int16,
-                        [bias_term],
-                        np.int16,
-                        quantization=quant,
-                        quant_value_dtype=np.int16,
-                    )
-                    add_op.add_input_tensor(bias_term_tens)
-                    add_op.set_output_tensor(op.ofm)
-                    add_op.set_ifm_ofm_shapes()
-                    add_op.activation = op.activation
-                    op.activation = None
-                    op.set_output_tensor(intermediate)
-                    op.set_ifm_ofm_shapes()
-                # If not, we can just do it with the OFM zero point.
-                else:
-                    foq.zero_point = bias_term
-                    op.forced_output_quantization = foq
-            else:
-                assert inp.dtype == DataType.int8
-                # Use a depthwise to calculate the sum,
-                # followed by a multiplication with 1/N to get the MEAN
-                weight_scale = 1
-                intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
-                intermediate.dtype = DataType.int16
-                mul_op = Operation(Op.Mul, op.name + "_mul")
-                mul_op.add_input_tensor(intermediate)
-                # Create scalar containing 1/N
-                quant = QuantizationParameters()
-                quant.zero_point = 0
-                # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
-                # while rounding mode NATURAL would round this to -1.
-                # This can only occur if N is even, and can be emulated by
-                # multiplying with a number that is slightly smaller than 1/N.
-                # It must be so small that other roundings are not affected;
-                # the calculated value is based on worst case,
-                # which is sum 256 * N (the maximum sum that can occur with int8)
-                n = int(h * w)
-                eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
-                quant.scale_f32 = 1 / (n - eps)
-                scalar = create_const_tensor(
-                    op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
-                )
-                mul_op.add_input_tensor(scalar)
-                mul_op.set_output_tensor(op.ofm)
-                mul_op.set_ifm_ofm_shapes()
-                mul_op.rounding_mode = NpuRoundingMode.NATURAL
-                mul_op.activation = op.activation
-                op.activation = None
-                op.set_output_tensor(intermediate)
-                op.set_ifm_ofm_shapes()
-        elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
-            # Here we can just use a simple AvgPool with truncating rounding,
-            # as we're emulating simple integer division.
-            op.rounding_mode = NpuRoundingMode.TRUNCATE
-            op.type = Op.AvgPool
-            op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
-        else:
-            op.rounding_mode = NpuRoundingMode.NATURAL
-            weight_scale = 1 / (h * w)
-            # Input zero point is adjusted after mean calculation, so we emulate that with a bias
-            bias = -ifmq.zero_point * h * w
-            fiq = ifmq.clone()
-            fiq.zero_point = 0
-            op.forced_input_quantization = fiq
-
-        # Change dimensions to 4
-        if dims < 4:
-            shape = [1] + shape
-            if dims == 2:
-                shape += [1]
-
-        # If height is greater than max kernel height, reshape to from HxW to 1x(HxW)
-        if h > 64:
-            shape = [shape[0], 1, h * w, shape[3]]
-            op.ifm_shapes[0] = Shape4D(shape)
-            if h > 256 and op.type == Op.AvgPool:
-                op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
-
-        # If the AvgPool version is used, we don't need to do anything else
-        if op.type == Op.AvgPool:
-            return op
-
-        # Make unit weight tensor quantization
-        weight_quant = ifmq.clone()
-        weight_quant.min = 0
-        weight_quant.max = 255
-        weight_quant.scale_f32 = weight_scale
-        weight_quant.zero_point = 0
-
-        # Set weight shape to [H,W,C,B]
-        weight_shape = shape[1:4] + [shape[0]]
-        # Add unit weight tensor
-        op.set_input_tensor(
-            create_const_tensor(
-                "weights",
-                weight_shape,
-                inp.dtype,
-                np.ones(weight_shape),
-                value_dtype=np.uint8,
-                quantization=weight_quant,
-            ),
-            1,
-        )
-        op.weights.quant_values = np.reshape(op.inputs[1].quant_values, weight_shape)
-
-        # Add None bias tensor
-        op.inputs.append(None)
-        # Add bias tensor
-        if bias:
-            bias_shape = [shape[-1]]
-            op.set_input_tensor(
-                create_const_tensor(
-                    "bias",
-                    bias_shape,
-                    inp.dtype,
-                    np.ones(bias_shape) * bias,
-                    value_dtype=np.int32,
-                    quant_value_dtype=np.int32,
-                    quantization=None,
-                ),
-                2,
-            )
-
-    return op
-
-
-def supported_operator_check(op, arch, nng):
-    op.run_on_npu = arch.supported_operators.is_operator_supported(op)
-    return op
-
-
-def _record_optimised(op, arch):
-    if op.type != Op.Const:
-        DebugDatabase.add_optimised(op, op)
-
-
-def optimise_graph_a(nng, arch, verbose_graph=False):
+def optimise_graph(nng, arch, network_type, verbose_graph=False):
     if verbose_graph:
         nng.print_graph("Before Graph Optimization")
 
-    pre_process_list = [
-        supported_operator_check,
-        set_ifm_ofm_op_shapes,
-        # TODO: memory-only Op removal
-    ]
-
-    for idx, sg in enumerate(nng.subgraphs):
-        # rewrite graph pass
-        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
-        )
-
-    # Handle Concat Ops
-    for idx, sg in enumerate(nng.subgraphs):
-        # rewrite graph pass
-        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
-        sg.refresh_after_modification()
-
-    # Handle Split Ops
-    for idx, sg in enumerate(nng.subgraphs):
-        # rewrite graph pass
-        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng,
-            sg,
-            arch,
-            [],
-            [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
-            rewrite_unsupported=False,
-        )
-
-    for idx, sg in enumerate(nng.subgraphs):
-        # rewrite graph pass
-        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False,
-        )
-
-    # Handle sg input output
-    for idx, sg in enumerate(nng.subgraphs):
-        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
-        )
-
-    # Removal of reshapes
-    for sg in nng.subgraphs:
-        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
-        sg.refresh_after_modification()
-
-    op_rewrite_list = [
-        set_tensor_equivalence,
-        convert_mean_to_depthwise_conv_or_avgpool,
-        convert_depthwise_to_conv,
-        convert_conv_to_fc,
-        convert_softmax,
-        optimise_strided_conv,
-        convert_hardswish_to_lut,
-        rewrite_fully_connected_input,
-        convert_batched_fc_shape,
-        fixup_conv2d_backprop,
-        fixup_relus_with_differing_ifm_ofm_scaling,
-        fixup_elementwise_with_scalars,  # TODO Move to early stage?
-        reorder_depthwise_weights,
-        fixup_resizebilinear,
-        fixup_bias_tensors,
-        convert_mul_max_to_abs_or_lrelu,
-        convert_lrelu,
-        convert_tanh_sigmoid_to_lut,
-        replace_pad_by_hw_pad,
-    ]
-
-    for idx, sg in enumerate(nng.subgraphs):
-        # rewrite graph pass
-        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
-        )
-
-    for idx, sg in enumerate(nng.subgraphs):
-        # remove passthrough tensors and attempt further optimizations
-        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng,
-            sg,
-            arch,
-            [remove_passthrough_tensor],
-            [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
-        )
-
-    # Removal of SplitSliceRead, need to be done after optimisation has been performed,
-    # since ifm/ofm_shapes are of importance to this function
-    for sg in nng.subgraphs:
-        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
-        sg.refresh_after_modification()
-
-    # Check Tensor Format restrictions
-    for sg in nng.subgraphs:
-        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [check_format_restrictions], [])
-        sg.refresh_after_modification()
+    if network_type == NetworkType.TFLite:
+        # TensorFlow Lite graph optimization
+        nng = tflite_optimise_graph(nng, arch)
+    else:
+        # TOSA graph optimization
+        nng = tosa_optimise_graph(nng, arch)
 
     # Post-optimisation operator debug tracing, and checking that no undesired reshapes are left in the graph
     for sg in nng.subgraphs:
-        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [check_reshapes, _record_optimised])
+        rewrite_graph.visit_graph_post_order(
+            sg.output_tensors, arch, [check_format_restrictions], [check_reshapes, record_optimised]
+        )
 
     if verbose_graph:
         nng.print_graph("After Graph Optimization")
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
new file mode 100644
index 0000000..0b44b8f
--- /dev/null
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -0,0 +1,168 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Common functions and definitions used during the graph optimization.
+from .data_type import DataType
+from .debug_database import DebugDatabase
+from .errors import VelaError
+from .operation import Op
+from .shape4d import Shape4D
+from .tensor import check_quantized_tens_scaling_equal
+
+
+memory_only_ops = (Op.Reshape,)
+
+
+def _avoid_nhcwb16_for_concat(tens):
+    # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
+    # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
+    # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
+    # and those addresses are always 16 byte aligned due to the NHCWB16 format.
+    return any(op.write_offset.depth % 16 != 0 for op in tens.ops if op.write_offset is not None)
+
+
+def _avoid_nhcwb16_for_split(tens):
+    # If read offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
+    for cons_op in tens.consumer_list:
+        if cons_op.ifm == tens:
+            read_offset = cons_op.read_offsets[0]
+        elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens:
+            read_offset = cons_op.read_offsets[1]
+        else:
+            assert False
+        if read_offset is not None and (read_offset[-1] % 16) != 0:
+            return True
+    return False
+
+
+def _avoid_nhcwb16_for_shapes(tens):
+    # check all producers/consumers to see if any op shape is preventing NHCWB16
+    for cons_op in tens.consumer_list:
+        if cons_op.ifm == tens:
+            cons_op_shape = cons_op.ifm_shapes[0]
+        elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens:
+            cons_op_shape = cons_op.ifm_shapes[1]
+        else:
+            assert False
+        if Shape4D(tens.shape) != cons_op_shape:
+            return True
+
+    for prod_op in tens.ops:
+        if Shape4D(tens.shape) != prod_op.ofm_shapes[0]:
+            return True
+
+    return False
+
+
+# Check if non linear format can be used
+def check_format_restrictions(tens, arch):
+    if len(tens.ops) < 1:
+        return
+    if tens.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const) or any(
+        cons is None for cons in tens.consumer_list
+    ):
+        return
+
+    # Check if any of the producers/consumers is run on CPU
+    if not all(cons.run_on_npu for cons in tens.consumer_list):
+        return
+    if not all(prod.run_on_npu for prod in tens.ops):
+        return
+
+    # "Concat" ofm exception:
+    if _avoid_nhcwb16_for_concat(tens):
+        return
+
+    # "Split" ifm exception:
+    if _avoid_nhcwb16_for_split(tens):
+        return
+
+    # Shapes checking: check all producers/consumers are NHCWB16 compatible with tens.shape
+    if _avoid_nhcwb16_for_shapes(tens):
+        return
+
+    for op in tens.consumer_list:
+        if op.type == Op.ReduceSum and tens.dtype == DataType.int32:
+            return
+        if op.type == Op.Reshape:
+            # Using NHCWB16 format for a no-op reshape is only an option if subsequent
+            # consumers do not also need to perform a reshape or if the OFM is going to
+            # be processed by CPU operations. No-op reshape consumers with empty lists
+            # (those that have no consumers, or null-consumers used as list terminators)
+            # must use normal NHWC output.
+
+            def incompatible_consumers(oper):
+                if oper and oper.type == Op.Reshape:
+                    for consumer in oper.outputs[0].consumer_list:
+                        yield from incompatible_consumers(consumer)
+                yield not oper or not oper.run_on_npu
+
+            if not any(incompatible_consumers(op)):
+
+                def get_rewrites(oper):
+                    if oper and oper.type == Op.Reshape:
+                        for consumer in oper.outputs[0].consumer_list:
+                            yield from get_rewrites(consumer)
+                        yield oper
+
+                # Detect no-op reshapes by comparing their full input and output tensor shapes.
+                inshape = op.ifm_shapes[0]
+                compatible_shape = [(inshape == oper.ofm_shapes[0]) for oper in get_rewrites(op)]
+                if not (compatible_shape and all(compatible_shape)):
+                    return
+            else:
+                return
+
+    tens.needs_linear_format = False
+
+
+def needed_total_padding(input_size, stride, filter_size):
+    out_size = (input_size + stride - 1) // stride
+    needed_input = (out_size - 1) * stride + filter_size
+    total_padding = max(0, needed_input - input_size)
+    return total_padding
+
+
+# Set input/output tensor equivalence to the same id for memory operations
+def set_tensor_equivalence(op, arch, nng):
+    if op.type in memory_only_ops:
+        eid = op.outputs[0].equivalence_id
+        for inp in op.inputs:
+            inp.equivalence_id = eid
+    return op
+
+
+def set_ifm_ofm_op_shapes(op, arch, nng):
+    if op.run_on_npu and op.type.needs_shapes():
+        if op.ifm_shapes or op.ofm_shapes:
+            # Shapes already set
+            return op
+        op.set_ifm_ofm_shapes()
+    return op
+
+
+def check_reshapes(op, arch):
+    if op.run_on_npu and op.type == Op.Reshape:
+        ofm = op.ofm
+
+        if check_quantized_tens_scaling_equal(op.ifm, ofm):
+            # Reshape should have been removed
+            raise VelaError(f"Reshape op {op} expected to have been removed, still remains")
+
+
+def record_optimised(op, arch):
+    if op.type != Op.Const:
+        DebugDatabase.add_optimised(op, op)
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index 6fcf80c..b3ea9d4 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -109,7 +109,9 @@
     # Create activation function if needed
     for op in ps.ops:
         if op.type.is_relu_op() or op.type in (Op.Tanh, Op.Sigmoid):
-            ps.primary_op.activation = create_activation_function(op.type)
+            ps.primary_op.activation = create_activation_function(
+                op.type, min=op.attrs.get("min", None), max=op.attrs.get("max", None)
+            )
 
     # Generate commands for the Op that produces this Op's IFM, if applicable
     if cascade_info is None or cascade_info.start == sched_op.index:
diff --git a/ethosu/vela/model_reader.py b/ethosu/vela/model_reader.py
index bb49b64..f48645d 100644
--- a/ethosu/vela/model_reader.py
+++ b/ethosu/vela/model_reader.py
@@ -16,7 +16,9 @@
 # Description:
 # Dispatcher for reading a neural network model.
 from . import tflite_reader
+from . import tosa_reader
 from .errors import InputFileError
+from .nn_graph import NetworkType
 
 
 class ModelReaderOptions:
@@ -37,12 +39,33 @@
             output_node_names = []
         if initialisation_nodes is None:
             initialisation_nodes = []
-        return tflite_reader.read_tflite(
-            fname,
-            options.batch_size,
-            feed_dict=feed_dict,
-            output_node_names=output_node_names,
-            initialisation_nodes=initialisation_nodes,
+        return (
+            tflite_reader.read_tflite(
+                fname,
+                options.batch_size,
+                feed_dict=feed_dict,
+                output_node_names=output_node_names,
+                initialisation_nodes=initialisation_nodes,
+            ),
+            NetworkType.TFLite,
+        )
+    elif fname.endswith(".tosa"):
+        if feed_dict is None:
+            feed_dict = {}
+        if output_node_names is None:
+            output_node_names = []
+        if initialisation_nodes is None:
+            initialisation_nodes = []
+
+        return (
+            tosa_reader.read_tosa(
+                fname,
+                options.batch_size,
+                feed_dict=feed_dict,
+                output_node_names=output_node_names,
+                initialisation_nodes=initialisation_nodes,
+            ),
+            NetworkType.TOSA,
         )
     else:
-        raise InputFileError(fname, "Unsupported file extension. Only .tflite files are supported")
+        raise InputFileError(fname, "Unsupported file extension. Only .tflite and .tosa files are supported")
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index 7dc2d72..97afde3 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -44,6 +44,11 @@
         return self.name
 
 
+class NetworkType(enum.Enum):
+    TFLite = 1
+    TOSA = 2
+
+
 class Pass:
     def __init__(self, name, placement, is_element_wise, npu_block_type):
         self.inputs = []
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 6bd955d..0558e52 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -238,6 +238,8 @@
     Relu = OperatorInfo(indices=IFM_INDICES)
     Relu6 = OperatorInfo(indices=IFM_INDICES)
     ReluN1To1 = OperatorInfo(indices=IFM_INDICES)
+    ReluN = OperatorInfo(indices=IFM_INDICES)  # TOSA specific
+    Rescale = OperatorInfo(indices=IFM_INDICES)  # TOSA specific
     RescaleAdd = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
     Reshape = OperatorInfo(indices=IFM_INDICES)
     ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
@@ -321,7 +323,7 @@
         return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary
 
     def is_relu_op(self):
-        return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Clip)
+        return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.ReluN, Op.Clip)
 
     def is_activation_op(self):
         return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT, Op.HardSwish)
@@ -374,7 +376,20 @@
         return res
 
 
-def create_activation_function(op_type: Op) -> ActivationFunction:
+class ExplicitScaling:
+    """Explicit scaling parameters"""
+
+    def __init__(self, per_channel, shift, multiplier):
+        self.per_channel = per_channel
+        self.shift = shift
+        self.multiplier = multiplier
+
+    def clone(self):
+        res = copy.copy(self)
+        return res
+
+
+def create_activation_function(op_type: Op, min=None, max=None) -> ActivationFunction:
     """Creates activation function with min/max depending on op_type"""
     act = ActivationFunction(op_type)
     if op_type == Op.Relu:
@@ -393,6 +408,15 @@
         act.max = 1.0
     elif op_type == Op.HardSwish:
         act.min = 0.0
+    if op_type == Op.Clip:
+        assert min is not None and max is not None
+        act.min = min
+        act.max = max
+    elif op_type == Op.ReluN:
+        assert max is not None
+        act.min = 0.0
+        act.max = max
+
     return act
 
 
@@ -436,6 +460,7 @@
         "read_offsets",
         "read_shapes",
         "rounding_mode",
+        "explicit_scaling",
         "low_precision_scaling",
         "write_offset",
         "write_shape",
@@ -470,6 +495,8 @@
         self.read_offsets: List[Shape4D] = [None, None]  # offset for [ifm, ifm2]
         self.read_shapes: List[Shape4D] = [None, None]  # read shape for [ifm, ifm2]
         self.rounding_mode: Optional[NpuRoundingMode] = None
+        # Rescale op in TOSA supplies explicit multiplier and shift values
+        self.explicit_scaling: Optional[ExplicitScaling] = None
         # The Mean operator (implemented as a depthwise convolution) requires scaling
         # to be calculated differently in one case. In that case, this is set to True.
         self.low_precision_scaling = False
@@ -498,6 +525,7 @@
         res.read_offsets = list(self.read_offsets)
         res.read_shapes = list(self.read_shapes)
         res.rounding_mode = self.rounding_mode
+        res.explicit_scaling = self.explicit_scaling
         res.low_precision_scaling = self.low_precision_scaling
 
         return res
diff --git a/ethosu/vela/reader_util.py b/ethosu/vela/reader_util.py
new file mode 100644
index 0000000..5b454b5
--- /dev/null
+++ b/ethosu/vela/reader_util.py
@@ -0,0 +1,60 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Utlity function for reading .tosa and .tflite files
+from .operation import Op
+from .operation import Operation
+
+
+def decode_str(s):
+    if s is None:
+        return ""
+    return s.decode("utf-8")
+
+
+def clone_and_reshape_tensor(src_tens, reorder, set_unique):
+    tens = src_tens.clone("_reshape", set_unique)
+    tens.shape = [src_tens.shape[idx] for idx in reorder]
+    tens.bandwidth_shape = tens.shape
+    tens.storage_shape = tens.shape
+
+    if tens.values is not None:
+        tens.values = tens.values.transpose(reorder)
+
+    if tens.quant_values is not None:
+        tens.quant_values = tens.quant_values.transpose(reorder)
+
+    op = Operation(Op.Const, tens.name)
+    op.set_output_tensor(tens)
+    return tens
+
+
+# Fix up tensors without operations. Generate either Placeholder or Constant ops
+def fixup_tensors(input_tensors, tensors):
+    for tens in input_tensors:
+        if len(tens.ops) and tens.ops[0].type == Op.Const:
+            break
+
+        if tens.ops != []:
+            tens.error("This subgraph input tensor has unexpected driving operators.")
+
+        op = Operation(Op.Placeholder, tens.name)
+        op.set_output_tensor(tens)
+
+    for tens in tensors:
+        if not tens.ops:
+            op = Operation(Op.Const, tens.name)
+            op.set_output_tensor(tens)
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index dfa2719..c993da1 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -25,31 +25,19 @@
 from .operation import get_slice_offsets
 from .operation import Op
 from .operation import Padding
+from .supported_operators_util import docstring_format_args
+from .supported_operators_util import list_formatter
 from .tensor import check_quantized_tens_scaling_equal
 from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN
 from .tflite_mapping import optype_to_builtintype
 
 
-# Custom decorator function to allow formatting docstrings containing "{}"
-def docstring_format_args(args):
-    def docstring(func):
-        func.__doc__ = func.__doc__.format(*args)
-        return func
-
-    return docstring
-
-
-def _list_formatter(arg):
-    # Order and join into a string representation
-    return ", ".join(sorted(map(str, arg)))
-
-
 def _optype_formatter(op_list):
     # Convert internal op types to external names
     output = map(optype_to_builtintype, op_list)
     # Remove UNKNOWNs
     output = (x for x in output if x is not BUILTIN_OPERATOR_UNKNOWN)
-    return _list_formatter(output)
+    return list_formatter(output)
 
 
 class SupportedOperators:
@@ -88,7 +76,8 @@
     supported_int32_tensor_ops = (
         set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
     )
-    relu_ops = Op.op_set(Op.is_relu_op)
+
+    relu_ops = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Clip,))
     activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax, Op.HardSwish))
     npu_post_ops = (
         # activation functions
@@ -354,7 +343,7 @@
         return valid, ", ".join(extra)
 
     @classmethod
-    @docstring_format_args([_list_formatter(supported_op_dtypes)])
+    @docstring_format_args([list_formatter(supported_op_dtypes)])
     def constraint_tens_dtype(cls, op):
         "Tensors must be of type: {}"
         valid = True
@@ -463,7 +452,7 @@
         return res
 
     @classmethod
-    @docstring_format_args([_list_formatter(supported_faf_dtypes)])
+    @docstring_format_args([list_formatter(supported_faf_dtypes)])
     def constraint_faf_type(cls, op):
         "If a fused activation function is present, the Output tensor must be one of type: {}"
         if op.activation is None:
@@ -549,7 +538,7 @@
         return valid, f"Tensor '{weights.name}' has the sum of weights: {limit}"
 
     @classmethod
-    @docstring_format_args([_list_formatter(supported_bias_dtypes)])
+    @docstring_format_args([list_formatter(supported_bias_dtypes)])
     def constraint_bias_type(cls, op):
         "Optional Bias tensor must be of type: {}"
         bias = op.bias
@@ -832,7 +821,7 @@
         return valid, f"The pad tensor has the shape: {op.inputs[1].shape}"
 
     @classmethod
-    @docstring_format_args([_list_formatter(supported_pad_dtypes)])
+    @docstring_format_args([list_formatter(supported_pad_dtypes)])
     def constraint_pad_type(cls, op):
         "Pad tensor must be of type: {}"
         pad_tensor = op.inputs[1]
diff --git a/ethosu/vela/supported_operators_util.py b/ethosu/vela/supported_operators_util.py
new file mode 100644
index 0000000..24fe72c
--- /dev/null
+++ b/ethosu/vela/supported_operators_util.py
@@ -0,0 +1,31 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Utility functions used in supported operator checking
+
+
+def list_formatter(arg):
+    # Order and join into a string representation
+    return ", ".join(sorted(map(str, arg)))
+
+
+# Custom decorator function to allow formatting docstrings containing "{}"
+def docstring_format_args(args):
+    def docstring(func):
+        func.__doc__ = func.__doc__.format(*args)
+        return func
+
+    return docstring
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index 83a3dda..b37bac8 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -15,17 +15,14 @@
 # limitations under the License.
 #
 # Description:
-# Unit tests for graph_optimiser
+# Unit tests for tflite_graph_optimiser
 import numpy as np
 import pytest
 
 from ethosu.vela.data_type import DataType
-from ethosu.vela.graph_optimiser import calc_explicit_padding
-from ethosu.vela.graph_optimiser import convert_batched_fc_shape
-from ethosu.vela.graph_optimiser import optimise_graph_a
-from ethosu.vela.graph_optimiser import replace_pad_by_hw_pad
-from ethosu.vela.graph_optimiser import rewrite_fully_connected_input
+from ethosu.vela.graph_optimiser import optimise_graph
 from ethosu.vela.nn_graph import Graph
+from ethosu.vela.nn_graph import NetworkType
 from ethosu.vela.operation import Op
 from ethosu.vela.operation import Padding
 from ethosu.vela.rewrite_graph import verify_graph_health
@@ -33,6 +30,10 @@
 from ethosu.vela.tensor import Shape4D
 from ethosu.vela.tensor import Tensor
 from ethosu.vela.test import testutil
+from ethosu.vela.tflite_graph_optimiser import calc_explicit_padding
+from ethosu.vela.tflite_graph_optimiser import convert_batched_fc_shape
+from ethosu.vela.tflite_graph_optimiser import replace_pad_by_hw_pad
+from ethosu.vela.tflite_graph_optimiser import rewrite_fully_connected_input
 
 
 def test_convert_batched_fc():
@@ -300,7 +301,7 @@
     pool_op.run_on_npu = True
     nng = testutil.create_graph([pad_op, pool_op])
     arch = testutil.create_arch()
-    nng = optimise_graph_a(nng, arch)
+    nng = optimise_graph(nng, arch, NetworkType.TFLite)
     sg = nng.subgraphs[0]
     all_ops = sg.get_all_ops()
     print("all_ops: ", all_ops)
@@ -382,7 +383,7 @@
     nng, reshape1_op, conv2d_op, reshape2_op = setup_network()
     arch = testutil.create_arch()
     assert verify_graph_health(nng)
-    nng = optimise_graph_a(nng, arch)
+    nng = optimise_graph(nng, arch, NetworkType.TFLite)
     assert verify_graph_health(nng)
 
     # Test2 reshape1 with different quantisation, this Reshape op is expected to remain
@@ -393,5 +394,5 @@
     quant_zp32.zero_point = 32
     reshape1_op.ofm.quantization = quant_zp32
     assert verify_graph_health(nng)
-    nng = optimise_graph_a(nng, arch)
+    nng = optimise_graph(nng, arch, NetworkType.TFLite)
     assert verify_graph_health(nng)
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
new file mode 100644
index 0000000..3d9eeb8
--- /dev/null
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -0,0 +1,1633 @@
+# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
+# to do the traversal of the graph.
+import math
+import uuid
+from typing import Tuple
+
+import numpy as np
+
+from . import fp_math
+from . import lut
+from . import rewrite_graph
+from . import scaling
+from .api import NpuRoundingMode
+from .data_type import DataType
+from .debug_database import DebugDatabase
+from .errors import UnsupportedFeatureError
+from .ethos_u55_regs.ethos_u55_regs import resampling_mode
+from .graph_optimiser_util import needed_total_padding
+from .graph_optimiser_util import set_ifm_ofm_op_shapes
+from .graph_optimiser_util import set_tensor_equivalence
+from .numeric_util import clamp_sigmoid
+from .numeric_util import full_shape
+from .numeric_util import round_away_zero
+from .operation import create_activation_function
+from .operation import NpuBlockType
+from .operation import Op
+from .operation import Operation
+from .operation import Padding
+from .operation_util import create_avgpool_nop
+from .operation_util import get_pad_values_from_input
+from .shape4d import Shape4D
+from .softmax import SoftMax
+from .tensor import check_quantized_tens_scaling_equal
+from .tensor import create_const_tensor
+from .tensor import create_equivalence_id
+from .tensor import QuantizationParameters
+from .tensor import Tensor
+from .tensor import TensorPurpose
+from .tflite_mapping import optype_to_builtintype
+
+passthrough_nodes = (Op.Identity,)
+
+
+def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
+    """Creates an average pool for the given concat op/input feature map"""
+    ofm = concat_op.ofm
+    avgpool_op = create_avgpool_nop(name)
+    avgpool_op.inputs = [ifm]
+    avgpool_op.outputs = [ofm]
+
+    avgpool_op.write_offset = write_offset
+    avgpool_op.write_shape = ifm_shape
+    ofm.ops.append(avgpool_op)
+    DebugDatabase.add_optimised(concat_op, avgpool_op)
+    avgpool_op.ifm_shapes.append(ifm_shape)
+    avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
+    avgpool_op.memory_function = Op.ConcatSliceWrite
+    return avgpool_op
+
+
+def remove_passthrough_tensor(tens, arch, nng):
+    if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
+        assert len(tens.ops[0].inputs) == 1
+        tens = tens.ops[0].inputs[0]
+    return tens
+
+
+def rewrite_concat_ops(op, arch):
+    if not op.run_on_npu or not op.type.is_concat_op():
+        return
+
+    axis_4D = 0
+    ofm = op.ofm
+    ofm.ops = []
+    offset = 0
+
+    unfuse_activation_function(op)
+
+    if op.type == Op.Pack:
+        # Pack is also referred to as Stack
+        axis = int(op.attrs["axis"])
+        if axis < 0:  # Convert to positive axis
+            axis = len(op.inputs[0].shape) + 1 + axis
+
+        desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
+
+        axis_4D = axis + (4 - len(desired_shape))
+
+        for idx, inp in enumerate(op.inputs):
+            op.ifm_shapes[idx] = Shape4D(desired_shape)
+        op.type = Op.PackReshaped
+
+    inputs, axis = op.get_concat_inputs_axis()
+    for idx, inp in enumerate(inputs):
+        if op.type != Op.PackReshaped:
+            op.ifm_shapes[idx] = Shape4D(inp.shape)
+            if axis >= 0:
+                axis_4D = axis + (4 - len(inp.shape))
+            else:
+                axis_4D = axis
+        write_offset = [0, 0, 0, 0]
+        write_offset[axis_4D] = offset
+        concat_end = offset + op.ifm_shapes[idx][axis_4D]
+        create_avg_pool_for_concat(
+            op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
+        )
+        offset = concat_end
+    assert ofm.shape[axis] == offset
+
+    return op
+
+
+def rewrite_split_ops(tens, arch, nng):
+
+    if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
+        split_op = tens.ops[0]
+
+        # Not supported so leave it and run on CPU
+        if not split_op.run_on_npu:
+            return tens
+
+        inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
+
+        tens.ops = []
+        new_op = Operation(Op.SplitSliceRead, split_op.name)
+        new_op.inputs = [inp]
+        ofm_shape_idx = 0
+        read_shape = offset_end
+
+        # For Split the offset cannot be extracted from the tensor so it has to
+        # be calculated from the index of the output tensor
+        if axis is not None:
+            # Get the start and end of the split
+            offset_start = [0] * 4
+            axis_4D_list = split_op.attrs.get("split_axis_4D", None)  # Present for UnpackReshaped and some StridedSlice
+            for idx, out in enumerate(outputs):
+                if axis_4D_list is not None:
+                    axis_4D = axis_4D_list[idx]
+                else:
+                    split_op.ofm_shapes[idx] = Shape4D(out.shape)
+                    if axis >= 0:
+                        axis_4D = axis + (4 - len(out.shape))
+                    else:
+                        axis_4D = axis
+
+                if out == tens:
+                    ofm_shape_idx = idx
+                    read_shape = split_op.ofm_shapes[idx]
+                    break
+
+                offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
+
+        new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
+        new_op.read_shapes[0] = read_shape
+        new_op.run_on_npu = True
+        new_op.set_output_tensor(tens)
+        new_op.ifm_shapes.append(Shape4D(inp.shape))
+        new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
+        DebugDatabase.add_optimised(split_op, new_op)
+
+    return tens
+
+
+def remove_SplitSliceRead(op, arch):
+
+    if op.type == Op.SplitSliceRead:
+        # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
+        if (
+            len(op.ofm.consumer_list) == 1
+            and op.ofm.consumer_list[0] is not None
+            and op.ofm.consumer_list[0].run_on_npu
+            and op.ofm.consumer_list[0].type != Op.Reshape
+            and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
+        ):
+            # SplitSliceRead can be performed by tensor consumer
+            cons_op = op.ofm.consumer_list[0]
+            if cons_op.ifm == op.ofm:
+                cons_op.read_offsets[0] = op.read_offsets[0]
+                cons_op.read_shapes[0] = op.read_shapes[0]
+                cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[0])
+                cons_op.ifm_shapes[0] = op.ifm_shapes[0]
+            elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == op.ofm:
+                cons_op.read_offsets[1] = op.read_offsets[0]
+                cons_op.read_shapes[1] = op.read_shapes[0]
+                cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[1])
+                cons_op.ifm_shapes[1] = op.ifm_shapes[0]
+
+            if "skirt" in cons_op.attrs:
+                assert cons_op.attrs["explicit_padding"] == cons_op.attrs["skirt"]
+                cons_op.attrs["skirt"] = None
+                cons_op.attrs["force_padding"] = True
+            op.ofm.consumer_list.remove(cons_op)
+            op.ofm.ops = []
+            op.ifm.consumer_list.remove(op)
+        else:
+            avgpool_op = create_avgpool_nop(op.name + "_avgpool")
+            avgpool_op.add_input_tensor(op.ifm)
+            avgpool_op.outputs = [op.ofm]
+            op.ofm.ops.remove(op)
+            op.ofm.ops.append(avgpool_op)
+            avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
+            avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
+            avgpool_op.read_offsets[0] = op.read_offsets[0]
+            avgpool_op.read_shapes[0] = op.read_shapes[0]
+
+            op.ifm.consumer_list.remove(op)
+            DebugDatabase.add_optimised(op, avgpool_op)
+
+
+def insert_copy_op_after_tens(tens):
+    tens_cons_list_copy = tens.consumer_list.copy()
+
+    # Create a avg_pool nop op with ifm as input
+    copy_tens = tens.clone()
+    copy_op = create_avgpool_nop(tens.name + "_avgpool")
+    copy_op.add_input_tensor(tens)
+    copy_op.set_output_tensor(copy_tens)
+    copy_op.set_ifm_ofm_shapes()
+    copy_op.run_on_npu = True
+
+    # Set copy_ifm consumers
+    for tens_cons in tens_cons_list_copy:
+        if tens_cons is not None:
+            for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
+                if cons_inp == tens:
+                    tens_cons.set_input_tensor(copy_tens, ifm_idx)
+
+    DebugDatabase.add_optimised(tens.ops[0], copy_op)
+
+
+def fix_sg_input_output(op, arch, nng):
+    if not op.run_on_npu or op.type != Op.Reshape:
+        return op
+
+    # For the Reshape operators we want to remove, tensors are removed.
+    # But in order to to do this, they cannot be outputs of the sg,
+    # this need to be fixed prior to the removal.
+    # Solution is to add a avgpool NOP, to maintain the original tensor.
+    # This is also valid when reshape ifm/ofm is produced respectively
+    # consumed by CPU
+
+    # Check if operator ifm/ofm are sg ifm/ofm
+    ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
+    ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
+    ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
+    # Check if ifm/ofm is produced repectivly consumed by CPU
+    ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
+    ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
+
+    if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
+        # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape
+        insert_copy_op_after_tens(op.ifm)
+
+    return op
+
+
+def calc_explicit_padding(input_size, stride, filter_size, pad_before, pad_after) -> Tuple[int, int]:
+    """
+    Based on explicit padding provided in a PAD operation, returns the corresponding hardware padding
+    that provides equivalent results.
+    """
+    total_padding = needed_total_padding(input_size, stride, filter_size)
+    # The top/left padding can be taken as is from the PAD
+    output_pad_before = pad_before
+    # The bottom/right padding might need downward adjustment depending on stride/input size
+    output_pad_after = pad_after
+    while output_pad_after > 0 and output_pad_after % stride != (total_padding - pad_before) % stride:
+        output_pad_after -= 1
+    return output_pad_before, output_pad_after
+
+
+def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
+    k_w, k_h = kernel.dilated_wh()
+    s_x, s_y = kernel.stride
+    ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
+    xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
+    if padding_type == Padding.SAME:
+        left_pad = (xpad + 0) // 2
+        right_pad = (xpad + 1) // 2
+        top_pad = (ypad + 0) // 2
+        bottom_pad = (ypad + 1) // 2
+    elif padding_type == Padding.VALID:
+        left_pad = 0
+        right_pad = 0
+        top_pad = 0
+        bottom_pad = 0
+    elif padding_type == Padding.EXPLICIT:
+        # Padding is specified in a PAD operator which has been bypassed.
+        top, left, bottom, right = explicit_padding
+        top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
+        left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
+    else:
+        raise UnsupportedFeatureError(f"Unknown padding")
+    padding = (top_pad, left_pad, bottom_pad, right_pad)
+    skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
+    return padding, skirt
+
+
+def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
+    kernel_height, kernel_width = kernel_size[0], kernel_size[1]
+    if padding_type == Padding.SAME:
+        ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
+        xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
+        right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
+        bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
+        left_pad = max(kernel_width - 1 - right_pad, 0)
+        top_pad = max(kernel_height - 1 - bottom_pad, 0)
+    elif padding_type == Padding.VALID:
+        right_pad = max(kernel_width - 2, 0)
+        bottom_pad = max(kernel_height - 2, 0)
+        left_pad = kernel_width - 1
+        top_pad = kernel_height - 1
+    else:
+        raise UnsupportedFeatureError(f"Unknown padding")
+    padding = (top_pad, left_pad, bottom_pad, right_pad)
+    skirt = padding
+    return padding, skirt
+
+
+def fixup_conv2d_backprop(op, arch, nng):
+    if op.type == Op.Conv2DBackpropInput:
+        # flip the inputs
+        op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
+        op.type = Op.Conv2DBackpropInputSwitchedBias
+        op.ifm.resampling_mode = resampling_mode.TRANSPOSE
+
+        # Update strides
+        op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
+
+    return op
+
+
+# Convert the op to an elementwise add
+def convert_resizebilinear_1x1_to_add(op):
+    op.type = Op.Add
+    op.name = op.name + "_add"
+    op.attrs["resizebilinear"] = True
+    # Create an input tensor filled with zeros
+    shape = op.ofm_shapes[0].as_list()
+    tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
+    tens.values = np.zeros(shape)
+    tens.quant_values = np.zeros(shape, np.uint8)
+    tens.quantization = QuantizationParameters(0.0, 255.0)
+    tens.quantization.scale_f32 = 1.0
+    tens.quantization.zero_point = 0
+    tens.consumer_list = [op]
+    tens_op = op.inputs[1].ops[0]
+    tens_op.set_output_tensor(tens)
+    # Set the add inputs
+    op.inputs[1] = op.inputs[0]
+    op.inputs[0] = tens
+    op.set_ifm_ofm_shapes()
+
+    return op
+
+
+# Convert ResizeBilinear to a number of 2x2 pool ops
+def convert_resizebilinear_to_2x2_pool(op):
+    count = 0
+    pre_op = op
+    outputs = op.outputs
+
+    op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
+    if op.attrs["align_corners"]:
+        shape_modifier = 1
+        op.attrs["padding"] = Padding.VALID
+    else:
+        shape_modifier = 0
+        op.attrs["padding"] = Padding.SAME
+    op.inputs[0].resampling_mode = resampling_mode.NEAREST
+
+    upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
+    out_shape = np.array(op.ofm_shapes[0].get_hw_as_list())
+    if (upscaled_shape == upscaled_shape * 2 - shape_modifier).all():
+        return op
+
+    while (upscaled_shape < out_shape).all():
+        if count == 0:
+            scaled_op = pre_op
+        else:
+            scaled_op = op.clone("_{}".format(count))
+            scaled_op.inputs[0] = pre_op.outputs[0]
+
+        upscaled_shape = upscaled_shape * 2 - shape_modifier
+
+        if (upscaled_shape == out_shape).all():
+            scaled_op.outputs = outputs
+            scaled_op.outputs[0].ops = [scaled_op]
+        else:
+            shape = op.ofm_shapes[0].as_list()
+            shape[1:3] = upscaled_shape
+            out_tens = Tensor(shape, DataType.int16, "{}_{}".format(op.outputs[0].name, count))
+            out_tens.quantization = op.outputs[0].quantization.clone()
+            out_tens.quantization.quant_min = np.iinfo(np.int16).min
+            out_tens.quantization.quant_max = np.iinfo(np.int16).max
+            scaled_op.set_output_tensor(out_tens)
+            pre_op = scaled_op
+            count += 1
+
+        # Setup the scale value
+        if scaled_op.inputs[0].dtype.bits == 8 and scaled_op.outputs[0].dtype.bits == 16:
+            scaled_op.rescale = 128
+        elif scaled_op.inputs[0].dtype.bits == 16 and scaled_op.outputs[0].dtype.bits == 8:
+            scaled_op.rescale = 1 / 128
+        else:
+            scaled_op.rescale = None
+        scaled_op.set_ifm_ofm_shapes()
+
+    return op
+
+
+def fixup_resizebilinear(op, arch, nng):
+    if op.type == Op.ResizeBilinear and op.run_on_npu:
+        if op.ifm_shapes[0] == op.ofm_shapes[0]:
+            # Bypass nop resizebilinear
+            op.inputs = op.inputs[:1]
+            op.type = Op.Identity
+        elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
+            convert_resizebilinear_1x1_to_add(op)
+        else:
+            convert_resizebilinear_to_2x2_pool(op)
+
+    return op
+
+
+def convert_nop_split_to_identity(op, arch, nng):
+    if op.type == Op.Split and op.attrs.get("num_splits") == 1:
+        # the list comprehension should return a list with a single tensor
+        # if it shouldn't, remove_passthrough_tensor will fail appropriately
+        op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
+        op.type = Op.Identity
+    return op
+
+
+def rewrite_fully_connected_input(op, arch, nng):
+    if op.type == Op.FullyConnected:
+        n_in_elems = op.weights.shape[-2]
+        elms = op.ifm.elements()
+        batch_size = elms // n_in_elems
+        assert batch_size * n_in_elems == elms
+
+        op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
+    return op
+
+
+def convert_batched_fc_shape(op, arch, nng):
+    if op.type == Op.FullyConnected:
+        # Check if the first dimension indicates batching
+        if op.ifm_shapes[0].batch > 1:
+            batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
+            n = op.ifm_shapes[0].batch
+            h, w = batching_split.get(n, (1, n))
+            op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
+
+            # Reshape Weights to be 4D. IO becomes HWIO
+            weight_tensor = op.inputs[1]
+            weight_tensor.quant_values = np.expand_dims(np.expand_dims(weight_tensor.quant_values, axis=0), axis=0)
+            weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
+
+            n = op.ofm_shapes[0].batch
+            h, w = batching_split.get(n, (1, n))
+            op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
+    return op
+
+
+def unfuse_activation_function(op):
+    if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
+        act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
+        op.activation = None
+        out_tens = op.outputs[0]
+        intermediate_tens = out_tens.clone("_act_intermediate")
+        act_op.set_output_tensor(out_tens)
+        act_op.add_input_tensor(intermediate_tens)
+        op.set_output_tensor(intermediate_tens)
+        act_op.set_ifm_ofm_shapes()
+
+
+def rewrite_stridedslice_output(op, arch, nng):
+    if not op.run_on_npu or op.type != Op.StridedSlice:
+        return op
+
+    new_axis_mask = op.attrs["new_axis_mask"]
+    shrink_axis_mask = op.attrs["shrink_axis_mask"]
+
+    if shrink_axis_mask == 0 and new_axis_mask == 0:
+        return op
+
+    axis_4D = [0] * len(op.outputs)
+    for idx, out_tens in enumerate(op.outputs):
+        output_shape = list(out_tens.shape)
+
+        if shrink_axis_mask != 0:
+            n = 0
+            axis = 0
+            while shrink_axis_mask:
+                prev_mask = shrink_axis_mask
+                n += 1
+                shrink_axis_mask &= shrink_axis_mask - 1
+                axis = int(math.log2(prev_mask - shrink_axis_mask))
+                output_shape = output_shape[:axis] + [1] + output_shape[axis:]
+
+            assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
+            op.attrs["shrink_axis_mask"] = 0
+            if axis >= 0:
+                axis_4D[idx] = axis + (4 - len(output_shape))
+            else:
+                axis_4D[idx] = axis
+            op.ofm_shapes[idx] = Shape4D(output_shape)
+
+        elif new_axis_mask != 0:
+            n = 0
+            axis = 0
+            while new_axis_mask:
+                prev_mask = new_axis_mask
+                n += 1
+                new_axis_mask &= new_axis_mask - 1
+                axis = int(math.log2(prev_mask - new_axis_mask))
+                output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
+                new_axis_mask >>= 1
+
+            assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
+            op.attrs["new_axis_mask"] = 0
+            if axis >= 0:
+                axis_4D[idx] = axis + (4 - len(output_shape))
+            else:
+                axis_4D[idx] = axis
+            op.ofm_shapes[idx] = Shape4D(output_shape)
+
+    op.attrs["split_axis_4D"] = axis_4D
+    return op
+
+
+def rewrite_unpack_output(op, arch, nng):
+    tens = op.outputs[0]
+    if op.run_on_npu and op.type == Op.Unpack:
+        # Unpack is also referred to as Unstack
+        axis = int(op.attrs["axis"])
+        if axis < 0:  # Convert to positive axis
+            axis = len(op.inputs[0].shape) + 1 + axis
+        op.type = Op.UnpackReshaped
+        desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
+
+        axis_4D = axis + (4 - len(desired_output_shape))
+        op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
+
+        for idx, out_tens in enumerate(op.outputs):
+            op.ofm_shapes[idx] = Shape4D(desired_output_shape)
+    return op
+
+
+def add_padding_fields(op, arch, nng):
+    if op.run_on_npu:
+        if "padding" in op.attrs:
+            input_shape = op.ifm_shapes[0]
+            output_shape = op.ofm_shapes[0]
+            if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
+                kernel_size = op.inputs[1].shape[:2]
+            elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
+                kernel_size = op.attrs["ksize"][1:3]
+            else:
+                raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
+
+            if op.type == Op.Conv2DBackpropInputSwitchedBias:
+                upscaling_factor = output_shape.height // input_shape.height
+                padding, skirt = calc_upscaled_padding_and_skirt(
+                    op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
+                )
+            else:
+                padding, skirt = calc_padding_and_skirt(
+                    op.attrs["padding"], op.kernel, input_shape, op.attrs.get("explicit_padding"),
+                )
+
+            op.attrs["explicit_padding"] = padding
+            op.attrs["skirt"] = skirt
+
+    return op
+
+
+def convert_depthwise_to_conv(op, arch, nng):
+    # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and
+    # the ofm depth equals the depth multipler.
+    # If those conditions are true, then we can perform a simple
+    # switch of the operator type (and weight order)
+
+    if op.type == Op.DepthwiseConv2DBias and (op.attrs["depth_multiplier"] != 1):
+        ifm_shape = op.ifm_shapes[0]
+        weight_tensor = op.inputs[1]
+        ofm_shape = op.ofm_shapes[0]
+        if (ifm_shape.depth == 1) and (ofm_shape.depth == op.attrs["depth_multiplier"]):
+            # Change op type to Conv2d
+            op.type = Op.Conv2DBias
+            del op.attrs["channel_multiplier"]
+            del op.attrs["depth_multiplier"]
+
+            weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2))
+            weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
+        else:
+            raise UnsupportedFeatureError(
+                f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},",
+                f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}",
+            )
+        DebugDatabase.add_optimised(op, op)
+    return op
+
+
+def reorder_depthwise_weights(op, arch, nng):
+    if op.type.is_depthwise_conv2d_op():
+        weight_tensor = op.inputs[1]
+        weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2))
+        weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
+        weight_tensor.weight_transpose_depthwise = True
+
+    return op
+
+
+def optimise_strided_conv(op, arch, nng):
+    stride_x, stride_y = op.get_kernel_stride()
+    ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()
+
+    if (
+        op.type == Op.Conv2DBias
+        and op.op_index == 0
+        and stride_x == 2
+        and op.ifm_shapes[0].depth <= 4
+        and op.ifm_shapes[0].width % 2 == 0
+        and weight_tensor is not None
+        and weight_tensor.shape[1] >= 2
+    ):
+        ifm_shape = op.ifm_shapes[0]
+        # IFM
+        op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
+
+        # Weights
+        weight_shape = weight_tensor.shape
+        if weight_shape[1] % 2 != 0:
+            weight_shape[1] = weight_shape[1] + 1
+            padded_array = np.zeros(weight_shape)
+            for i in range(weight_shape[0]):
+                padded_array[i] = np.vstack(
+                    [
+                        weight_tensor.quant_values[i],
+                        np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
+                    ]
+                )
+            weight_tensor.quant_values = padded_array
+        weight_shape[1] //= 2
+        weight_shape[2] *= 2
+        weight_tensor.quant_values = np.reshape(weight_tensor.quant_values, weight_shape)
+        weight_tensor.set_all_shapes(weight_shape)
+        # If multiple copies of the weights are used, we could avoid
+        # them having the same address by changing the value_id
+        weight_tensor.value_id = uuid.uuid4()
+
+        # Strides
+        stride_x = 1
+        op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
+
+    return op
+
+
+def convert_conv_to_fc(op, arch, nng):
+    # Conv 1x1 can be equivalent to Fully Connected.
+    # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
+    # caching/double buffering for the weights.
+    # (Weights dont need to be reloaded for convs when IFM H and W are 1)
+    if op.type == Op.Conv2DBias:
+        h = op.ifm_shapes[0].height
+        w = op.ifm_shapes[0].width
+        kh, kw, _, _ = op.inputs[1].shape
+        if h == 1 and w == 1 and kh == 1 and kw == 1:
+            # Overwrite this op as a Fully Connected Op
+            op.name += "_fc"
+            op.type = Op.FullyConnected
+            op.attrs = {
+                "weights_format": 0,
+            }
+            # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
+            weight_tensor = op.inputs[1]
+            weight_tensor.quant_values = weight_tensor.quant_values.squeeze(axis=(0, 1))
+            weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
+
+            DebugDatabase.add_optimised(op, op)
+    return op
+
+
+def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
+    if op.run_on_npu and op.type.is_relu_op():
+        ifm = op.inputs[0]
+        ofm = op.outputs[0]
+        # Relu with differing IFM and OFM scaling cannot be fused with another primary op
+        # and requires its own to be inserted
+        if not check_quantized_tens_scaling_equal(ifm, ofm):
+            # Override this op with its own primary op (avgpool)
+            relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
+            # And fuse the original activation function to it
+            relu_fused_op.activation = create_activation_function(op.type)
+            # Tidy up and assign the ifm and ofm to the new op
+            ifm.consumer_list.remove(op)
+
+            relu_fused_op.add_input_tensor(ifm)
+            relu_fused_op.set_output_tensor(ofm)
+            relu_fused_op.set_ifm_ofm_shapes()
+            op = relu_fused_op
+    return op
+
+
+def fixup_elementwise_with_scalars(op, arch, nng):
+    if op.type.is_binary_elementwise_op():
+        ifm_tensor, ifm2_tensor, _, _ = op.get_ifm_ifm2_weights_ofm()
+        if ifm2_tensor.shape != [] and ifm_tensor.shape != []:
+            diff = len(ifm_tensor.shape) - len(ifm2_tensor.shape)
+            if diff > 0:
+                ifm2_tensor.shape = full_shape(len(ifm_tensor.shape), ifm2_tensor.shape, 1)
+            elif diff < 0:
+                ifm_tensor.shape = full_shape(len(ifm2_tensor.shape), ifm_tensor.shape, 1)
+        elif ifm_tensor.shape == [] and ifm_tensor.quant_values is None:
+            # IFM is marked as a scalar, but is a result of an operation; change it to a shape of size 1
+            ifm_tensor.shape = len(ifm2_tensor.shape) * [1]
+            ifm_tensor.storage_shape = ifm_tensor.shape
+        elif ifm2_tensor.shape == [] and ifm2_tensor.quant_values is None:
+            # IFM2 is marked as a scalar, but is a result of an operation; change it to a shape of size 1
+            ifm2_tensor.shape = len(ifm_tensor.shape) * [1]
+            ifm2_tensor.storage_shape = ifm2_tensor.shape
+    return op
+
+
+def convert_softmax(op, arch, nng):
+    if op.type == Op.Softmax and op.run_on_npu:
+        softmax = SoftMax(op)
+        op = softmax.get_graph()
+    return op
+
+
+def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
+    r"""Whenever there is a subgraph with this topology:
+
+       Input    X   For X = -1 or X > 0
+       |   \   /    This subgraph can be replaced with either
+       |    Mul     an Abs (if X = -1) or a LeakyReLU (if X > 0)
+       |   /
+       Max
+    """
+
+    if op.type == Op.Maximum:
+        # finds the Mul input(s) to the Max
+        muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
+        if len(muls) == 1:
+            mul = muls[0].ops[0]
+        elif len(muls) == 2:
+            # In the case both inputs are Muls, find the one with the same input as the Max
+            mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
+        else:
+            # No Mul inputs
+            return op
+
+        # make sure the Mul doesn't have any other consumers
+        mul_ofm = mul.outputs[0]
+        if len(mul_ofm.consumers()) != 1:
+            return op
+        # make sure the Mul doesn't have a fused activation function
+        if mul.activation:
+            return op
+        ifm, ofm = op.get_ifm_ofm()
+        if ifm is None or ofm is None:
+            return op
+
+        if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
+            return op
+        if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
+            # rewrite to LeakyRelu currently only makes sense if the quantization is identical
+            return op
+
+        # finds the branched input that goes to both the Max and the Mul
+        shared = set(op.inputs) & set(mul.inputs)
+        if len(shared) == 1:
+            shared_in = shared.pop()
+            # find the constant scalar input to the Mul
+            const_tens = (set(mul.inputs) - {shared_in}).pop()
+            # check that it is a scalar
+            if const_tens.shape != []:
+                return op
+            const = const_tens.ops[0]
+            # check that it is a constant
+            if const.type != Op.Const:
+                return op
+            # Remove the Mul from the shared input's consumers
+            shared_in.consumer_list.remove(mul)
+        else:
+            return op
+
+        val = const.outputs[0].values
+        if val >= 0:
+            new_op = Op.LeakyRelu
+            op.attrs["alpha"] = val
+            # to produce bit exact results, the alpha is not enough;
+            # save additional scaling info in attr "alpha_scale", to be used as input
+            # to the LUT construction
+            alpha_scalar = const_tens.quant_values - const_tens.quantization.zero_point
+            mul_ifm_scale = np.double(ifm.quantization.scale_f32)
+            mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
+            mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
+            alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
+            op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
+        elif val == -1:
+            new_op = Op.Abs
+        else:
+            return op
+
+        op.type = new_op
+        op.name = op.name.replace("Maximum", new_op.name)
+        op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
+        op.inputs = [shared_in]
+        op.set_ifm_ofm_shapes()
+
+        # Record optimisation in debug database
+        DebugDatabase.add_optimised(op, op)
+
+    return op
+
+
+def convert_hardswish_to_lut(op, arch, nng):
+    if op.type == Op.HardSwish:
+        ifm, ofm = op.get_ifm_ofm()
+        # Generate the LUT
+        ifm_scale = np.double(ifm.quantization.scale_f32)
+        ofm_scale = np.double(ofm.quantization.scale_f32)
+        zp_in = ifm.quantization.zero_point
+        zp_out = ofm.quantization.zero_point
+        ifm_scale_hires = (1 / 128) * ifm_scale
+        relu_multiplier = np.double(3 / 32768)
+        out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
+        relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
+        # Use 16bit scale
+        out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
+        relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
+
+        values = []
+        ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
+        quantized_min = min(ix)
+        quantized_max = max(ix)
+        for x in ix:
+            input_value = x - zp_in
+            input_value_hires = input_value * 128
+            # Compute the input value on essentially the output scale, not shifted yet
+            input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
+            # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
+            relu_value = np.int16(input_value_hires)
+            if relu_shift < 31:
+                relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
+
+            relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
+
+            if relu_shift < 31:
+                relu_value = fp_math.shift_left16(relu_value, 1)
+
+            if relu_shift > 31:
+                relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
+
+            # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
+            # Now convert that to a 16bit fixedpoint value in [0, 1]
+            relu_value = (relu_value + (1 << 15)) >> 1
+            lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
+            shift = 31 - out_shift
+            shift = -shift if shift < 0 else 0
+            # Finally apply the output shift
+            lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
+            lut_result = min(quantized_max, max(quantized_min, lut_result))
+            values.append(lut_result)
+        return convert_to_lut(op, values, "hardswish")
+    return op
+
+
+def convert_lrelu_to_mul_max(op, arch):
+    # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
+    # (the opposite of convert_mul_max_to_abs_or_lrelu)
+    ifm, ofm = op.get_ifm_ofm()
+    if ifm is None or ofm is None:
+        return op
+
+    # Add multiplication with alpha
+    mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
+    mul_alpha.add_input_tensor(ifm)
+    # Create const tensor containing alpha as scalar
+    alpha = op.attrs["alpha"]
+    quantization = ifm.quantization.clone()
+    quantization.min = 0
+    quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
+    quantization.zero_point = 0
+    if np.isinf(1 / np.float32(alpha)):
+        # Handling of alpha near zero
+        quantization.scale_f32 = 1
+        scalar = 0
+    else:
+        quantization.scale_f32 = alpha
+        scalar = alpha
+    alpha_tens = create_const_tensor(
+        op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
+    )
+    alpha_tens.quant_values = np.array([1])
+    mul_alpha.add_input_tensor(alpha_tens)
+    fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
+    mul_alpha.set_output_tensor(fm_alpha)
+    mul_alpha.set_ifm_ofm_shapes()
+    DebugDatabase.add_optimised(op, mul_alpha)
+
+    if check_quantized_tens_scaling_equal(ifm, ofm):
+        # No identity multiplication is needed
+        fm_id = ifm
+    else:
+        # Add multiplication with identity
+        mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
+        mul_identity.add_input_tensor(ifm)
+        # Create const tensor containing identity as scalar
+        quantization = ifm.quantization.clone()
+        quantization.min = 0
+        quantization.max = quantization.quant_max - quantization.quant_min
+        quantization.scale_f32 = 1
+        quantization.zero_point = 0
+        identity_tens = create_const_tensor(
+            op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
+        )
+        mul_identity.add_input_tensor(identity_tens)
+        # Make sure that fm_id is allocated to a different address than fm_alpha
+        fm_id = ofm.clone(op.name + "_id", set_unique=True)
+        mul_identity.set_output_tensor(fm_id)
+        mul_identity.set_ifm_ofm_shapes()
+        DebugDatabase.add_optimised(op, mul_identity)
+
+    # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
+    op.type = Op.Maximum
+    op.name = op.name.replace("LeakyRelu", "Maximum")
+    op.inputs = []
+    ifm.consumer_list.remove(op)
+    op.add_input_tensor(fm_alpha)
+    op.add_input_tensor(fm_id)
+    op.set_ifm_ofm_shapes()
+
+    DebugDatabase.add_optimised(op, op)
+    return op
+
+
+def convert_to_lut(op, lut_values, lut_name):
+    # Rewrite the operation by Add with scalar 0 + LUT activation
+    ifm = op.inputs[0]
+    if ifm is None:
+        return op
+    assert ifm.dtype.size_in_bytes() == 1
+    op.type = Op.Add
+    op.name = op.name + "_lut_" + lut_name
+    # Mark as no-op to enable potential fusing optimizations
+    op.attrs["is_nop"] = True
+    # Create an input tensor containing scalar zero
+    quantization = QuantizationParameters(0.0, 255.0)
+    quantization.scale_f32 = ifm.quantization.scale_f32
+    quantization.zero_point = 0
+    tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization)
+    op.add_input_tensor(tens)
+    op.ifm_shapes.append(Shape4D(tens.shape))
+
+    # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
+    # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
+    # should be the same as the IFM
+    op.forced_output_quantization = ifm.quantization
+    lut_tensor = lut.create_lut_tensor(op.name + "_values", lut_values, DataType.int8)
+    op.set_activation_lut(lut_tensor)
+    op.set_ifm_ofm_shapes()
+    return op
+
+
+def convert_to_lut8(op, fn, fn_name):
+    # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
+    # fn is a function(real) -> real
+    ifm, ofm = op.get_ifm_ofm()
+    if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
+        return op
+    # Generate the LUT
+    ifm_scale = np.double(ifm.quantization.scale_f32)
+    ofm_scale = np.double(ofm.quantization.scale_f32)
+    zp_in = ifm.quantization.zero_point
+    zp_out = ofm.quantization.zero_point
+    values = []
+    ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
+    quantized_min = min(ix)
+    quantized_max = max(ix)
+    for x in ix:
+        x_real = ifm_scale * (x - zp_in)
+        y_real = fn(x_real)
+        lut_result = round_away_zero(zp_out + y_real / ofm_scale)
+        lut_result = min(quantized_max, max(quantized_min, lut_result))
+        values.append(lut_result)
+    return convert_to_lut(op, values, fn_name)
+
+
+def convert_lrelu_to_lut(op, arch):
+    ifm, ofm = op.get_ifm_ofm()
+    # Generate the LUT
+    alpha = op.attrs["alpha"]
+    ifm_scale = np.double(ifm.quantization.scale_f32)
+    ofm_scale = np.double(ofm.quantization.scale_f32)
+    zp_in = ifm.quantization.zero_point
+    zp_out = ofm.quantization.zero_point
+    identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
+    alpha_scalar = 1
+    alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
+    if "alpha_scaling" in op.attrs:
+        # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
+        alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
+    values = []
+    ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
+    quantized_min = min(ix)
+    quantized_max = max(ix)
+    for x in ix:
+        if x < zp_in:
+            lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
+                alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
+            )
+        else:
+            lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
+        lut_result = min(quantized_max, max(quantized_min, lut_result))
+        values.append(lut_result)
+    return convert_to_lut(op, values, "lrelu")
+
+
+def convert_lrelu(op, arch, nng):
+    # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
+    if op.type != Op.LeakyRelu:
+        return op
+    ifm, ofm = op.get_ifm_ofm()
+    if ifm is None or ofm is None:
+        return op
+    if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
+        # use LUT for int8/uint8
+        return convert_lrelu_to_lut(op, arch)
+    if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
+        # use LeakyRelu unmodified for int16 with equal input/output scaling
+        return op
+    return convert_lrelu_to_mul_max(op, arch)
+
+
+def convert_tanh_sigmoid_to_lut(op, arch, nng):
+    # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
+    if op.type == Op.Sigmoid:
+        return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
+    elif op.type == Op.Tanh:
+        return convert_to_lut8(op, math.tanh, "tanh")
+    return op
+
+
+def remove_reshapes(op, arch):
+    if op.run_on_npu and op.type == Op.Reshape:
+        ofm = op.ofm
+        ifm = op.ifm
+
+        # Check if quantization is the same in the input and output for the reshape ops
+        if not check_quantized_tens_scaling_equal(ifm, ofm):
+            # TODO Both tensors are needed, since quantisation properties currently are linked to Tensors.
+            # In order to remove this reshape either quantization properties need to be moved to Operator,
+            # or the reshape need to be replace with a NOP.
+            return
+
+        # Check if Reshape ifm/ofm are network ifm/ofm
+        ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
+        ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list)
+        ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list)
+        # Check if ifm/ofm is produced repectivly consumed by CPU
+        ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
+        ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
+
+        # This case should be handled prior to this function
+        assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed))
+
+        if ofm_is_sg_ofm or ofm_is_cpu_consumed:
+            # Bypassed by replacing ifm with ofm
+            ofm.ops = []
+            for prev_op in ifm.ops:
+                prev_op.outputs = [ofm]
+                ofm.ops.append(prev_op)
+
+            # All ifm consumers need to use ofm as input
+            for ifm_cons in ifm.consumer_list:
+                for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
+                    if cons_ifm == ifm:
+                        ifm_cons.set_input_tensor(ofm, ifm_idx)
+        else:
+            # Bypassed Reshape by replacing ofm with ifm
+            for cons in ofm.consumer_list:
+                for ifm_idx, cons_ifm in enumerate(cons.inputs):
+                    if cons_ifm == ofm:
+                        cons.set_input_tensor(ifm, ifm_idx)
+
+
+def fuse_activation_function_with_prev(op, arch, nng):
+    # if op is a no-op: attempts to move the activation function to the preceding op
+    if not op.attrs.get("is_nop", False) or op.activation is None:
+        return op
+    ifm, ofm = op.get_ifm_ofm()
+    if ifm is None or ofm is None:
+        return op
+    # finds the input(s) to the operation
+    prev_op = ifm.ops[0]
+    # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
+    fuse = (
+        prev_op.run_on_npu
+        and prev_op.type.npu_block_type != NpuBlockType.Default
+        and len(ifm.ops) == 1
+        and len(prev_op.outputs[0].consumers()) == 1
+        and prev_op.activation is None
+    )
+    if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
+        # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
+        # LUT currently only works correctly for elementwise ops
+        fuse = False
+    if not fuse:
+        return op
+    # Move the fused activation function + corresponding info to prev_op
+    prev_op.activation = op.activation
+    prev_op.forced_output_quantization = op.forced_output_quantization
+    if op.activation_lut is not None:
+        prev_op.set_activation_lut(op.activation_lut)
+    # Bypass op
+    prev_op.set_output_tensor(ofm)
+    DebugDatabase.add_optimised(op, prev_op)
+    return op
+
+
+def _leading_pad_ok(leading_pad, stride, kernel_size):
+    # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
+    # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
+    max_size = kernel_size // 2
+    return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
+
+
+def replace_pad_by_hw_pad(op: Operation, arch, nng):
+    """
+    Tries to completely remove a PAD operator by using hardware padding.
+    E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
+    is rewritten such that the PAD is removed, and the CONV uses SAME padding.
+    Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
+    if both operations can be run on the NPU.
+    This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
+    """
+    if (
+        (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
+        and op.run_on_npu
+        and op.attrs["padding"] == Padding.VALID
+    ):
+        pad_op = op.ifm.ops[0]
+        if pad_op.type != Op.Pad or not pad_op.run_on_npu:
+            return op
+        if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
+            return op
+        top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
+        k = op.kernel
+        k_w, k_h = k.dilated_wh()
+
+        # Check if the PAD operator can be replaced by hardware padding
+        if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
+            # Too much padding, it would require hardware padding to actually insert zeros
+            return op
+        if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
+            return op
+
+        if op.type.is_avgpool_op():
+            # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
+            for pad, k_size in (
+                (left, k_w),
+                (right, k_w),
+                (top, k_h),
+                (bottom, k_h),
+            ):
+                if pad not in (0, k_size // 2):
+                    return op
+            # Average pool is converted to depthwise, because NPU average pool + same padding
+            # has a special implementation that is different from PAD followed by average pool with
+            # valid padding.
+            k_w, k_h = op.kernel.width, op.kernel.height
+            ifm = op.ifm
+            # Remember other inputs
+            other_inputs = op.inputs[1:]
+            # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
+            quantization = QuantizationParameters(0.0, 255.0)
+            quantization.scale_f32 = 1.0 / (k_w * k_h)
+            quantization.zero_point = 0
+            shape = [k_h, k_w, 1, op.ofm.shape[-1]]
+            weights = np.full(shape, 1)
+
+            weight_tens = create_const_tensor(
+                op.name + "_weights",
+                shape,
+                op.ifm.dtype,
+                weights,
+                np.uint8,
+                purpose=TensorPurpose.Weights,
+                quantization=quantization,
+            )
+            weight_tens.quant_values = weights
+            op.type = Op.DepthwiseConv2DBias
+            op.inputs = []
+            op.add_input_tensor(ifm)
+            op.add_input_tensor(weight_tens)
+            # Add bias tensor, all biases set to 0
+            op.inputs.append(None)
+            fixup_bias_tensors(op, arch, nng)
+            # Add other inputs
+            op.inputs.extend(other_inputs)
+            op.rounding_mode = NpuRoundingMode.NATURAL
+
+        # Bypass the PAD operator
+        op.set_input_tensor(pad_op.ifm, 0)
+        # Adjust the padding attributes of the convolution operator
+        op.attrs["padding"] = Padding.EXPLICIT
+        op.attrs["explicit_padding"] = (top, left, bottom, right)
+        op.set_ifm_ofm_shapes()
+    return op
+
+
+def convert_pad(op: Operation, arch, nng):
+    """
+    Rewrites PAD operator to an average pool that copies the IFM to the OFM
+    + up to 4 average pool operators that fill the OFM with zeros at the borders.
+    This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
+    """
+    if op.type != Op.Pad or not op.run_on_npu:
+        return op
+    top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
+
+    ifm = op.ifm
+    assert ifm is not None
+    ifm_shape = Shape4D(ifm.shape)
+    ofm = op.ofm
+    assert ofm is not None
+    ofm.ops = []
+    ofm_shape = op.ofm_shapes[0]
+
+    # Average pool op that copies IFM to the right place inside the OFM
+    shp0 = Shape4D(0, 0, 0, 0)
+    shp_top = shp0.with_height(top)
+    avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
+    avgpool_op.activation = op.activation
+    quant = ofm.quantization
+    pad_value = quant.zero_point
+    # Add operations that fill the borders of the OFM
+    if top > 0:
+        shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
+        zero_tens = create_const_tensor(
+            op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+        )
+        # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
+        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+        create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
+    if bottom > 0:
+        shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
+        zero_tens = create_const_tensor(
+            op.name + "_bottom",
+            shape.as_list(),
+            ofm.dtype,
+            shape.elements() * [pad_value],
+            np.uint8,
+            quantization=quant,
+        )
+        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+        create_avg_pool_for_concat(
+            op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
+        )
+    if left > 0:
+        shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
+        zero_tens = create_const_tensor(
+            op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+        )
+        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+        create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
+    if right > 0:
+        shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
+        zero_tens = create_const_tensor(
+            op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+        )
+        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+        create_avg_pool_for_concat(
+            op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
+        )
+
+    op.type = Op.ConcatTFLite
+    return avgpool_op
+
+
+def add_attrs_to_resizebilinear(op, arch, nng):
+    if op.type == Op.ResizeBilinear and op.run_on_npu:
+        input_tensor = op.inputs[0]
+        input_shape = op.ifm_shapes[0]
+        upscaled_height = input_shape.height * 2
+        upscaled_width = input_shape.width * 2
+        out_shape = op.ofm_shapes[0]
+        if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
+            # this means the output is supposed to be a x2 upscale,
+            # so we need to do SAME padding
+            op.attrs["padding"] = Padding.SAME
+        elif (
+            op.attrs["align_corners"]
+            and out_shape.height == (upscaled_height - 1)
+            and out_shape.width == (upscaled_width - 1)
+        ):
+            # here we can just run the avg pool without padding and
+            # produce a (M * 2 - 1, N * 2 - 1) sized output
+            op.attrs["padding"] = Padding.VALID
+        else:
+            return op
+        input_tensor.resampling_mode = resampling_mode.NEAREST
+        op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
+    return op
+
+
+def fixup_bias_tensors(op, arch, nng):
+    if op.type.needs_bias() and op.bias is None:
+        # Op has no bias, add bias tensor filled with zeros
+        nr_biases = op.inputs[1].shape[-1]
+        bias_values = [0] * nr_biases
+        bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
+        bias_tensor.quant_values = bias_tensor.values
+        op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
+
+    return op
+
+
+def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
+    if op.type == Op.Mean and op.run_on_npu:
+        keep_dims = op.attrs.get("keep_dims", False)
+        inp, axis = op.inputs
+        shape = inp.shape
+        dims = len(shape)
+
+        # Height and width axes have different index depending on dimensions
+        if axis.shape == [] or axis.shape[0] == 1:  # single axis
+            axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
+            if dims in (2, 3):
+                if axis == 0:
+                    h, w = shape[axis], 1
+                else:
+                    h, w = 1, shape[axis]
+            else:
+                if axis == 1:
+                    h, w = shape[axis], 1
+                else:
+                    h, w = 1, shape[axis]
+        else:  # multiple axes
+            axis = sorted(axis.values)
+            h, w = [shape[i] for i in axis]
+
+        # Set necessary depthwise attributes
+        op.attrs.update(
+            {
+                "padding": Padding.VALID,
+                "stride_h": 1,
+                "stride_w": 1,
+                "strides": (1, 1, 1, 1),
+                "depth_multiplier": 1,
+                "channel_multiplier": 1,
+                "dilation_h_factor": 1,
+                "dilation_w_factor": 1,
+                "dilation": (1, 1, 1, 1),
+            }
+        )
+        # Change op type
+        op.type = Op.DepthwiseConv2DBias
+        # Set IFM/OFM shapes after changing op type
+        op.set_ifm_ofm_shapes()
+
+        weight_scale, bias = 1, None
+        ofmq, ifmq = op.ofm.quantization, inp.quantization
+        # Set rounding mode, scaling and zero point based on which reference implementation to match
+        if len(shape) == 4 and axis == [1, 2] and keep_dims:
+            if inp.dtype == DataType.uint8:
+                # This attribute means a different scaling calculation is used in order to match reference
+                op.low_precision_scaling = True
+                weight_scale = h * w
+                # Set zero points to 0 as they will be adjusted for with bias term
+                foq = ofmq.clone()
+                foq.zero_point = 0
+                fiq = ifmq.clone()
+                fiq.zero_point = 0
+                op.forced_input_quantization = fiq
+                bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
+                # If the bias term is outside uint8 range, we need an Add op to apply it.
+                if bias_term < 0 or bias_term > 255:
+                    intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
+                    # Bias term has higher bitness (i32) than input/output (u8).
+                    # 16 bits is enough since the bias is added/subtracted from a u8 value,
+                    # the bias can only effectively assume values in the range [-255, 255].
+                    intermediate.dtype = DataType.int16
+                    intermediate.quantization.zero_point = 0
+                    add_op = Operation(Op.Add, op.name + "_bias")
+                    add_op.forced_output_quantization = foq
+                    add_op.add_input_tensor(intermediate)
+                    quant = QuantizationParameters()
+                    quant.zero_point = 0
+                    bias_term_tens = create_const_tensor(
+                        op.name + "_bias",
+                        [1, 1, 1, 1],
+                        DataType.int16,
+                        [bias_term],
+                        np.int16,
+                        quantization=quant,
+                        quant_value_dtype=np.int16,
+                    )
+                    add_op.add_input_tensor(bias_term_tens)
+                    add_op.set_output_tensor(op.ofm)
+                    add_op.set_ifm_ofm_shapes()
+                    add_op.activation = op.activation
+                    op.activation = None
+                    op.set_output_tensor(intermediate)
+                    op.set_ifm_ofm_shapes()
+                # If not, we can just do it with the OFM zero point.
+                else:
+                    foq.zero_point = bias_term
+                    op.forced_output_quantization = foq
+            else:
+                assert inp.dtype == DataType.int8
+                # Use a depthwise to calculate the sum,
+                # followed by a multiplication with 1/N to get the MEAN
+                weight_scale = 1
+                intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
+                intermediate.dtype = DataType.int16
+                mul_op = Operation(Op.Mul, op.name + "_mul")
+                mul_op.add_input_tensor(intermediate)
+                # Create scalar containing 1/N
+                quant = QuantizationParameters()
+                quant.zero_point = 0
+                # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
+                # while rounding mode NATURAL would round this to -1.
+                # This can only occur if N is even, and can be emulated by
+                # multiplying with a number that is slightly smaller than 1/N.
+                # It must be so small that other roundings are not affected;
+                # the calculated value is based on worst case,
+                # which is sum 256 * N (the maximum sum that can occur with int8)
+                n = int(h * w)
+                eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
+                quant.scale_f32 = 1 / (n - eps)
+                scalar = create_const_tensor(
+                    op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
+                )
+                mul_op.add_input_tensor(scalar)
+                mul_op.set_output_tensor(op.ofm)
+                mul_op.set_ifm_ofm_shapes()
+                mul_op.rounding_mode = NpuRoundingMode.NATURAL
+                mul_op.activation = op.activation
+                op.activation = None
+                op.set_output_tensor(intermediate)
+                op.set_ifm_ofm_shapes()
+        elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
+            # Here we can just use a simple AvgPool with truncating rounding,
+            # as we're emulating simple integer division.
+            op.rounding_mode = NpuRoundingMode.TRUNCATE
+            op.type = Op.AvgPool
+            op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
+        else:
+            op.rounding_mode = NpuRoundingMode.NATURAL
+            weight_scale = 1 / (h * w)
+            # Input zero point is adjusted after mean calculation, so we emulate that with a bias
+            bias = -ifmq.zero_point * h * w
+            fiq = ifmq.clone()
+            fiq.zero_point = 0
+            op.forced_input_quantization = fiq
+
+        # Change dimensions to 4
+        if dims < 4:
+            shape = [1] + shape
+            if dims == 2:
+                shape += [1]
+
+        # If height is greater than max kernel height, reshape to from HxW to 1x(HxW)
+        if h > 64:
+            shape = [shape[0], 1, h * w, shape[3]]
+            op.ifm_shapes[0] = Shape4D(shape)
+            if h > 256 and op.type == Op.AvgPool:
+                op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
+
+        # If the AvgPool version is used, we don't need to do anything else
+        if op.type == Op.AvgPool:
+            return op
+
+        # Make unit weight tensor quantization
+        weight_quant = ifmq.clone()
+        weight_quant.min = 0
+        weight_quant.max = 255
+        weight_quant.scale_f32 = weight_scale
+        weight_quant.zero_point = 0
+
+        # Set weight shape to [H,W,C,B]
+        weight_shape = shape[1:4] + [shape[0]]
+        # Add unit weight tensor
+        op.set_input_tensor(
+            create_const_tensor(
+                "weights",
+                weight_shape,
+                inp.dtype,
+                np.ones(weight_shape),
+                value_dtype=np.uint8,
+                quantization=weight_quant,
+            ),
+            1,
+        )
+        op.weights.quant_values = np.reshape(op.inputs[1].quant_values, weight_shape)
+
+        # Add None bias tensor
+        op.inputs.append(None)
+        # Add bias tensor
+        if bias:
+            bias_shape = [shape[-1]]
+            op.set_input_tensor(
+                create_const_tensor(
+                    "bias",
+                    bias_shape,
+                    inp.dtype,
+                    np.ones(bias_shape) * bias,
+                    value_dtype=np.int32,
+                    quant_value_dtype=np.int32,
+                    quantization=None,
+                ),
+                2,
+            )
+
+    return op
+
+
+def supported_operator_check(op, arch, nng):
+    op.run_on_npu = arch.supported_operators.is_operator_supported(op)
+    return op
+
+
+def tflite_optimise_graph(nng, arch):
+    # Pre-processing step
+    pre_process_list = [
+        supported_operator_check,
+        set_ifm_ofm_op_shapes,
+    ]
+
+    for idx, sg in enumerate(nng.subgraphs):
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
+        )
+
+    # Handle Concat Ops
+    for idx, sg in enumerate(nng.subgraphs):
+        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
+        sg.refresh_after_modification()
+
+    # Handle Split Ops
+    for idx, sg in enumerate(nng.subgraphs):
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng,
+            sg,
+            arch,
+            [],
+            [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
+            rewrite_unsupported=False,
+        )
+
+    for idx, sg in enumerate(nng.subgraphs):
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False,
+        )
+
+    # Handle sg input output
+    for idx, sg in enumerate(nng.subgraphs):
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
+        )
+
+    # Removal of reshapes
+    for sg in nng.subgraphs:
+        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
+        sg.refresh_after_modification()
+
+    # Rewrite of operators
+    op_rewrite_list = [
+        set_tensor_equivalence,
+        convert_mean_to_depthwise_conv_or_avgpool,
+        convert_depthwise_to_conv,
+        convert_conv_to_fc,
+        convert_softmax,
+        optimise_strided_conv,
+        convert_hardswish_to_lut,
+        rewrite_fully_connected_input,
+        convert_batched_fc_shape,
+        fixup_conv2d_backprop,
+        fixup_relus_with_differing_ifm_ofm_scaling,
+        fixup_elementwise_with_scalars,
+        reorder_depthwise_weights,
+        fixup_resizebilinear,
+        fixup_bias_tensors,
+        convert_mul_max_to_abs_or_lrelu,
+        convert_lrelu,
+        convert_tanh_sigmoid_to_lut,
+        replace_pad_by_hw_pad,
+    ]
+
+    for idx, sg in enumerate(nng.subgraphs):
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
+        )
+
+    for idx, sg in enumerate(nng.subgraphs):
+        # remove passthrough tensors and attempt further optimizations
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng,
+            sg,
+            arch,
+            [remove_passthrough_tensor],
+            [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
+        )
+
+    # Removal of SplitSliceRead, need to be done after optimisation has been performed,
+    # since ifm/ofm_shapes are of importance to this function
+    for sg in nng.subgraphs:
+        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
+        sg.refresh_after_modification()
+
+    return nng
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index b47177f..1a45a5e 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -27,6 +27,9 @@
 from .operation import create_activation_function
 from .operation import Op
 from .operation import Operation
+from .reader_util import clone_and_reshape_tensor
+from .reader_util import decode_str
+from .reader_util import fixup_tensors
 from .tensor import QuantizationParameters
 from .tensor import Tensor
 from .tflite.BuiltinOperator import BuiltinOperator
@@ -37,29 +40,6 @@
 from .tflite_mapping import datatype_map_numpy
 
 
-def decode_str(s):
-    if s is None:
-        return ""
-    return s.decode("utf-8")
-
-
-def clone_and_reshape_tensor(src_tens, reorder, set_unique):
-    tens = src_tens.clone("_reshape", set_unique)
-    tens.shape = [src_tens.shape[idx] for idx in reorder]
-    tens.bandwidth_shape = tens.shape
-    tens.storage_shape = tens.shape
-
-    if tens.values is not None:
-        tens.values = tens.values.transpose(reorder)
-
-    if tens.quant_values is not None:
-        tens.quant_values = tens.quant_values.transpose(reorder)
-
-    op = Operation(Op.Const, tens.name)
-    op.set_output_tensor(tens)
-    return tens
-
-
 class TFLiteSubgraph:
     def __init__(self, graph, subgraph):
         self.graph = graph
@@ -74,19 +54,7 @@
 
         self.outputs = self.get_tensors_from_indices_remove_duplicates(subgraph.OutputsAsNumpy(), "output")
         self.inputs = self.get_tensors_from_indices_remove_duplicates(subgraph.InputsAsNumpy(), "input")
-
-        # Fix up tensors without operations. Generate either Placeholder or Constant ops
-        for tens in self.inputs:
-            if tens.ops != []:
-                tens.error("This subgraph input tensor has unexpected driving operators.")
-
-            op = Operation(Op.Placeholder, tens.name)
-            op.set_output_tensor(tens)
-
-        for tens in self.tensors:
-            if not tens.ops:
-                op = Operation(Op.Const, tens.name)
-                op.set_output_tensor(tens)
+        fixup_tensors(self.inputs, self.tensors)
 
     def get_tensors_from_indices_remove_duplicates(self, indices, warning_str):
         tensors = []
diff --git a/ethosu/vela/tosa/ArithmeticRightShiftAttribute.py b/ethosu/vela/tosa/ArithmeticRightShiftAttribute.py
new file mode 100644
index 0000000..ad7b9b0
--- /dev/null
+++ b/ethosu/vela/tosa/ArithmeticRightShiftAttribute.py
@@ -0,0 +1,30 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class ArithmeticRightShiftAttribute(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsArithmeticRightShiftAttribute(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = ArithmeticRightShiftAttribute()
+        x.Init(buf, n + offset)
+        return x
+
+    # ArithmeticRightShiftAttribute
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # ArithmeticRightShiftAttribute
+    def Round(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
+        return False
+
+def ArithmeticRightShiftAttributeStart(builder): builder.StartObject(1)
+def ArithmeticRightShiftAttributeAddRound(builder, round): builder.PrependBoolSlot(0, round, 0)
+def ArithmeticRightShiftAttributeEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/Attribute.py b/ethosu/vela/tosa/Attribute.py
new file mode 100644
index 0000000..8a2ccae
--- /dev/null
+++ b/ethosu/vela/tosa/Attribute.py
@@ -0,0 +1,22 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+class Attribute(object):
+    NONE = 0
+    Pool2dAttribute = 1
+    Conv2dAttribute = 2
+    TransposeConv2dAttribute = 3
+    ReluNAttribute = 4
+    AxisAttribute = 5
+    ReshapeAttribute = 6
+    SliceAttribute = 7
+    TileAttribute = 8
+    ResizeAttribute = 9
+    ClampAttribute = 10
+    RescaleAttribute = 11
+    MulAttribute = 12
+    ArithmeticRightShiftAttribute = 13
+    CondIfAttribute = 14
+    WhileLoopAttribute = 15
+
diff --git a/ethosu/vela/tosa/AxisAttribute.py b/ethosu/vela/tosa/AxisAttribute.py
new file mode 100644
index 0000000..77f1bca
--- /dev/null
+++ b/ethosu/vela/tosa/AxisAttribute.py
@@ -0,0 +1,30 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class AxisAttribute(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsAxisAttribute(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = AxisAttribute()
+        x.Init(buf, n + offset)
+        return x
+
+    # AxisAttribute
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # AxisAttribute
+    def Axis(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+        return 0
+
+def AxisAttributeStart(builder): builder.StartObject(1)
+def AxisAttributeAddAxis(builder, axis): builder.PrependInt32Slot(0, axis, 0)
+def AxisAttributeEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/ClampAttribute.py b/ethosu/vela/tosa/ClampAttribute.py
new file mode 100644
index 0000000..9cc8d4d
--- /dev/null
+++ b/ethosu/vela/tosa/ClampAttribute.py
@@ -0,0 +1,54 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class ClampAttribute(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsClampAttribute(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = ClampAttribute()
+        x.Init(buf, n + offset)
+        return x
+
+    # ClampAttribute
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # ClampAttribute
+    def MinInt(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+        return 0
+
+    # ClampAttribute
+    def MaxInt(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+        return 0
+
+    # ClampAttribute
+    def MinFp(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
+        return 0.0
+
+    # ClampAttribute
+    def MaxFp(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
+        return 0.0
+
+def ClampAttributeStart(builder): builder.StartObject(4)
+def ClampAttributeAddMinInt(builder, minInt): builder.PrependInt32Slot(0, minInt, 0)
+def ClampAttributeAddMaxInt(builder, maxInt): builder.PrependInt32Slot(1, maxInt, 0)
+def ClampAttributeAddMinFp(builder, minFp): builder.PrependFloat32Slot(2, minFp, 0.0)
+def ClampAttributeAddMaxFp(builder, maxFp): builder.PrependFloat32Slot(3, maxFp, 0.0)
+def ClampAttributeEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/CondIfAttribute.py b/ethosu/vela/tosa/CondIfAttribute.py
new file mode 100644
index 0000000..bc19b95
--- /dev/null
+++ b/ethosu/vela/tosa/CondIfAttribute.py
@@ -0,0 +1,38 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class CondIfAttribute(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsCondIfAttribute(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = CondIfAttribute()
+        x.Init(buf, n + offset)
+        return x
+
+    # CondIfAttribute
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # CondIfAttribute
+    def ThenBranch(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.String(o + self._tab.Pos)
+        return None
+
+    # CondIfAttribute
+    def ElseBranch(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.String(o + self._tab.Pos)
+        return None
+
+def CondIfAttributeStart(builder): builder.StartObject(2)
+def CondIfAttributeAddThenBranch(builder, thenBranch): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(thenBranch), 0)
+def CondIfAttributeAddElseBranch(builder, elseBranch): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(elseBranch), 0)
+def CondIfAttributeEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/Conv2dAttribute.py b/ethosu/vela/tosa/Conv2dAttribute.py
new file mode 100644
index 0000000..5faecbd
--- /dev/null
+++ b/ethosu/vela/tosa/Conv2dAttribute.py
@@ -0,0 +1,94 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class Conv2dAttribute(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsConv2dAttribute(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = Conv2dAttribute()
+        x.Init(buf, n + offset)
+        return x
+
+    # Conv2dAttribute
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # Conv2dAttribute
+    def Padding(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # Conv2dAttribute
+    def PaddingAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # Conv2dAttribute
+    def PaddingLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # Conv2dAttribute
+    def Stride(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # Conv2dAttribute
+    def StrideAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # Conv2dAttribute
+    def StrideLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # Conv2dAttribute
+    def Dilation(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # Conv2dAttribute
+    def DilationAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # Conv2dAttribute
+    def DilationLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+def Conv2dAttributeStart(builder): builder.StartObject(3)
+def Conv2dAttributeAddPadding(builder, padding): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(padding), 0)
+def Conv2dAttributeStartPaddingVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def Conv2dAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0)
+def Conv2dAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def Conv2dAttributeAddDilation(builder, dilation): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(dilation), 0)
+def Conv2dAttributeStartDilationVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def Conv2dAttributeEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/ConvQuantInfo.py b/ethosu/vela/tosa/ConvQuantInfo.py
new file mode 100644
index 0000000..9f785b3
--- /dev/null
+++ b/ethosu/vela/tosa/ConvQuantInfo.py
@@ -0,0 +1,38 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class ConvQuantInfo(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsConvQuantInfo(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = ConvQuantInfo()
+        x.Init(buf, n + offset)
+        return x
+
+    # ConvQuantInfo
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # ConvQuantInfo
+    def InputZp(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+        return 0
+
+    # ConvQuantInfo
+    def WeightZp(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+        return 0
+
+def ConvQuantInfoStart(builder): builder.StartObject(2)
+def ConvQuantInfoAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0)
+def ConvQuantInfoAddWeightZp(builder, weightZp): builder.PrependInt32Slot(1, weightZp, 0)
+def ConvQuantInfoEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/DType.py b/ethosu/vela/tosa/DType.py
new file mode 100644
index 0000000..65432db
--- /dev/null
+++ b/ethosu/vela/tosa/DType.py
@@ -0,0 +1,15 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+class DType(object):
+    UNKNOWN = 0
+    BOOL = 1
+    UINT8 = 2
+    INT4 = 3
+    INT8 = 4
+    INT16 = 5
+    INT32 = 6
+    INT48 = 7
+    FLOAT = 8
+
diff --git a/ethosu/vela/tosa/MatMulQuantInfo.py b/ethosu/vela/tosa/MatMulQuantInfo.py
new file mode 100644
index 0000000..6cc3a72
--- /dev/null
+++ b/ethosu/vela/tosa/MatMulQuantInfo.py
@@ -0,0 +1,38 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class MatMulQuantInfo(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsMatMulQuantInfo(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = MatMulQuantInfo()
+        x.Init(buf, n + offset)
+        return x
+
+    # MatMulQuantInfo
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # MatMulQuantInfo
+    def AZp(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+        return 0
+
+    # MatMulQuantInfo
+    def BZp(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+        return 0
+
+def MatMulQuantInfoStart(builder): builder.StartObject(2)
+def MatMulQuantInfoAddAZp(builder, aZp): builder.PrependInt32Slot(0, aZp, 0)
+def MatMulQuantInfoAddBZp(builder, bZp): builder.PrependInt32Slot(1, bZp, 0)
+def MatMulQuantInfoEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/MulAttribute.py b/ethosu/vela/tosa/MulAttribute.py
new file mode 100644
index 0000000..c08f684
--- /dev/null
+++ b/ethosu/vela/tosa/MulAttribute.py
@@ -0,0 +1,30 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class MulAttribute(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsMulAttribute(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = MulAttribute()
+        x.Init(buf, n + offset)
+        return x
+
+    # MulAttribute
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # MulAttribute
+    def Shift(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+        return 0
+
+def MulAttributeStart(builder): builder.StartObject(1)
+def MulAttributeAddShift(builder, shift): builder.PrependInt32Slot(0, shift, 0)
+def MulAttributeEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/Op.py b/ethosu/vela/tosa/Op.py
new file mode 100644
index 0000000..c71ac44
--- /dev/null
+++ b/ethosu/vela/tosa/Op.py
@@ -0,0 +1,75 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+class Op(object):
+    UNKNOWN = 0
+    ARGMAX = 1
+    AVG_POOL2D = 2
+    CONV2D = 3
+    CONV3D = 4
+    DEPTHWISE_CONV2D = 5
+    FULLY_CONNECTED = 6
+    MATMUL = 7
+    MAX_POOL2D = 8
+    TRANSPOSE_CONV2D = 9
+    CLAMP = 10
+    RELUN = 11
+    SIGMOID = 12
+    TANH = 13
+    ADD = 14
+    ARITHMETIC_RIGHT_SHIFT = 15
+    BITWISE_AND = 16
+    BITWISE_OR = 17
+    BITWISE_XOR = 18
+    DIV = 19
+    LOGICAL_AND = 20
+    LOGICAL_LEFT_SHIFT = 21
+    LOGICAL_RIGHT_SHIFT = 22
+    LOGICAL_OR = 23
+    LOGICAL_XOR = 24
+    MAXIMUM = 25
+    MINIMUM = 26
+    MUL = 27
+    POW = 28
+    SUB = 29
+    TABLE = 30
+    ABS = 31
+    BITWISE_NOT = 32
+    CEIL = 33
+    CLZ = 34
+    EXP = 35
+    FLOOR = 36
+    LOG = 37
+    LOGICAL_NOT = 38
+    NEGATE = 39
+    RECIPROCAL = 40
+    RSQRT = 41
+    SELECT = 42
+    EQUAL = 43
+    GREATER = 44
+    GREATER_EQUAL = 45
+    REDUCE_ANY = 46
+    REDUCE_ALL = 47
+    REDUCE_MAX = 48
+    REDUCE_MIN = 49
+    REDUCE_PRODUCT = 50
+    REDUCE_SUM = 51
+    CONCAT = 52
+    PAD = 53
+    RESHAPE = 54
+    REVERSE = 55
+    SLICE = 56
+    TILE = 57
+    TRANSPOSE = 58
+    GATHER = 59
+    SCATTER = 60
+    RESIZE = 61
+    CAST = 62
+    RESCALE = 63
+    CONST = 64
+    IDENTITY = 65
+    CUSTOM = 66
+    COND_IF = 67
+    WHILE_LOOP = 68
+
diff --git a/ethosu/vela/tosa/PadQuantInfo.py b/ethosu/vela/tosa/PadQuantInfo.py
new file mode 100644
index 0000000..825b72c
--- /dev/null
+++ b/ethosu/vela/tosa/PadQuantInfo.py
@@ -0,0 +1,30 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class PadQuantInfo(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsPadQuantInfo(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = PadQuantInfo()
+        x.Init(buf, n + offset)
+        return x
+
+    # PadQuantInfo
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # PadQuantInfo
+    def InputZp(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+        return 0
+
+def PadQuantInfoStart(builder): builder.StartObject(1)
+def PadQuantInfoAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0)
+def PadQuantInfoEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/Pool2dAttribute.py b/ethosu/vela/tosa/Pool2dAttribute.py
new file mode 100644
index 0000000..a45c1f1
--- /dev/null
+++ b/ethosu/vela/tosa/Pool2dAttribute.py
@@ -0,0 +1,94 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class Pool2dAttribute(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsPool2dAttribute(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = Pool2dAttribute()
+        x.Init(buf, n + offset)
+        return x
+
+    # Pool2dAttribute
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # Pool2dAttribute
+    def Padding(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # Pool2dAttribute
+    def PaddingAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # Pool2dAttribute
+    def PaddingLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # Pool2dAttribute
+    def Kernel(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # Pool2dAttribute
+    def KernelAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # Pool2dAttribute
+    def KernelLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # Pool2dAttribute
+    def Stride(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # Pool2dAttribute
+    def StrideAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # Pool2dAttribute
+    def StrideLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+def Pool2dAttributeStart(builder): builder.StartObject(3)
+def Pool2dAttributeAddPadding(builder, padding): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(padding), 0)
+def Pool2dAttributeStartPaddingVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def Pool2dAttributeAddKernel(builder, kernel): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(kernel), 0)
+def Pool2dAttributeStartKernelVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def Pool2dAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0)
+def Pool2dAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def Pool2dAttributeEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/QuantInfo.py b/ethosu/vela/tosa/QuantInfo.py
new file mode 100644
index 0000000..ffdfd32
--- /dev/null
+++ b/ethosu/vela/tosa/QuantInfo.py
@@ -0,0 +1,11 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+class QuantInfo(object):
+    NONE = 0
+    UnaryQuantInfo = 1
+    ConvQuantInfo = 2
+    MatMulQuantInfo = 3
+    PadQuantInfo = 4
+
diff --git a/ethosu/vela/tosa/ReluNAttribute.py b/ethosu/vela/tosa/ReluNAttribute.py
new file mode 100644
index 0000000..008b535
--- /dev/null
+++ b/ethosu/vela/tosa/ReluNAttribute.py
@@ -0,0 +1,38 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class ReluNAttribute(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsReluNAttribute(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = ReluNAttribute()
+        x.Init(buf, n + offset)
+        return x
+
+    # ReluNAttribute
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # ReluNAttribute
+    def MaxInt(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+        return 0
+
+    # ReluNAttribute
+    def MaxFp(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
+        return 0.0
+
+def ReluNAttributeStart(builder): builder.StartObject(2)
+def ReluNAttributeAddMaxInt(builder, maxInt): builder.PrependInt32Slot(0, maxInt, 0)
+def ReluNAttributeAddMaxFp(builder, maxFp): builder.PrependFloat32Slot(1, maxFp, 0.0)
+def ReluNAttributeEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/RescaleAttribute.py b/ethosu/vela/tosa/RescaleAttribute.py
new file mode 100644
index 0000000..1aa6707
--- /dev/null
+++ b/ethosu/vela/tosa/RescaleAttribute.py
@@ -0,0 +1,110 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class RescaleAttribute(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsRescaleAttribute(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = RescaleAttribute()
+        x.Init(buf, n + offset)
+        return x
+
+    # RescaleAttribute
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # RescaleAttribute
+    def InputZp(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+        return 0
+
+    # RescaleAttribute
+    def OutputZp(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+        return 0
+
+    # RescaleAttribute
+    def Multiplier(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # RescaleAttribute
+    def MultiplierAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # RescaleAttribute
+    def MultiplierLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # RescaleAttribute
+    def Shift(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # RescaleAttribute
+    def ShiftAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # RescaleAttribute
+    def ShiftLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # RescaleAttribute
+    def Scale32(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+        if o != 0:
+            return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
+        return False
+
+    # RescaleAttribute
+    def DoubleRound(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
+        if o != 0:
+            return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
+        return False
+
+    # RescaleAttribute
+    def PerChannel(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
+        if o != 0:
+            return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
+        return False
+
+def RescaleAttributeStart(builder): builder.StartObject(7)
+def RescaleAttributeAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0)
+def RescaleAttributeAddOutputZp(builder, outputZp): builder.PrependInt32Slot(1, outputZp, 0)
+def RescaleAttributeAddMultiplier(builder, multiplier): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(multiplier), 0)
+def RescaleAttributeStartMultiplierVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def RescaleAttributeAddShift(builder, shift): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(shift), 0)
+def RescaleAttributeStartShiftVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def RescaleAttributeAddScale32(builder, scale32): builder.PrependBoolSlot(4, scale32, 0)
+def RescaleAttributeAddDoubleRound(builder, doubleRound): builder.PrependBoolSlot(5, doubleRound, 0)
+def RescaleAttributeAddPerChannel(builder, perChannel): builder.PrependBoolSlot(6, perChannel, 0)
+def RescaleAttributeEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/ReshapeAttribute.py b/ethosu/vela/tosa/ReshapeAttribute.py
new file mode 100644
index 0000000..629b6c2
--- /dev/null
+++ b/ethosu/vela/tosa/ReshapeAttribute.py
@@ -0,0 +1,46 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class ReshapeAttribute(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsReshapeAttribute(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = ReshapeAttribute()
+        x.Init(buf, n + offset)
+        return x
+
+    # ReshapeAttribute
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # ReshapeAttribute
+    def Shape(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # ReshapeAttribute
+    def ShapeAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # ReshapeAttribute
+    def ShapeLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+def ReshapeAttributeStart(builder): builder.StartObject(1)
+def ReshapeAttributeAddShape(builder, shape): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0)
+def ReshapeAttributeStartShapeVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def ReshapeAttributeEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/ResizeAttribute.py b/ethosu/vela/tosa/ResizeAttribute.py
new file mode 100644
index 0000000..89156d1
--- /dev/null
+++ b/ethosu/vela/tosa/ResizeAttribute.py
@@ -0,0 +1,158 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class ResizeAttribute(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsResizeAttribute(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = ResizeAttribute()
+        x.Init(buf, n + offset)
+        return x
+
+    # ResizeAttribute
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # ResizeAttribute
+    def OutputSize(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # ResizeAttribute
+    def OutputSizeAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # ResizeAttribute
+    def OutputSizeLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # ResizeAttribute
+    def Stride(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # ResizeAttribute
+    def StrideAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # ResizeAttribute
+    def StrideLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # ResizeAttribute
+    def Offset(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # ResizeAttribute
+    def OffsetAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # ResizeAttribute
+    def OffsetLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # ResizeAttribute
+    def Shift(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+        return 0
+
+    # ResizeAttribute
+    def StrideFp(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Float32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # ResizeAttribute
+    def StrideFpAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Float32Flags, o)
+        return 0
+
+    # ResizeAttribute
+    def StrideFpLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # ResizeAttribute
+    def OffsetFp(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Float32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # ResizeAttribute
+    def OffsetFpAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Float32Flags, o)
+        return 0
+
+    # ResizeAttribute
+    def OffsetFpLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # ResizeAttribute
+    def Mode(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
+        return 0
+
+def ResizeAttributeStart(builder): builder.StartObject(7)
+def ResizeAttributeAddOutputSize(builder, outputSize): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(outputSize), 0)
+def ResizeAttributeStartOutputSizeVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def ResizeAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0)
+def ResizeAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def ResizeAttributeAddOffset(builder, offset): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(offset), 0)
+def ResizeAttributeStartOffsetVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def ResizeAttributeAddShift(builder, shift): builder.PrependInt32Slot(3, shift, 0)
+def ResizeAttributeAddStrideFp(builder, strideFp): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(strideFp), 0)
+def ResizeAttributeStartStrideFpVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def ResizeAttributeAddOffsetFp(builder, offsetFp): builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(offsetFp), 0)
+def ResizeAttributeStartOffsetFpVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def ResizeAttributeAddMode(builder, mode): builder.PrependUint32Slot(6, mode, 0)
+def ResizeAttributeEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/ResizeMode.py b/ethosu/vela/tosa/ResizeMode.py
new file mode 100644
index 0000000..65bcd5d
--- /dev/null
+++ b/ethosu/vela/tosa/ResizeMode.py
@@ -0,0 +1,9 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+class ResizeMode(object):
+    UNKNOWN = 0
+    NEAREST = 1
+    BILINEAR = 2
+
diff --git a/ethosu/vela/tosa/SliceAttribute.py b/ethosu/vela/tosa/SliceAttribute.py
new file mode 100644
index 0000000..d2f9958
--- /dev/null
+++ b/ethosu/vela/tosa/SliceAttribute.py
@@ -0,0 +1,70 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class SliceAttribute(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsSliceAttribute(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = SliceAttribute()
+        x.Init(buf, n + offset)
+        return x
+
+    # SliceAttribute
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # SliceAttribute
+    def Begin(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # SliceAttribute
+    def BeginAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # SliceAttribute
+    def BeginLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # SliceAttribute
+    def Size(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # SliceAttribute
+    def SizeAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # SliceAttribute
+    def SizeLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+def SliceAttributeStart(builder): builder.StartObject(2)
+def SliceAttributeAddBegin(builder, begin): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(begin), 0)
+def SliceAttributeStartBeginVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def SliceAttributeAddSize(builder, size): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(size), 0)
+def SliceAttributeStartSizeVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def SliceAttributeEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/TileAttribute.py b/ethosu/vela/tosa/TileAttribute.py
new file mode 100644
index 0000000..f1b721b
--- /dev/null
+++ b/ethosu/vela/tosa/TileAttribute.py
@@ -0,0 +1,46 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class TileAttribute(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsTileAttribute(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = TileAttribute()
+        x.Init(buf, n + offset)
+        return x
+
+    # TileAttribute
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # TileAttribute
+    def Multiples(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # TileAttribute
+    def MultiplesAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # TileAttribute
+    def MultiplesLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+def TileAttributeStart(builder): builder.StartObject(1)
+def TileAttributeAddMultiples(builder, multiples): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(multiples), 0)
+def TileAttributeStartMultiplesVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TileAttributeEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/TosaBasicBlock.py b/ethosu/vela/tosa/TosaBasicBlock.py
new file mode 100644
index 0000000..8f1604a
--- /dev/null
+++ b/ethosu/vela/tosa/TosaBasicBlock.py
@@ -0,0 +1,108 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class TosaBasicBlock(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsTosaBasicBlock(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = TosaBasicBlock()
+        x.Init(buf, n + offset)
+        return x
+
+    # TosaBasicBlock
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # TosaBasicBlock
+    def Name(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.String(o + self._tab.Pos)
+        return None
+
+    # TosaBasicBlock
+    def Operators(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            x = self._tab.Vector(o)
+            x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
+            x = self._tab.Indirect(x)
+            from .TosaOperator import TosaOperator
+            obj = TosaOperator()
+            obj.Init(self._tab.Bytes, x)
+            return obj
+        return None
+
+    # TosaBasicBlock
+    def OperatorsLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # TosaBasicBlock
+    def Tensors(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            x = self._tab.Vector(o)
+            x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
+            x = self._tab.Indirect(x)
+            from .TosaTensor import TosaTensor
+            obj = TosaTensor()
+            obj.Init(self._tab.Bytes, x)
+            return obj
+        return None
+
+    # TosaBasicBlock
+    def TensorsLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # TosaBasicBlock
+    def Inputs(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return ""
+
+    # TosaBasicBlock
+    def InputsLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # TosaBasicBlock
+    def Outputs(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return ""
+
+    # TosaBasicBlock
+    def OutputsLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+def TosaBasicBlockStart(builder): builder.StartObject(5)
+def TosaBasicBlockAddName(builder, name): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
+def TosaBasicBlockAddOperators(builder, operators): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(operators), 0)
+def TosaBasicBlockStartOperatorsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TosaBasicBlockAddTensors(builder, tensors): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(tensors), 0)
+def TosaBasicBlockStartTensorsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TosaBasicBlockAddInputs(builder, inputs): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(inputs), 0)
+def TosaBasicBlockStartInputsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TosaBasicBlockAddOutputs(builder, outputs): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(outputs), 0)
+def TosaBasicBlockStartOutputsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TosaBasicBlockEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/TosaGraph.py b/ethosu/vela/tosa/TosaGraph.py
new file mode 100644
index 0000000..f54a44a
--- /dev/null
+++ b/ethosu/vela/tosa/TosaGraph.py
@@ -0,0 +1,56 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class TosaGraph(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsTosaGraph(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = TosaGraph()
+        x.Init(buf, n + offset)
+        return x
+
+    # TosaGraph
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # TosaGraph
+    def Version(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            x = self._tab.Indirect(o + self._tab.Pos)
+            from .Version import Version
+            obj = Version()
+            obj.Init(self._tab.Bytes, x)
+            return obj
+        return None
+
+    # TosaGraph
+    def Blocks(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            x = self._tab.Vector(o)
+            x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
+            x = self._tab.Indirect(x)
+            from .TosaBasicBlock import TosaBasicBlock
+            obj = TosaBasicBlock()
+            obj.Init(self._tab.Bytes, x)
+            return obj
+        return None
+
+    # TosaGraph
+    def BlocksLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+def TosaGraphStart(builder): builder.StartObject(2)
+def TosaGraphAddVersion(builder, version): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(version), 0)
+def TosaGraphAddBlocks(builder, blocks): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(blocks), 0)
+def TosaGraphStartBlocksVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TosaGraphEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/TosaOperator.py b/ethosu/vela/tosa/TosaOperator.py
new file mode 100644
index 0000000..c45efda
--- /dev/null
+++ b/ethosu/vela/tosa/TosaOperator.py
@@ -0,0 +1,102 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class TosaOperator(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsTosaOperator(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = TosaOperator()
+        x.Init(buf, n + offset)
+        return x
+
+    # TosaOperator
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # TosaOperator
+    def Op(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
+        return 0
+
+    # TosaOperator
+    def AttributeType(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos)
+        return 0
+
+    # TosaOperator
+    def Attribute(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            from flatbuffers.table import Table
+            obj = Table(bytearray(), 0)
+            self._tab.Union(obj, o)
+            return obj
+        return None
+
+    # TosaOperator
+    def Inputs(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return ""
+
+    # TosaOperator
+    def InputsLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # TosaOperator
+    def Outputs(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return ""
+
+    # TosaOperator
+    def OutputsLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # TosaOperator
+    def QuantInfoType(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos)
+        return 0
+
+    # TosaOperator
+    def QuantInfo(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
+        if o != 0:
+            from flatbuffers.table import Table
+            obj = Table(bytearray(), 0)
+            self._tab.Union(obj, o)
+            return obj
+        return None
+
+def TosaOperatorStart(builder): builder.StartObject(7)
+def TosaOperatorAddOp(builder, op): builder.PrependUint32Slot(0, op, 0)
+def TosaOperatorAddAttributeType(builder, attributeType): builder.PrependUint8Slot(1, attributeType, 0)
+def TosaOperatorAddAttribute(builder, attribute): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(attribute), 0)
+def TosaOperatorAddInputs(builder, inputs): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(inputs), 0)
+def TosaOperatorStartInputsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TosaOperatorAddOutputs(builder, outputs): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(outputs), 0)
+def TosaOperatorStartOutputsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TosaOperatorAddQuantInfoType(builder, quantInfoType): builder.PrependUint8Slot(5, quantInfoType, 0)
+def TosaOperatorAddQuantInfo(builder, quantInfo): builder.PrependUOffsetTRelativeSlot(6, flatbuffers.number_types.UOffsetTFlags.py_type(quantInfo), 0)
+def TosaOperatorEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/TosaTensor.py b/ethosu/vela/tosa/TosaTensor.py
new file mode 100644
index 0000000..2a397db
--- /dev/null
+++ b/ethosu/vela/tosa/TosaTensor.py
@@ -0,0 +1,70 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class TosaTensor(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsTosaTensor(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = TosaTensor()
+        x.Init(buf, n + offset)
+        return x
+
+    # TosaTensor
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # TosaTensor
+    def Name(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.String(o + self._tab.Pos)
+        return None
+
+    # TosaTensor
+    def Shape(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # TosaTensor
+    def ShapeAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # TosaTensor
+    def ShapeLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # TosaTensor
+    def Type(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
+        return 0
+
+    # TosaTensor
+    def NpyFilename(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+        if o != 0:
+            return self._tab.String(o + self._tab.Pos)
+        return None
+
+def TosaTensorStart(builder): builder.StartObject(4)
+def TosaTensorAddName(builder, name): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
+def TosaTensorAddShape(builder, shape): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0)
+def TosaTensorStartShapeVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TosaTensorAddType(builder, type): builder.PrependUint32Slot(2, type, 0)
+def TosaTensorAddNpyFilename(builder, npyFilename): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(npyFilename), 0)
+def TosaTensorEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/TransposeConv2dAttribute.py b/ethosu/vela/tosa/TransposeConv2dAttribute.py
new file mode 100644
index 0000000..80f544f
--- /dev/null
+++ b/ethosu/vela/tosa/TransposeConv2dAttribute.py
@@ -0,0 +1,118 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class TransposeConv2dAttribute(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsTransposeConv2dAttribute(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = TransposeConv2dAttribute()
+        x.Init(buf, n + offset)
+        return x
+
+    # TransposeConv2dAttribute
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # TransposeConv2dAttribute
+    def Outpad(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # TransposeConv2dAttribute
+    def OutpadAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # TransposeConv2dAttribute
+    def OutpadLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # TransposeConv2dAttribute
+    def Stride(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # TransposeConv2dAttribute
+    def StrideAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # TransposeConv2dAttribute
+    def StrideLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # TransposeConv2dAttribute
+    def Dilation(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # TransposeConv2dAttribute
+    def DilationAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # TransposeConv2dAttribute
+    def DilationLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+    # TransposeConv2dAttribute
+    def OutputShape(self, j):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+        if o != 0:
+            a = self._tab.Vector(o)
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+        return 0
+
+    # TransposeConv2dAttribute
+    def OutputShapeAsNumpy(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+        if o != 0:
+            return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
+        return 0
+
+    # TransposeConv2dAttribute
+    def OutputShapeLength(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+        if o != 0:
+            return self._tab.VectorLen(o)
+        return 0
+
+def TransposeConv2dAttributeStart(builder): builder.StartObject(4)
+def TransposeConv2dAttributeAddOutpad(builder, outpad): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(outpad), 0)
+def TransposeConv2dAttributeStartOutpadVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TransposeConv2dAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0)
+def TransposeConv2dAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TransposeConv2dAttributeAddDilation(builder, dilation): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(dilation), 0)
+def TransposeConv2dAttributeStartDilationVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TransposeConv2dAttributeAddOutputShape(builder, outputShape): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(outputShape), 0)
+def TransposeConv2dAttributeStartOutputShapeVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def TransposeConv2dAttributeEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/UnaryQuantInfo.py b/ethosu/vela/tosa/UnaryQuantInfo.py
new file mode 100644
index 0000000..7111c6c
--- /dev/null
+++ b/ethosu/vela/tosa/UnaryQuantInfo.py
@@ -0,0 +1,38 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class UnaryQuantInfo(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsUnaryQuantInfo(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = UnaryQuantInfo()
+        x.Init(buf, n + offset)
+        return x
+
+    # UnaryQuantInfo
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # UnaryQuantInfo
+    def InputZp(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+        return 0
+
+    # UnaryQuantInfo
+    def OutputZp(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+        return 0
+
+def UnaryQuantInfoStart(builder): builder.StartObject(2)
+def UnaryQuantInfoAddInputZp(builder, inputZp): builder.PrependInt32Slot(0, inputZp, 0)
+def UnaryQuantInfoAddOutputZp(builder, outputZp): builder.PrependInt32Slot(1, outputZp, 0)
+def UnaryQuantInfoEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/Version.py b/ethosu/vela/tosa/Version.py
new file mode 100644
index 0000000..403add3
--- /dev/null
+++ b/ethosu/vela/tosa/Version.py
@@ -0,0 +1,54 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class Version(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsVersion(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = Version()
+        x.Init(buf, n + offset)
+        return x
+
+    # Version
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # Version
+    def _major(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+        return 0
+
+    # Version
+    def _minor(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+        return 22
+
+    # Version
+    def _patch(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+        if o != 0:
+            return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
+        return 0
+
+    # Version
+    def _experimental(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+        if o != 0:
+            return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
+        return False
+
+def VersionStart(builder): builder.StartObject(4)
+def VersionAdd_major(builder, Major): builder.PrependInt32Slot(0, Major, 0)
+def VersionAdd_minor(builder, Minor): builder.PrependInt32Slot(1, Minor, 22)
+def VersionAdd_patch(builder, Patch): builder.PrependInt32Slot(2, Patch, 0)
+def VersionAdd_experimental(builder, Experimental): builder.PrependBoolSlot(3, Experimental, 0)
+def VersionEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/WhileLoopAttribute.py b/ethosu/vela/tosa/WhileLoopAttribute.py
new file mode 100644
index 0000000..68655c9
--- /dev/null
+++ b/ethosu/vela/tosa/WhileLoopAttribute.py
@@ -0,0 +1,38 @@
+# automatically generated by the FlatBuffers compiler, do not modify
+
+# namespace: tosa
+
+import flatbuffers
+
+class WhileLoopAttribute(object):
+    __slots__ = ['_tab']
+
+    @classmethod
+    def GetRootAsWhileLoopAttribute(cls, buf, offset):
+        n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
+        x = WhileLoopAttribute()
+        x.Init(buf, n + offset)
+        return x
+
+    # WhileLoopAttribute
+    def Init(self, buf, pos):
+        self._tab = flatbuffers.table.Table(buf, pos)
+
+    # WhileLoopAttribute
+    def CondBranch(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
+        if o != 0:
+            return self._tab.String(o + self._tab.Pos)
+        return None
+
+    # WhileLoopAttribute
+    def BodyBranch(self):
+        o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+        if o != 0:
+            return self._tab.String(o + self._tab.Pos)
+        return None
+
+def WhileLoopAttributeStart(builder): builder.StartObject(2)
+def WhileLoopAttributeAddCondBranch(builder, condBranch): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(condBranch), 0)
+def WhileLoopAttributeAddBodyBranch(builder, bodyBranch): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(bodyBranch), 0)
+def WhileLoopAttributeEnd(builder): return builder.EndObject()
diff --git a/ethosu/vela/tosa/__init__.py b/ethosu/vela/tosa/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/ethosu/vela/tosa/__init__.py
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
new file mode 100644
index 0000000..94e6f99
--- /dev/null
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -0,0 +1,196 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Early optimisation of the TOSA based network graph, using the rewrite_graph module to do the traversal of the graph.
+from . import rewrite_graph
+from .api import NpuRoundingMode
+from .data_type import DataType
+from .debug_database import DebugDatabase
+from .graph_optimiser_util import needed_total_padding
+from .graph_optimiser_util import set_ifm_ofm_op_shapes
+from .graph_optimiser_util import set_tensor_equivalence
+from .operation import ExplicitScaling
+from .operation import NpuBlockType
+from .operation import Op
+from .operation import Padding
+
+
+def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
+    k_w, k_h = kernel.dilated_wh()
+    s_x, s_y = kernel.stride
+    ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
+    xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
+    left_pad, right_pad, top_pad, bottom_pad = explicit_padding
+
+    padding = (top_pad, left_pad, bottom_pad, right_pad)
+    skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
+    return padding, skirt
+
+
+def add_padding_fields(op, arch, nng):
+    if op.run_on_npu:
+        if "padding" in op.attrs:
+            input_shape = op.ifm_shapes[0]
+
+            if op.type == Op.Conv2DBackpropInputSwitchedBias:
+                # TODO not yet supported, but there will be need for separate handling
+                assert False
+            else:
+                padding, skirt = calc_padding_and_skirt(
+                    Padding.EXPLICIT, op.kernel, input_shape, op.attrs.get("padding"),
+                )
+
+            op.attrs["explicit_padding"] = padding
+            op.attrs["skirt"] = skirt
+
+    return op
+
+
+def rewrite_activation(op, arch, nng):
+    if not op.type.is_relu_op():
+        return op
+
+    ifm = op.ifm
+    prev_op = ifm.ops[0]
+
+    # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
+    fuseable = (
+        prev_op.run_on_npu
+        and prev_op.type.npu_block_type != NpuBlockType.Default
+        and len(ifm.ops) == 1
+        and len(prev_op.outputs[0].consumers()) == 1
+        and prev_op.activation is None
+    )
+    if not fuseable:
+        print("Warning: relu like op will not be possible to fuse, currently not supported")
+        assert False
+
+    zp = ifm.quantization.zero_point if ifm.quantization.zero_point else 0
+    if op.ofm.quantization.zero_point is None:
+        op.ofm.quantization.zero_point = zp
+
+    if op.type == Op.Clip:
+        op.attrs["min"] = op.attrs["min_int"] - zp
+        op.attrs["max"] = op.attrs["max_int"] - zp
+    elif op.type == Op.ReluN:
+        op.attrs["max"] = op.attrs["max_int"] - zp
+    else:
+        print("Warning: Unknown TOSA activation Op")
+        assert False
+
+    return op
+
+
+def rewrite_rescale(op, arch, nng):
+    if op.type == Op.Rescale:
+        ifm = op.ifm
+        ofm = op.ofm
+
+        # some error checking
+        assert len(ifm.ops) == 1
+        prev_op = ifm.ops[0]
+
+        # TODO currently not supported
+        assert prev_op.type not in (Op.Placeholder, Op.SubgraphInput, Op.Const)
+        assert len(ifm.consumer_list) == 1
+
+        input_zp = op.attrs["input_zp"]
+        output_zp = op.attrs["output_zp"]
+        multiplier = op.attrs["multiplier"]
+        shift = op.attrs["shift"]
+        scale32 = op.attrs["scale32"]
+        double_round = op.attrs["double_round"]
+        per_channel = op.attrs["per_channel"]
+
+        assert ifm.dtype in (DataType.uint8, DataType.int8, DataType.int32)
+        assert ifm.dtype in (DataType.uint8, DataType.int8) or input_zp == 0
+        assert ofm.dtype in (DataType.uint8, DataType.int8) or output_zp == 0
+        assert (scale32 and ifm.dtype != DataType.int48) or (not scale32 and not double_round)
+
+        # Check that input tensor has the same zp or no zp
+        ifm_zp = ifm.quantization.zero_point
+        if ifm_zp is not None and ifm_zp != input_zp:
+            print("Error (fuse_rescale): zp of tensors producer/consumer differs unexpectedidly ")
+            assert False
+        ifm.quantization.zero_point = input_zp
+
+        if not scale32:
+            double_round = False
+
+        if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected:
+            assert len(multiplier) == len(shift) == len(prev_op.bias.values)
+
+            if ifm.dtype == DataType.int32 and per_channel:
+                for s, m in zip(shift, multiplier):
+                    # TODO these are the TOSA limitations
+                    assert m >= 0
+                    assert 2 <= s <= 62
+                    # TODO these are the HW limitations
+                    assert 0 <= s < (1 << 6)
+                prev_op.explicit_scaling = ExplicitScaling(per_channel, shift, multiplier)
+                ofm.quantization.zero_point = output_zp
+
+                if double_round:
+                    prev_op.rounding_mode = NpuRoundingMode.TFL
+                else:
+                    prev_op.rounding_mode = NpuRoundingMode.NATURAL
+
+                # Bypass op
+                prev_op.set_output_tensor(ofm)
+                DebugDatabase.add_optimised(op, prev_op)
+                return op
+            else:
+                print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
+                assert False
+
+        else:
+            print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
+            assert False
+    return op
+
+
+def supported_operator_check(op, arch, nng):
+    op.run_on_npu = arch.tosa_supported_operators.is_operator_supported(op)
+    return op
+
+
+def tosa_optimise_graph(nng, arch):
+    # Pre-processing step
+    pre_process_list = [
+        supported_operator_check,
+        set_ifm_ofm_op_shapes,
+    ]
+
+    for idx, sg in enumerate(nng.subgraphs):
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
+        )
+
+    # Rewite Operators step
+    op_rewrite_list = [set_tensor_equivalence, rewrite_rescale]
+
+    for idx, sg in enumerate(nng.subgraphs):
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
+        )
+
+    # Post-processing step
+    for idx, sg in enumerate(nng.subgraphs):
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng, sg, arch, [], [rewrite_activation, add_padding_fields],
+        )
+
+    return nng
diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py
new file mode 100644
index 0000000..82f61f7
--- /dev/null
+++ b/ethosu/vela/tosa_mapping.py
@@ -0,0 +1,325 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# TOSA mapping functions used by reader.
+# Contains a mapping from the various TOSA enums and options structs, generated by the FlatBuffer code
+# generator, to Vela's internal format.
+from .data_type import DataType
+from .operation import Op
+from .operation import TensorIndices
+from .tosa import ArithmeticRightShiftAttribute  # noqa: F401
+from .tosa import AxisAttribute  # noqa: F401
+from .tosa import ClampAttribute  # noqa: F401
+from .tosa import CondIfAttribute  # noqa: F401
+from .tosa import Conv2dAttribute  # noqa: F401
+from .tosa import ConvQuantInfo  # noqa: F401
+from .tosa import MatMulQuantInfo  # noqa: F401
+from .tosa import MulAttribute  # noqa: F401
+from .tosa import PadQuantInfo  # noqa: F401
+from .tosa import Pool2dAttribute  # noqa: F401
+from .tosa import ReluNAttribute  # noqa: F401
+from .tosa import RescaleAttribute  # noqa: F401
+from .tosa import ReshapeAttribute  # noqa: F401
+from .tosa import ResizeAttribute  # noqa: F401
+from .tosa import SliceAttribute  # noqa: F401
+from .tosa import TileAttribute  # noqa: F401
+from .tosa import TransposeConv2dAttribute  # noqa: F401
+from .tosa import UnaryQuantInfo  # noqa: F401
+from .tosa import WhileLoopAttribute  # noqa: F401
+from .tosa.DType import DType
+from .tosa.Op import Op as TosaOp
+
+
+datatype_map = {
+    DType.BOOL: DataType.bool,
+    DType.UINT8: DataType.uint8,
+    DType.INT4: DataType.int4,
+    DType.INT8: DataType.int8,
+    DType.INT16: DataType.int16,
+    DType.INT32: DataType.int32,
+    DType.INT48: DataType.int48,
+    DType.FLOAT: DataType.float32,
+}
+
+
+# TODO duplicate of tflite_mapping
+def underscore_to_camel_case(s):
+    return "".join(x.title() for x in s.split("_"))
+
+
+# TODO duplicate of tflite_mapping
+def identity(x):
+    return x
+
+
+class AttrSerializer:
+    def __init__(self, name, members=None):
+        self.name = name
+        self.module = globals()[self.name]
+        self.cls = getattr(self.module, self.name)
+        self.members = []
+        if members is not None:
+            for mem in members:
+                deserialize = identity
+                is_vector = False
+                if isinstance(mem, tuple):
+                    if len(mem) == 2:
+                        mem, is_vector = mem
+                        deserialize = tuple
+                    else:
+                        assert 0
+                underscore_mem = mem
+                camelcase_mem = underscore_to_camel_case(mem)
+                self.members.append((underscore_mem, camelcase_mem, deserialize, is_vector))
+
+    def deserialize(self, op_data):
+        attr_type = op_data.AttributeType()
+        attr = op_data.Attribute()
+        attrs = {}
+        if attr_type:
+            tosa_attrs = self.cls()
+            tosa_attrs.Init(attr.Bytes, attr.Pos)
+            for underscore_mem, camelcase_mem, deserialize, is_vector in self.members:
+                fun = camelcase_mem
+                if is_vector:
+                    fun += "AsNumpy"
+
+                attr = getattr(tosa_attrs, fun)()
+                try:
+                    attrs[underscore_mem] = deserialize(attr)
+                except TypeError:
+                    print("Warning: {0} could not read attribute '{1}'.".format(self.name, underscore_mem))
+
+        return attrs
+
+
+class QuantSerializer:
+    def __init__(self, name, members=None):
+        self.name = name
+        self.module = globals()[self.name]
+        self.cls = getattr(self.module, self.name)
+        self.members = []
+        if members is not None:
+            for mem in members:
+                deserialize = identity
+                underscore_mem = mem
+                camelcase_mem = underscore_to_camel_case(mem)
+                self.members.append((underscore_mem, camelcase_mem, deserialize))
+
+    def deserialize(self, op_data):
+        quant_info_type = op_data.QuantInfoType()
+        quant_info = op_data.QuantInfo()
+        quant = {}
+        if quant_info_type:
+            tosa_quant = self.cls()
+            tosa_quant.Init(quant_info.Bytes, quant_info.Pos)
+            for underscore_mem, camelcase_mem, deserialize in self.members:
+                attr = getattr(tosa_quant, camelcase_mem)()
+                try:
+                    quant[underscore_mem] = deserialize(attr)
+                except TypeError:
+                    print("Warning: {0} could not read quant info '{1}'.".format(self.name, underscore_mem))
+
+        return quant
+
+
+is_vec = True
+pool2d_attrs = AttrSerializer("Pool2dAttribute", (("padding", is_vec), ("kernel", is_vec), ("stride", is_vec)))
+conv2d_attrs = AttrSerializer("Conv2dAttribute", (("padding", is_vec), ("stride", is_vec), ("dilation", is_vec)))
+transpose_conv2d_attrs = AttrSerializer(
+    "TransposeConv2dAttribute", (("outpad", is_vec), ("stride", is_vec), ("dilation", is_vec), ("out_shape", is_vec))
+)
+relun_attrs = AttrSerializer("ReluNAttribute", ("max_int"))
+axis_attrs = AttrSerializer("AxisAttribute", ("axis"))
+reshape_attrs = AttrSerializer("ReshapeAttribute", (("shape", is_vec),))
+slice_attrs = AttrSerializer("SliceAttribute", (("begin", is_vec), ("size", is_vec)))
+tile_attrs = AttrSerializer("TileAttribute", (("multiplies", is_vec),))
+resize_attrs = AttrSerializer(
+    "ResizeAttribute", (("output_size", is_vec), ("stride", is_vec), ("offset", is_vec), ("shift"))
+)
+clamp_attrs = AttrSerializer("ClampAttribute", (("min_int"), ("max_int")))
+rescale_attrs = AttrSerializer(
+    "RescaleAttribute",
+    ("input_zp", "output_zp", ("multiplier", is_vec), ("shift", is_vec), "scale32", "double_round", "per_channel"),
+)
+mul_attrs = AttrSerializer("MulAttribute", ("shift"))
+ars_attrs = AttrSerializer("ArithmeticRightShiftAttribute", ("round",))
+condif_attrs = AttrSerializer("CondIfAttribute", (("then_branch"), ("else_branch")))  # TODO these are references
+while_attrs = AttrSerializer("WhileLoopAttribute", (("cond_branch"), ("body_branch")))  # TODO these are references
+
+unary_quant_info = QuantSerializer("UnaryQuantInfo", ("input_zp", "output_zp"))
+conv_quant_info = QuantSerializer("ConvQuantInfo", ("input_zp", "weight_zp"))
+matmul_quant_info = QuantSerializer("MatMulQuantInfo", ("a_zp", "b_zp"))
+pad_quant_info = QuantSerializer("PadQuantInfo", ("input_zp"))
+
+unsupported_tosa_operators = {
+    TosaOp.UNKNOWN,
+    TosaOp.ARGMAX,
+    TosaOp.CONV3D,
+    TosaOp.MATMUL,
+    TosaOp.TRANSPOSE_CONV2D,
+    TosaOp.SIGMOID,
+    TosaOp.TANH,
+    TosaOp.BITWISE_AND,
+    TosaOp.BITWISE_OR,
+    TosaOp.BITWISE_XOR,
+    TosaOp.DIV,
+    TosaOp.LOGICAL_AND,
+    TosaOp.LOGICAL_LEFT_SHIFT,
+    TosaOp.LOGICAL_RIGHT_SHIFT,
+    TosaOp.LOGICAL_OR,
+    TosaOp.LOGICAL_XOR,
+    TosaOp.MAXIMUM,
+    TosaOp.MINIMUM,
+    TosaOp.MUL,
+    TosaOp.POW,
+    TosaOp.TABLE,
+    TosaOp.ABS,
+    TosaOp.BITWISE_NOT,
+    TosaOp.CEIL,
+    TosaOp.CLZ,
+    TosaOp.EXP,
+    TosaOp.FLOOR,
+    TosaOp.LOG,
+    TosaOp.LOGICAL_NOT,
+    TosaOp.NEGATE,
+    TosaOp.RECIPROCAL,
+    TosaOp.RSQRT,
+    TosaOp.SELECT,
+    TosaOp.EQUAL,
+    TosaOp.GREATER,
+    TosaOp.GREATER_EQUAL,
+    TosaOp.REDUCE_ANY,
+    TosaOp.REDUCE_ALL,
+    TosaOp.REDUCE_MAX,
+    TosaOp.REDUCE_MIN,
+    TosaOp.REDUCE_PRODUCT,
+    TosaOp.REDUCE_SUM,
+    TosaOp.CONCAT,
+    TosaOp.PAD,
+    TosaOp.RESHAPE,
+    TosaOp.REVERSE,
+    TosaOp.SLICE,
+    TosaOp.TILE,
+    TosaOp.TRANSPOSE,
+    TosaOp.GATHER,
+    TosaOp.SCATTER,
+    TosaOp.RESIZE,
+    TosaOp.CAST,
+    TosaOp.IDENTITY,
+    TosaOp.CUSTOM,
+    TosaOp.COND_IF,
+    TosaOp.WHILE_LOOP,
+}
+
+
+TOSA_NO_INDICES = TensorIndices([], [], [])
+TOSA_IFM_INDICES = TensorIndices([0], [], [])
+# TOSA_IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
+TOSA_IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
+TOSA_IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
+# TOSA_CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
+# TOSA_TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
+# TOSA_CONCAT_INDICES = TensorIndices([1, 2], [], [])
+# TOSA_SPLIT_IFM_INDICES = TensorIndices([1], [], [])
+# TOSA_BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
+
+
+tosa_operator_map = {
+    # TosaOp.UNKNOWN: (),
+    # TODO TosaOp.ARGMAX: (Op.ArgMax, axis_attrs, None),
+    TosaOp.AVG_POOL2D: (Op.AvgPool, pool2d_attrs, unary_quant_info, TOSA_IFM_INDICES),
+    TosaOp.CONV2D: (Op.Conv2DBias, conv2d_attrs, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES),
+    # TODO TosaOp.CONV3D:
+    TosaOp.DEPTHWISE_CONV2D: (Op.DepthwiseConv2DBias, conv2d_attrs, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES),
+    TosaOp.FULLY_CONNECTED: (Op.FullyConnected, None, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES),
+    # TODO TosaOp.MATMUL:
+    TosaOp.MAX_POOL2D: (Op.MaxPool, pool2d_attrs, None, TOSA_IFM_INDICES),
+    # TODO TosaOp.TRANSPOSE_CONV2D: (Op.Conv2DBackpropInput, transpose_conv2d_attrs, conv_quant_info)
+    TosaOp.CLAMP: (Op.Clip, clamp_attrs, None, TOSA_IFM_INDICES),
+    TosaOp.RELUN: (Op.ReluN, relun_attrs, None, TOSA_IFM_INDICES),
+    # TODO TosaOp.SIGMOID
+    # TODO TosaOp.TANH
+    TosaOp.ADD: (Op.Add, None, None, TOSA_IFM_IFM2_INDICES),
+    TosaOp.ARITHMETIC_RIGHT_SHIFT: (Op.SHR, ars_attrs, None, TOSA_IFM_IFM2_INDICES),
+    # TODO TosaOp.BITWISE_AND
+    # TODO TosaOp.BITWISE_OR
+    # TODO TosaOp.BITWISE_XOR
+    # TODO TosaOp.DIV
+    # TODO TosaOp.LOGICAL_AND
+    # TODO TosaOp.LOGICAL_LEFT_SHIFT
+    # TODO TosaOp.LOGICAL_RIGHT_SHIFT
+    # TODO TosaOp.LOGICAL_OR
+    # TODO TosaOp.LOGICAL_XOR
+    # TODO TosaOp.MAXIMUM
+    # TODO TosaOp.MINIMUM
+    # TODO TosaOp.MUL
+    # TODO TosaOp.POW
+    TosaOp.SUB: (Op.Sub, None, None, TOSA_IFM_IFM2_INDICES),
+    # TODO TosaOp.TABLE
+    # TODO TosaOp.ABS
+    # TODO TosaOp.BITWISE_NOT
+    # TODO TosaOp.CEIL
+    # TODO TosaOp.CLZ
+    # TODO TosaOp.EXP
+    # TODO TosaOp.FLOOR
+    # TODO TosaOp.LOG
+    # TODO TosaOp.LOGICAL_NOT
+    # TODO TosaOp.NEGATE
+    # TODO TosaOp.RECIPROCAL
+    # TODO TosaOp.RSQRT
+    # TODO TosaOp.SELECT
+    # TODO TosaOp.EQUAL
+    # TODO TosaOp.GREATER
+    # TODO TosaOp.GREATER_EQUAL
+    # TODO TosaOp.REDUCE_ANY
+    # TODO TosaOp.REDUCE_ALL
+    # TODO TosaOp.REDUCE_MAX
+    # TODO TosaOp.REDUCE_MIN
+    # TODO TosaOp.REDUCE_PRODUCT
+    # TODO TosaOp.REDUCE_SUM
+    # TODO TosaOp.CONCAT
+    # TODO TosaOp.PAD
+    # TODO TosaOp.RESHAPE
+    # TODO TosaOp.REVERSE
+    # TODO TosaOp.SLICE
+    # TODO TosaOp.TILE
+    # TODO TosaOp.TRANSPOSE
+    # TODO TosaOp.GATHER
+    # TODO TosaOp.SCATTER
+    # TODO TosaOp.RESIZE
+    # TODO TosaOp.CAST
+    TosaOp.RESCALE: (Op.Rescale, rescale_attrs, None, TOSA_IFM_INDICES),
+    TosaOp.CONST: (Op.Const, None, None, TOSA_NO_INDICES),
+    # TODO TosaOp.IDENTITY
+    # TODO TosaOp.CUSTOM
+    # TODO TosaOp.COND_IF
+    # TODO TosaOp.WHILE_LOOP
+}
+
+tosa_operator_inv_map = {v[0]: (k, v[1]) for k, v in tosa_operator_map.items()}
+
+
+def tosa_type_name(builtin):
+    return next(k for k, v in vars(TosaOp).items() if v == builtin)
+
+
+# TODO will return UNKNOWN for the once that have not yet been defined in tosa_operator_map
+def optype_to_tosa_op_type(op_type):
+    if op_type in tosa_operator_inv_map:
+        return tosa_type_name(tosa_operator_inv_map[op_type][0])
+    else:
+        return TosaOp.UNKNOWN
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
new file mode 100644
index 0000000..ac0b396
--- /dev/null
+++ b/ethosu/vela/tosa_reader.py
@@ -0,0 +1,259 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Functions used to read from a TOSA format file.
+import os.path
+import struct
+import sys
+
+import numpy as np
+
+from .nn_graph import Graph
+from .nn_graph import Subgraph
+from .operation import Op
+from .operation import Operation
+from .reader_util import clone_and_reshape_tensor
+from .reader_util import decode_str
+from .reader_util import fixup_tensors
+from .tensor import QuantizationParameters
+from .tensor import Tensor
+from .tflite_mapping import DataType
+from .tosa.TosaGraph import TosaGraph as TG
+from .tosa_mapping import datatype_map
+from .tosa_mapping import tosa_operator_map
+from .tosa_mapping import unsupported_tosa_operators
+
+
+class TosaSubgraph:
+    def __init__(self, file_path, graph, block):
+        self.graph = graph
+        self.name = decode_str(block.Name())
+
+        self.tensors = []
+        for idx in range(block.TensorsLength()):
+            self.tensors.append(self.parse_tensor(block.Tensors(idx), file_path))
+
+        for idx in range(block.OperatorsLength()):
+            self.parse_operator(idx, block.Operators(idx))
+
+        # Get the subgraph inputs and outputs
+        self.inputs = self.get_sg_inputs_remove_duplicates(block)
+        self.outputs = self.get_sg_outputs_remove_duplicates(block)
+        fixup_tensors(self.inputs, self.tensors)
+
+    def get_sg_inputs_remove_duplicates(self, block):
+        inputs = []
+        for idx in range(block.InputsLength()):
+            tens_data = block.Inputs(idx)
+            self.add_not_duplicate(tens_data, inputs, "input")
+        return inputs
+
+    def get_sg_outputs_remove_duplicates(self, block):
+        outputs = []
+        for idx in range(block.OutputsLength()):
+            tens_data = block.Outputs(idx)
+            self.add_not_duplicate(tens_data, outputs, "output")
+        return outputs
+
+    def add_not_duplicate(self, tens_data, tensors, warning_str):
+        name = decode_str(tens_data)
+        tensor = self.get_tensor_by_name(name)
+        if tensor not in tensors:
+            tensors.append(tensor)
+        else:
+            print(f"Warning: Subgraph {warning_str} tensor ({tensor}) already seen. Removing the duplicate.")
+
+    def get_tensor_by_name(self, name):
+        for tens in self.tensors:
+            if tens.name == name:
+                return tens
+        return None
+
+    def parse_operator(self, op_index, op_data):
+        op_code = op_data.Op()
+        if op_code in unsupported_tosa_operators:
+            print("Unsupported Operator", op_code)
+            assert False
+
+        op_type, attr_serializer, quant_serializer, indices = tosa_operator_map[op_code]
+        inputs = []
+        outputs = []
+        for idx in range(op_data.InputsLength()):
+            input_tens = self.get_tensor_by_name(decode_str(op_data.Inputs(idx)))
+            inputs.append(input_tens)
+            assert input_tens is not None
+
+        for idx in range(op_data.OutputsLength()):
+            output_tens = self.get_tensor_by_name(decode_str(op_data.Outputs(idx)))
+            outputs.append(output_tens)
+            assert output_tens is not None
+
+        name = "unknown_op_name"
+        if len(outputs):
+            name = outputs[0].name
+        op = Operation(op_type, name)
+        op.type.info.indices = indices
+        op.op_index = op_index
+        op.inputs = inputs
+        op.outputs = outputs
+
+        for out in op.outputs:
+            out.ops = [op]
+
+        # TODO Transpose_conv and conv3d
+        if op.type.is_depthwise_conv2d_op() or op.type.is_conv2d_op() or op.type == Op.FullyConnected:
+            if inputs[1].values is not None:
+                if op.type == Op.FullyConnected:
+                    inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0), False)
+                elif op.type.is_conv2d_op():
+                    inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0), False)
+                elif op.type.is_depthwise_conv2d_op():
+                    inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 0, 3), False)
+            if op.type.needs_bias() and len(inputs) <= op_type.info.indices.biases[0]:
+                # No Bias tensor
+                inputs.append(None)
+            if inputs[-1] and inputs[-1].values is not None:
+                # Since bias tensor is used for both bias and scale,
+                # a clone with a unique equivalence_id is needed
+                inputs[-1] = clone_and_reshape_tensor(inputs[-1], (0,), True)
+
+        if attr_serializer is not None:
+            op.attrs = attr_serializer.deserialize(op_data)
+
+            if "dilation" in op.attrs:
+                dilation = op.attrs["dilation"]
+                if len(dilation) == 2:
+                    op.attrs["dilation"] = (1, dilation[0], dilation[1], 1)
+                elif len(dilation) == 3:
+                    # TODO CONV3D more to be done....
+                    op.attrs["dilation"] = (dilation[0], dilation[1], dilation[2], 1)
+            if "kernel" in op.attrs:
+                kernel = op.attrs["kernel"]
+                if len(kernel) == 2:
+                    op.attrs["ksize"] = (1, kernel[0], kernel[1], 1)
+                else:
+                    # TODO CONV3D more to be done....
+                    print("Unsupported kernel dimensions: ", len(kernel))
+                    assert False
+
+        if quant_serializer is not None:
+            quant_info = quant_serializer.deserialize(op_data)
+
+            # TODO tensor zero points currently set here
+            # zero points part of Rescale operation, handled in tosa_graph_optimizer
+            if "input_zp" in quant_info:
+                self.set_tensor_zp(op.ifm, quant_info["input_zp"])
+            if "weight_zp" in quant_info:
+                self.set_tensor_zp(op.weights, quant_info["weight_zp"])
+            if "ouput_zp" in quant_info:
+                self.set_tensor_zp(op.ofm, quant_info["output_zp"])
+            if "a_zp" in quant_info:
+                self.set_tensor_zp(op.ifm, quant_info["a_zp"])
+            if "b_zp" in quant_info:
+                self.set_tensor_zp(op.ifm2, quant_info["b_zp"])
+
+    def parse_tensor(self, tens_data, file_path):
+        name = decode_str(tens_data.Name())
+        np_shape = tens_data.ShapeAsNumpy()
+        shape = list(np_shape) if type(np_shape) is np.ndarray else []
+        tens_dtype = tens_data.Type()
+        dtype = datatype_map[tens_dtype]
+
+        tens = Tensor(shape, dtype, name)
+
+        # Initialize quantization parameters
+        tens.quantization = QuantizationParameters()
+
+        tens.quantization.scale_f32 = 1.0
+        if dtype == DataType.uint8:
+            tens.quantization.quant_min = 0
+            tens.quantization.quant_max = (1 << dtype.bits) - 1
+        elif dtype in (DataType.int8, DataType.int16, DataType.int32, DataType.int64):
+            tens.quantization.quant_min = -(1 << (dtype.bits - 1))
+            tens.quantization.quant_max = (1 << (dtype.bits - 1)) - 1
+
+        tens.values = None
+        if tens_data.NpyFilename() is not None:
+            try:
+                fname = decode_str(tens_data.NpyFilename())
+                tens.values = np.load(os.path.join(file_path, fname))
+                assert list(tens.values.shape) == tens.shape
+                tens.quant_values = tens.values
+            except (struct.error, TypeError, RuntimeError) as e:
+                print(f'Error: Invalid npy file. Got "{e}" ')
+                sys.exit(1)
+
+        return tens
+
+    def set_tensor_zp(self, tens, zp):
+        if tens.quantization.zero_point is None:
+            tens.quantization.zero_point = zp
+        elif tens.quantization.zero_point != zp:
+            print(f"Error: Setting tensor zp not possible, tensor already has different zero point")
+            assert False
+
+
+class TosaGraph:
+    def __init__(self, filename, batch_size, feed_dict, output_node_names, initialisation_nodes):
+
+        self.op_times = {}
+        if batch_size is None:
+            batch_size = 1
+        self.batch_size = batch_size
+        self.name = os.path.splitext(os.path.basename(filename))[0]
+        self.initialisation_nodes = initialisation_nodes
+
+        with open(filename, "rb") as f:
+            buf = bytearray(f.read())
+
+        try:
+            parsing_step = "parsing root"
+            tosa_graph = TG.GetRootAsTosaGraph(buf, 0)
+
+            parsing_step = "parsing version"
+            self.check_version(tosa_graph)
+
+            parsing_step = "parsing blocks length"
+            file_path = os.path.dirname(filename)
+            self.subgraphs = []
+            for b_idx in range(tosa_graph.BlocksLength()):
+                parsing_step = f"parsing block {b_idx}"
+                self.subgraphs.append(TosaSubgraph(file_path, self, tosa_graph.Blocks(b_idx)))
+
+            self.nng = Graph(self.name, self.batch_size)
+            for tosa_sg in self.subgraphs:
+                sg = Subgraph(tosa_sg.name)
+                sg.original_inputs = tosa_sg.inputs  # Preserve the original input order
+                sg.output_tensors = tosa_sg.outputs
+                self.nng.subgraphs.append(sg)
+
+        except (struct.error, TypeError, RuntimeError) as e:
+            print(f'Error: Invalid .tosa file. Got "{e}" while {parsing_step}.')
+            sys.exit(1)
+
+    def check_version(self, tosa_graph):
+        version = tosa_graph.Version()
+        version_str = f"{version._major()}.{version._minor()}.{version._patch()}"
+        if version_str != "0.22.0":
+            print(f"Unsupported TOSA version: {version_str}")
+            assert False
+
+
+def read_tosa(filename, batch_size, feed_dict, output_node_names, initialisation_nodes):
+    tosa_graph = TosaGraph(filename, batch_size, feed_dict, output_node_names, initialisation_nodes)
+    nng = tosa_graph.nng
+    nng.refresh_after_modification()
+    return nng
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
new file mode 100644
index 0000000..c87d653
--- /dev/null
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -0,0 +1,85 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# The TosaSupportedOperators class which is a collection of all supported operators and parameter checks.
+from collections import defaultdict
+
+from .data_type import DataType
+from .operation import Op
+from .supported_operators_util import docstring_format_args
+from .supported_operators_util import list_formatter
+from .tosa_mapping import optype_to_tosa_op_type
+
+
+class TosaSupportedOperators:
+    # TODO currently sparsely populated
+    # Categorised lists of supported operators
+    convolution_ops = set((Op.Conv2DBias,))
+    convolution_like_ops = convolution_ops
+    mac_main_ops = convolution_like_ops
+
+    type_conversion_ops = set((Op.Rescale,))
+    relu_ops = set((Op.Clip, Op.ReluN,))
+    activation_ops = relu_ops
+
+    npu_post_ops = activation_ops
+    supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops
+
+    # Supported data types
+    # TODO will differ compared to TensorFlow Lite, currently set to the same
+    supported_op_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
+
+    def __init__(self):
+        # Setup the generic constraints. Note: the order matters
+        self.generic_constraints = []
+        self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dtype)
+
+        # Setup specific constraints. Note: the order matters
+        self.specific_constraints = defaultdict(list)
+
+    def is_operator_supported(self, op):
+        ext_type = optype_to_tosa_op_type(op.type)
+        if op.type not in TosaSupportedOperators.supported_operators:
+            if op.type not in (Op.Placeholder, Op.SubgraphInput, Op.Const):
+                print(f"Info: {ext_type} '{op.name}' is not a NPU op")
+            return False
+
+        for constraint in self.generic_constraints + self.specific_constraints[op.type]:
+            valid, extra = constraint(op)
+            if not valid:
+                print(f"Warning: {ext_type} '{op.name}' is not supported on the NPU")
+                print(f" - {constraint.__doc__}")
+                if extra:
+                    print(f"   {extra}")
+                return False
+
+        return True
+
+    # TODO this function is the same for TensorFlow Lite, but input might differ
+    @classmethod
+    @docstring_format_args([list_formatter(supported_op_dtypes)])
+    def constraint_tens_dtype(cls, op):
+        "Tensors must be of type: {}"
+        valid = True
+        extra = []
+        tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+        if not tensors:
+            tensors = [tens for tens in op.inputs if tens]
+        for tens in tensors:
+            if tens.dtype not in cls.supported_op_dtypes:
+                valid = False
+                extra.append(f"Tensor '{tens.name}' has data type: {tens.dtype}")
+        return valid, ", ".join(extra)
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index f552b21..ecdc7aa 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -54,7 +54,7 @@
     output_basename = os.path.join(compiler_options.output_dir, os.path.splitext(os.path.basename(input_name))[0])
     DebugDatabase.show_warnings = enable_debug_db
 
-    nng = model_reader.read_model(input_name, model_reader_options)
+    nng, network_type = model_reader.read_model(input_name, model_reader_options)
 
     if not nng:
         raise InputFileError(input_name, "Input file could not be read")
@@ -67,7 +67,7 @@
         print("Model reading took %f s" % (stop - start))
         start = time.time()
 
-    compiler_driver.compiler_driver(nng, arch, compiler_options, scheduler_options)
+    compiler_driver.compiler_driver(nng, arch, compiler_options, scheduler_options, network_type)
 
     summary_csv_file = "{0}_summary_{1}.csv".format(output_basename, arch.system_config)
     stats_writer.write_summary_metrics_csv(nng, summary_csv_file, arch)
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 4ba3dee..7e33e93 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -203,7 +203,7 @@
     return ohwi[core : ohwi.shape[0] : ncores]
 
 
-def _prepare_scale_and_bias(arch, tens, rescale_for_faf):
+def _prepare_scale_and_bias(arch, tens, rescale_for_faf, explicit_scaling):
     assert tens.purpose in [TensorPurpose.FeatureMap, TensorPurpose.FSBias]
     assert tens.format == TensorFormat.NHWC
     # the connected operator should expect a bias input unless it is a FullyConnected
@@ -260,11 +260,15 @@
         else:
             raise UnsupportedFeatureError(f"Compression of {ifm_dtype} is not implemented; Tensor: '{tens.name}'")
 
-    # quantise all of the weight scales into (scale_factor, shift)
-    if ifm_dtype == DataType.int16:
-        quantised_scales = [reduced_quantise_scale(scale) for scale in scales]
+    if explicit_scaling:
+        assert len(explicit_scaling.shift) == len(explicit_scaling.multiplier)
+        quantised_scales = [(int(m), int(s)) for s, m in zip(explicit_scaling.shift, explicit_scaling.multiplier)]
     else:
-        quantised_scales = [quantise_scale(scale) for scale in scales]
+        # quantise all of the weight scales into (scale_factor, shift)
+        if ifm_dtype == DataType.int16:
+            quantised_scales = [reduced_quantise_scale(scale) for scale in scales]
+        else:
+            quantised_scales = [quantise_scale(scale) for scale in scales]
 
     # If only 1 quantised scale is used, repeat that value for the length of the biases
     if len(quantised_scales) == 1:
@@ -355,7 +359,7 @@
 
     # Bias & scale
     if do_scales:
-        quantised_scales, biases = _prepare_scale_and_bias(arch, scale_tens, rescale_for_faf)
+        quantised_scales, biases = _prepare_scale_and_bias(arch, scale_tens, rescale_for_faf, op.explicit_scaling)
         scale_tens.element_size_bytes = 10
 
     # Slice the weight stream up depth-ways into bricks and compress