MLBEDSW-3772 Reshape removal

-Removed reshapes in the original graph
-Removed the addition of reshapes to the
 optimized graph

-Reshapes with different ifm/ofm quantisation will remain

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I94862be53dac0d7434815e2aee5ca678228495f8
diff --git a/ethosu/vela/compiler_driver.py b/ethosu/vela/compiler_driver.py
index 78d7f12..3d4f758 100644
--- a/ethosu/vela/compiler_driver.py
+++ b/ethosu/vela/compiler_driver.py
@@ -146,9 +146,6 @@
     if options.verbose_quantization:
         nng.print_graph_with_tensor_quantization()
 
-    nng = graph_optimiser.optimise_graph_b(nng, arch, options.verbose_graph)
-    assert verify_graph_health(nng)
-
     nng = mark_tensors.mark_tensor_purpose(nng, arch, options.verbose_tensor_purpose)
     assert verify_graph_health(nng)
     nng = insert_dma.insert_dma_commands(nng, arch, options.verbose_graph)
diff --git a/ethosu/vela/debug_database.py b/ethosu/vela/debug_database.py
index 006348c..adf0362 100644
--- a/ethosu/vela/debug_database.py
+++ b/ethosu/vela/debug_database.py
@@ -25,6 +25,7 @@
 
 from . import numeric_util
 from .operation import Operation
+from .shape4d import Shape4D
 
 
 class DebugDatabase:
@@ -77,7 +78,10 @@
                 src_uid = cls._sourceUID[parent]
             uid = len(cls._optimisedUID)
             cls._optimisedUID[op] = (uid, src_uid)
-            ofm_shape = numeric_util.full_shape(3, op.outputs[0].shape, 1)
+            if len(op.ofm_shapes) == 0:
+                ofm_shape = Shape4D(op.outputs[0].shape)
+            else:
+                ofm_shape = op.ofm_shapes[0]
             cls._optimisedTable.append(
                 [
                     uid,
@@ -85,9 +89,9 @@
                     str(op.type),
                     op.kernel.width,
                     op.kernel.height,
-                    ofm_shape[-2],
-                    ofm_shape[-3],
-                    ofm_shape[-1],
+                    ofm_shape.width,
+                    ofm_shape.height,
+                    ofm_shape.depth,
                 ]
             )
 
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 5f11178..bb5a9e0 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -28,6 +28,7 @@
 from .data_type import DataType
 from .debug_database import DebugDatabase
 from .errors import UnsupportedFeatureError
+from .errors import VelaError
 from .ethos_u55_regs.ethos_u55_regs import resampling_mode
 from .numeric_util import clamp_sigmoid
 from .numeric_util import full_shape
@@ -42,7 +43,6 @@
 from .softmax import SoftMax
 from .tensor import check_quantized_tens_scaling_equal
 from .tensor import create_const_tensor
-from .tensor import create_reshape_tensor
 from .tensor import QuantizationParameters
 from .tensor import Tensor
 from .tflite_mapping import optype_to_builtintype
@@ -59,52 +59,68 @@
     return tens
 
 
-def rewrite_concat(tens, arch, nng):
-    if len(tens.ops) == 1 and tens.ops[0].type.is_concat_op():
-        concat_op = tens.ops[0]
-        if tens != concat_op.outputs[0]:
-            return tens  # don't attempt to rewrite the min/max outputs of QuantizedConcat
+def rewrite_concat_ops(op, arch, nng):
+    if not op.run_on_npu or not op.type.is_concat_op():
+        return op
 
-        # Not supported so leave it and run on CPU
-        if not concat_op.run_on_npu:
-            return tens
+    axis_4D = 0
+    ofm = op.ofm
+    ofm.ops = []
+    offset = 0
 
-        inputs, axis = concat_op.get_concat_inputs_axis()
+    if op.type == Op.Pack:
+        # Pack is also referred to as Stack
+        axis = int(op.attrs["axis"])
+        desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
 
-        tens.ops = []
-        offset = 0
-        for idx, inp in enumerate(inputs):
+        if axis >= 0:
+            axis_4D = axis + (4 - len(desired_shape))
+        else:
+            axis_4D = axis
+
+        for idx, inp in enumerate(op.inputs):
+            op.ifm_shapes[idx] = Shape4D(desired_shape)
+            if Shape4D(inp.shape) != op.ifm_shapes[idx]:
+                inp.avoid_NHCWB16 = True
+        op.type = Op.PackReshaped
+
+    inputs, axis = op.get_concat_inputs_axis()
+
+    for idx, inp in enumerate(inputs):
+        if op.type != Op.PackReshaped:
+            op.ifm_shapes[idx] = Shape4D(inp.shape)
             if axis >= 0:
                 axis_4D = axis + (4 - len(inp.shape))
             else:
                 axis_4D = axis
-            new_op = Operation(Op.ConcatSliceWrite, concat_op.name + str(idx))
-            new_op.inputs = [inp]
-            new_op.outputs = [tens]
-            new_op.attrs["concat_axis"] = axis_4D
-            new_op.attrs["concat_start"] = offset
-            offset += inp.shape[axis]
-            new_op.attrs["concat_end"] = offset
-            new_op.run_on_npu = True
-            tens.ops.append(new_op)
-            DebugDatabase.add_optimised(concat_op, new_op)
-            new_op.set_ifm_ofm_shapes()
-        assert tens.shape[axis] == offset
+        new_op = Operation(Op.ConcatSliceWrite, op.name + str(idx))
+        new_op.inputs = [inp]
+        new_op.outputs = [ofm]
+        new_op.attrs["concat_axis"] = axis_4D
+        new_op.attrs["concat_start"] = offset
+        offset += op.ifm_shapes[idx].get_dim(axis_4D)
 
-        # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
-        # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
-        # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
-        # and those addresses are always 16 byte aligned due to the NHCWB16 format.
-        if axis == -1 or axis == (len(tens.shape) - 1):
-            for op in tens.ops:
-                if op.attrs["concat_start"] % 16 != 0:
-                    tens.avoid_NHCWB16 = True
-                    break
+        new_op.attrs["concat_end"] = offset
+        new_op.run_on_npu = True
+        ofm.ops.append(new_op)
+        DebugDatabase.add_optimised(op, new_op)
+        new_op.ifm_shapes.append(op.ifm_shapes[idx].clone())
+        new_op.ofm_shapes.append(op.ofm_shapes[0].clone())
+    assert ofm.shape[axis] == offset
 
-    return tens
+    # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
+    # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
+    # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
+    # and those addresses are always 16 byte aligned due to the NHCWB16 format.
+    if axis == -1 or axis == (len(ofm.shape) - 1):
+        for op in ofm.ops:
+            if op.attrs["concat_start"] % 16 != 0:
+                ofm.avoid_NHCWB16 = True
+                break
+    return op
 
 
-def rewrite_split(tens, arch, nng):
+def rewrite_split_ops(tens, arch, nng):
 
     if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
         split_op = tens.ops[0]
@@ -118,20 +134,27 @@
         tens.ops = []
         new_op = Operation(Op.SplitSliceRead, split_op.name)
         new_op.inputs = [inp]
+        ofm_shape_idx = 0
 
         # For Split the offset cannot be extracted from the tensor so it has to
         # be calculated from the index of the output tensor
         if axis is not None:
             # Get the start and end of the split
             offset_start = [0] * 4
+            axis_4D_list = split_op.attrs.get("split_axis_4D", None)  # Present for UnpackReshaped and some StridedSlice
             for idx, out in enumerate(outputs):
-                split_op.ofm_shapes[idx] = Shape4D(out.shape)
-                if out == tens:
-                    break
-                if axis >= 0:
-                    axis_4D = axis + (4 - len(out.shape))
+                if axis_4D_list is not None:
+                    axis_4D = axis_4D_list[idx]
                 else:
-                    axis_4D = axis
+                    split_op.ofm_shapes[idx] = Shape4D(out.shape)
+                    if axis >= 0:
+                        axis_4D = axis + (4 - len(out.shape))
+                    else:
+                        axis_4D = axis
+
+                if out == tens:
+                    ofm_shape_idx = idx
+                    break
 
                 offset_start[axis_4D] += split_op.ofm_shapes[idx].get_dim(axis_4D)
 
@@ -145,7 +168,7 @@
         new_op.run_on_npu = True
         new_op.set_output_tensor(tens)
         new_op.ifm_shapes.append(Shape4D(inp.shape))
-        new_op.ofm_shapes.append(Shape4D(full_shape(4, tens.shape, 1)))
+        new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx].clone())
         DebugDatabase.add_optimised(split_op, new_op)
 
     return tens
@@ -158,9 +181,9 @@
     return total_padding
 
 
-def calc_padding_and_skirt(padding_type, kernel_size, stride, input_dims, explicit_padding):
-    ypad = needed_total_padding(int(input_dims[1]), int(stride[1]), int(kernel_size[0]))
-    xpad = needed_total_padding(int(input_dims[2]), int(stride[2]), int(kernel_size[1]))
+def calc_padding_and_skirt(padding_type, kernel_size, stride, input_shape, explicit_padding):
+    ypad = needed_total_padding(int(input_shape.height), int(stride[1]), int(kernel_size[0]))
+    xpad = needed_total_padding(int(input_shape.width), int(stride[2]), int(kernel_size[1]))
     if padding_type == Padding.SAME:
         left_pad = (xpad + 0) // 2
         right_pad = (xpad + 1) // 2
@@ -184,11 +207,11 @@
     return padding, skirt
 
 
-def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_dims, upscaling_factor):
+def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
     kernel_height, kernel_width = kernel_size[0], kernel_size[1]
     if padding_type == Padding.SAME:
-        ypad = needed_total_padding(int(input_dims[1]) * upscaling_factor, int(stride[1]), int(kernel_height))
-        xpad = needed_total_padding(int(input_dims[2]) * upscaling_factor, int(stride[2]), int(kernel_width))
+        ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
+        xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
         right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
         bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
         left_pad = max(kernel_width - 1 - right_pad, 0)
@@ -225,7 +248,7 @@
     op.name = op.name + "_add"
     op.attrs["resizebilinear"] = True
     # Create an input tensor filled with zeros
-    shape = op.outputs[0].shape
+    shape = op.ofm_shapes[0].as_list()
     tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
     tens.values = np.zeros(shape)
     tens.quant_values = np.zeros(shape, np.uint8)
@@ -258,8 +281,8 @@
         op.attrs["padding"] = Padding.SAME
     op.inputs[0].resampling_mode = resampling_mode.NEAREST
 
-    upscaled_shape = np.array(op.inputs[0].shape[1:3])
-    out_shape = np.array(op.outputs[0].shape[1:3])
+    upscaled_shape = op.ifm_shape[0].get_hw_as_list()
+    out_shape = op.ofm_shape[0].get_hw_as_list()
     if (upscaled_shape == upscaled_shape * 2 - shape_modifier).all():
         return op
 
@@ -276,8 +299,8 @@
             scaled_op.outputs = outputs
             scaled_op.outputs[0].ops = [scaled_op]
         else:
-            shape = outputs[0].shape.copy()
-            shape[1:3] = upscaled_shape[0:2]
+            shape = op.ofm_shapes[0].as_list()
+            shape[1:3] = upscaled_shape
             out_tens = Tensor(shape, DataType.int16, "{}_{}".format(op.outputs[0].name, count))
             out_tens.quantization = op.outputs[0].quantization.clone()
             out_tens.quantization.quant_min = np.iinfo(np.int16).min
@@ -300,11 +323,11 @@
 
 def fixup_resizebilinear(op, arch, nng):
     if op.type == Op.ResizeBilinear and op.run_on_npu:
-        if op.inputs[0].shape == op.outputs[0].shape:
+        if op.ifm_shapes[0] == op.ofm_shapes[0]:
             # Bypass nop resizebilinear
             op.inputs = op.inputs[:1]
             op.type = Op.Identity
-        elif op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1:
+        elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
             convert_resizebilinear_1x1_to_add(op)
         else:
             convert_resizebilinear_to_2x2_pool(op)
@@ -321,109 +344,26 @@
     return op
 
 
-def fixup_fully_connected_input(op, arch, nng):
-    if op.type == Op.FullyConnected:
-        inp = op.inputs[0]
-        weights = op.inputs[1]
-
-        n_in_elems = weights.shape[-2]
-        elms = inp.elements()
-        batch_size = elms // n_in_elems
-        assert batch_size * n_in_elems == elms
-
-        desired_shape = [batch_size, n_in_elems]
-        if inp.shape != desired_shape:
-            # mismatch, insert a reshape to fix this.
-            op.set_input_tensor(create_reshape_tensor(inp, desired_shape), 0)
-
-    return op
-
-
 def convert_batched_fc_shape(op, arch, nng):
     if op.type == Op.FullyConnected:
-        ifm = op.inputs[0]
-        ofm = op.outputs[0]
-        # Check if the FC is 2D and first dimension indicates batching
-        # TOD0 op.ifm_shape[0] > 1 is enough when refactory is complete
-        if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0].batch > 1:
-            n = ifm.shape[0]
+        # Check if the first dimension indicates batching
+        if op.ifm_shapes[0].batch > 1:
             batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
+            n = op.ifm_shapes[0].batch
             h, w = batching_split.get(n, (1, n))
+            op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
 
-            prev_op = ifm.ops[0]
-            desired_shape = [1, h, w, ifm.shape[-1]]
-            op.ifm_shapes[0] = Shape4D(desired_shape)
-
-            if len(ifm.consumer_list) == 1 and prev_op is not None and prev_op.type == Op.Reshape:
-                # There is a preceding Reshape
-                # Compare input of prev_op and input of op, to see if prev_op can be removed
-                ifm_prev_op = prev_op.inputs[0]
-                if ifm_prev_op.shape == ifm.shape and check_quantized_tens_scaling_equal(ifm_prev_op, ifm):
-                    # prev_op can be removed
-                    op.set_input_tensor(ifm_prev_op, 0)
-                else:
-                    op.inputs[0].set_all_shapes(desired_shape)
-                    prev_op.set_input_tensor(
-                        create_const_tensor(prev_op.inputs[1].name, [1], DataType.int32, desired_shape), 1
-                    )
-                    prev_op.attrs["new_shape"] = desired_shape
-            else:
-                # Add reshape op to the input if there is no preceding reshape
-                ifm.consumer_list.remove(op)
-                op.set_input_tensor(create_reshape_tensor(ifm, desired_shape), 0)
+            op.ifm.avoid_NHCWB16 = True
 
             # Reshape Weights to be 4D. IO becomes HWIO
             weight_tensor = op.inputs[1]
             weight_tensor.quant_values = np.expand_dims(np.expand_dims(weight_tensor.quant_values, axis=0), axis=0)
             weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
 
-            desired_shape = [1, h, w, ofm.shape[-1]]
-            op.ofm_shapes[0] = Shape4D(desired_shape)
-
-            if (
-                len(ofm.consumer_list) == 1
-                and ofm.consumer_list[0] is not None
-                and ofm.consumer_list[0].type == Op.Reshape
-            ):
-                # There is a subsequent Reshape
-                # Compare desired shape and output of consumer op, to see if consumer op can be removed
-                ofm_cons_op = ofm.consumer_list[0].outputs[0]
-                if desired_shape == ofm_cons_op.shape and check_quantized_tens_scaling_equal(ofm, ofm_cons_op):
-                    op.outputs[0] = ofm_cons_op
-                    op.outputs[0].ops = [op]
-                else:
-                    op.outputs[0].set_all_shapes(desired_shape)
-            else:
-                # Add reshape op to the output
-                op.set_output_tensor(create_reshape_tensor(ofm, desired_shape, False))
-    return op
-
-
-def fixup_pack_input(op, arch, nng):
-    if op.type == Op.Pack:
-        # Pack is also referred to as Stack
-        # Requires the rewrite_concat function to be called on the op afterwards
-        axis = int(op.attrs["axis"])
-        desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
-
-        # Construct 1 shape tensor to be used by all inserted reshape ops
-        new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, desired_shape)
-
-        for idx, inp in enumerate(op.inputs):
-            reshape_out = inp.clone("_reshaped")
-            reshape_out.set_all_shapes(desired_shape)
-
-            reshape_op = Operation(Op.Reshape, "{}{}_reshape".format(op.name, idx))
-            reshape_op.attrs["new_shape"] = desired_shape
-            reshape_op.inputs = [inp, new_shape_tens]
-            reshape_op.set_output_tensor(reshape_out)
-            reshape_op.set_ifm_ofm_shapes()
-            DebugDatabase.add_optimised(op, reshape_op)
-
-            op.inputs[idx] = reshape_out
-
-        op.type = Op.PackReshaped
-
+            n = op.ofm_shapes[0].batch
+            h, w = batching_split.get(n, (1, n))
+            op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
+            op.ofm.avoid_NHCWB16 = True
     return op
 
 
@@ -441,12 +381,19 @@
     return op
 
 
-def fixup_stridedslice_output(tens, arch, nng):
-    op = tens.ops[0]
-    if op.run_on_npu and op.type == Op.StridedSlice:
-        reshape_input_shape = tens.shape
-        new_axis_mask = op.attrs["new_axis_mask"]
-        shrink_axis_mask = op.attrs["shrink_axis_mask"]
+def rewrite_stridedslice_output(op, arch, nng):
+    if not op.run_on_npu or op.type != Op.StridedSlice:
+        return op
+
+    new_axis_mask = op.attrs["new_axis_mask"]
+    shrink_axis_mask = op.attrs["shrink_axis_mask"]
+
+    if shrink_axis_mask == 0 and new_axis_mask == 0:
+        return op
+
+    axis_4D = [0] * len(op.outputs)
+    for idx, out_tens in enumerate(op.outputs):
+        output_shape = list(out_tens.shape)
 
         if shrink_axis_mask != 0:
             n = 0
@@ -456,10 +403,16 @@
                 n += 1
                 shrink_axis_mask &= shrink_axis_mask - 1
                 axis = int(math.log2(prev_mask - shrink_axis_mask))
-                reshape_input_shape = reshape_input_shape[:axis] + [1] + reshape_input_shape[axis:]
+                output_shape = output_shape[:axis] + [1] + output_shape[axis:]
 
-            assert len(tens.shape) == (len(op.inputs[0].shape) - n)
+            assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
             op.attrs["shrink_axis_mask"] = 0
+            if axis >= 0:
+                axis_4D[idx] = axis + (4 - len(output_shape))
+            else:
+                axis_4D[idx] = axis
+            op.ofm_shapes[idx] = Shape4D(output_shape)
+
         elif new_axis_mask != 0:
             n = 0
             axis = 0
@@ -468,77 +421,62 @@
                 n += 1
                 new_axis_mask &= new_axis_mask - 1
                 axis = int(math.log2(prev_mask - new_axis_mask))
-                reshape_input_shape = reshape_input_shape[:axis] + reshape_input_shape[(axis + 1) :]
+                output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
                 new_axis_mask >>= 1
 
-            assert len(tens.shape) == (len(op.inputs[0].shape) + n)
+            assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
             op.attrs["new_axis_mask"] = 0
-        else:
-            # Equal Rank StridedSlice, no need to insert reshape
-            return tens
+            if axis >= 0:
+                axis_4D[idx] = axis + (4 - len(output_shape))
+            else:
+                axis_4D[idx] = axis
+            op.ofm_shapes[idx] = Shape4D(output_shape)
 
-        # Construct 1 shape tensor to be used by all inserted reshape ops
-        new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape)
+        if op.ofm_shapes[idx] != Shape4D(out_tens.shape):
+            out_tens.avoid_NHCWB16 = True
 
-        for idx, out_tens in enumerate(op.outputs):
-            op.ofm_shapes[idx] = Shape4D(new_shape_tens.shape)
-            reshape_in = out_tens.clone("_reshaped")
-            reshape_in.set_all_shapes(reshape_input_shape)
-            reshape_in.ops = [op]
-
-            reshape_op = Operation(Op.Reshape, "{}{}_reshape".format(op.name, idx))
-            reshape_op.attrs["new_shape"] = reshape_input_shape
-            reshape_op.inputs = [reshape_in, new_shape_tens]
-            reshape_op.set_output_tensor(out_tens)
-            reshape_op.set_ifm_ofm_shapes()
-
-            op.outputs[idx] = reshape_in
-
-    return tens
+    op.attrs["split_axis_4D"] = axis_4D
+    return op
 
 
-def fixup_unpack_output(tens, arch, nng):
-    op = tens.ops[0]
+def rewrite_unpack_output(op, arch, nng):
+    tens = op.outputs[0]
     if op.run_on_npu and op.type == Op.Unpack:
         # Unpack is also referred to as Unstack
-        # Requires the rewrite_split function to be called on the op afterwards
         axis = int(op.attrs["axis"])
         op.type = Op.UnpackReshaped
-        reshape_input_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
+        desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
 
-        # Construct 1 shape tensor to be used by all inserted reshape ops
-        new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape)
+        if axis >= 0:
+            axis_4D = axis + (4 - len(desired_output_shape))
+        else:
+            axis_4D = axis
 
+        axis_4D_list = [0] * len(op.outputs)
         for idx, out_tens in enumerate(op.outputs):
-            reshape_in = out_tens.clone("_reshaped")
-            reshape_in.set_all_shapes(reshape_input_shape)
-            reshape_in.ops = [op]
+            op.ofm_shapes[idx] = Shape4D(desired_output_shape)
+            axis_4D_list[idx] = axis_4D
+            if op.ofm_shapes[idx] != Shape4D(out_tens.shape):
+                out_tens.avoid_NHCWB16 = True
 
-            reshape_op = Operation(Op.Reshape, "{}{}_reshape".format(op.name, idx))
-            reshape_op.attrs["new_shape"] = reshape_input_shape
-            reshape_op.inputs = [reshape_in, new_shape_tens]
-            reshape_op.set_output_tensor(out_tens)
-            reshape_op.set_ifm_ofm_shapes()
-            DebugDatabase.add_optimised(op, reshape_op)
-
-            op.outputs[idx] = reshape_in
-    return tens
+        op.attrs["split_axis_4D"] = axis_4D_list
+    return op
 
 
 def add_padding_fields(op, arch, nng):
     if op.run_on_npu:
         if "padding" in op.attrs:
+            input_shape = op.ifm_shapes[0]
+            output_shape = op.ofm_shapes[0]
             if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
                 kernel_size = op.inputs[1].shape[:2]
-                input_shape = op.inputs[0].shape
             elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
                 kernel_size = op.attrs["ksize"][1:3]
-                input_shape = op.inputs[0].shape
             else:
                 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
 
             if op.type == Op.Conv2DBackpropInputSwitchedBias:
-                upscaling_factor = op.outputs[0].shape[1] // input_shape[1]
+                upscaling_factor = output_shape.height // input_shape.height
                 padding, skirt = calc_upscaled_padding_and_skirt(
                     op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
                 )
@@ -582,10 +520,10 @@
     # switch of the operator type (and weight order)
 
     if op.type == Op.DepthwiseConv2DBias and (op.attrs["depth_multiplier"] != 1):
-        ifm_tensor = op.inputs[0]
+        ifm_shape = op.ifm_shapes[0]
         weight_tensor = op.inputs[1]
-        ofm_tensor = op.outputs[0]
-        if (ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"]):
+        ofm_shape = op.ofm_shapes[0]
+        if (ifm_shape.depth == 1) and (ofm_shape.depth == op.attrs["depth_multiplier"]):
             # Change op type to Conv2d
             op.type = Op.Conv2DBias
             del op.attrs["channel_multiplier"]
@@ -596,7 +534,7 @@
         else:
             raise UnsupportedFeatureError(
                 f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},",
-                f" ifm channels = {ifm_tensor.shape[3]}, ofm channels = {ofm_tensor.shape[3]}",
+                f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}",
             )
         DebugDatabase.add_optimised(op, op)
     return op
@@ -620,17 +558,15 @@
         op.type == Op.Conv2DBias
         and op.op_index == 0
         and stride_x == 2
-        and len(ifm_tensor.shape) == 4
-        and ifm_tensor.shape[3] <= 4
-        and ifm_tensor.shape[2] % 2 == 0
+        and op.ifm_shapes[0].depth <= 4
+        and op.ifm_shapes[0].width % 2 == 0
         and weight_tensor is not None
         and weight_tensor.shape[1] >= 2
     ):
+        ifm_shape = op.ifm_shapes[0]
         # IFM
-        ifm_reshaped = create_reshape_tensor(
-            ifm_tensor, [ifm_tensor.shape[0], ifm_tensor.shape[1], ifm_tensor.shape[2] // 2, ifm_tensor.shape[3] * 2]
-        )
-        op.set_input_tensor(ifm_reshaped, 0)
+        op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
+        op.ifm.avoid_NHCWB16 = True
 
         # Weights
         weight_shape = weight_tensor.shape
@@ -657,8 +593,6 @@
         stride_x = 1
         op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
 
-        op.set_ifm_ofm_shapes()
-
     return op
 
 
@@ -683,27 +617,6 @@
             weight_tensor.quant_values = weight_tensor.quant_values.squeeze(axis=(0, 1))
             weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
 
-            # The output from a fully connected is expected to be 2D so we need to add a reshape layer to convert it
-            # back to 4D afterwards as the next layer is expecting that shape
-            orig_ofm_tensor = op.outputs[0]
-            # Reshape this ops output to be 2D: {(N*H*W), C} (We know N H and W are all 1 so this becomes {1, C})
-            fc_ofm_tensor = orig_ofm_tensor.clone("_fc")
-            fc_ofm_tensor.set_all_shapes([1, fc_ofm_tensor.shape[-1]])
-            fc_ofm_tensor.ops = [op]
-            # Add a reshape after the new OFM to convert it back to the original 4D shape
-            reshape_name = op.name + "_reshape"
-            new_shape_tens = create_const_tensor(reshape_name + "_shape", [1], DataType.int32, orig_ofm_tensor.shape)
-            reshape_op = Operation(Op.Reshape, reshape_name)
-            reshape_op.attrs["new_shape"] = orig_ofm_tensor.shape
-            reshape_op.inputs = [fc_ofm_tensor, new_shape_tens]
-            reshape_op.set_output_tensor(orig_ofm_tensor)
-            reshape_op.set_ifm_ofm_shapes()
-
-            # Replace this ops OFM to point to the 2D tensor
-            op.outputs[0] = fc_ofm_tensor
-            op.set_ifm_ofm_shapes()
-            # Record optimisation in debug database
-            DebugDatabase.add_optimised(op, reshape_op)
             DebugDatabase.add_optimised(op, op)
     return op
 
@@ -722,14 +635,6 @@
             # Tidy up and assign the ifm and ofm to the new op
             ifm.consumer_list.remove(op)
 
-            # if not 4d, reshape ifm/ofm
-            if len(ifm.shape) < 4:
-                ifm_shaped = create_reshape_tensor(ifm, full_shape(4, ifm.shape, 1))
-                ifm = ifm_shaped
-            if len(ofm.shape) < 4:
-                ofm_shaped = create_reshape_tensor(ofm, full_shape(4, ofm.shape, 1), False)
-                ofm = ofm_shaped
-
             relu_fused_op.add_input_tensor(ifm)
             relu_fused_op.set_output_tensor(ofm)
             relu_fused_op.set_ifm_ofm_shapes()
@@ -737,6 +642,7 @@
     return op
 
 
+# TODO remove if mem only ops can all be removed
 # Reorder activation op if it's after the memory only operations
 def fixup_act_reorder(op, arch, nng):
     if op.type.is_relu_op() or op.type in (Op.Sigmoid, Op.Tanh):
@@ -752,8 +658,8 @@
             act_op_out = act_op.inputs[0].clone("_acted")
             act_op_out.quantization = op.outputs[0].quantization.clone()
             act_op.set_output_tensor(act_op_out)
-            act_op.ifm_shapes[0] = Shape4D(prep_op.inputs[0].shape)
-            act_op.ofm_shapes[0] = Shape4D(act_op_out.shape)
+            act_op.ofm_shapes[0] = act_op.ifm_shapes[0].clone()
+            act_op.ifm_shapes[0] = prep_op.ifm_shapes[0].clone()
 
             # Update the consumer list
             act_op_out.consumer_list = op.outputs[0].consumer_list.copy()
@@ -1078,39 +984,94 @@
     return op
 
 
-def remove_unwanted_reshapes(op, arch, nng):
-    # Try to remove reshapes enclosing ElementWise operator with only one non-constant input
-    if not op.run_on_npu or not op.type.is_elementwise_op():
-        return op
+def remove_reshapes(op, arch):
+    if op.run_on_npu and op.type == Op.Reshape:
+        ofm = op.ofm
+        ifm = op.ifm
 
-    # Check if the ElementWise operator only have one non-constant input
-    non_const_tens = [x for x in op.inputs if x.ops[0].type != Op.Const]
-    if len(non_const_tens) != 1:
-        return op
-    ifm = non_const_tens[0]
+        # Check if quantization is the same in the input and output for the reshape ops
+        if not check_quantized_tens_scaling_equal(ifm, ofm):
+            # TODO Both tensors are needed, since quantisation properties currently are linked to Tensors.
+            # In order to remove this reshape either quantization properties need to be moved to Operator,
+            # or the reshape need to be replace with a NOP.
+            return
 
-    # Check if operation is enclosed by Reshapes that can be removed
-    ofm = op.outputs[0]
-    prev_op = ifm.ops[0]
-    if (
-        len(ifm.consumer_list) == 1
-        and prev_op.type == Op.Reshape
-        and len(ofm.consumer_list) == 1
-        and ofm.consumer_list[0].type == Op.Reshape
-    ):
-        # Operation is enclosed by reshapes, check if they can be removed
-        prev_op_ifm, prev_op_ofm = prev_op.get_ifm_ofm()
-        cons_op = ofm.consumer_list[0]
-        cons_op_ifm = ofm
-        cons_op_ofm = cons_op.outputs[0]
-        if len(prev_op_ifm.shape) == len(cons_op_ofm.shape):
-            # Check if quantization is the same in the input and output for the reshape ops
-            if check_quantized_tens_scaling_equal(prev_op_ifm, prev_op_ofm) and check_quantized_tens_scaling_equal(
-                cons_op_ifm, cons_op_ofm
-            ):
-                op.set_input_tensor(prev_op_ifm, 0)
-                op.set_output_tensor(cons_op_ofm)
-    return op
+        # Check if ifm is a sg input
+        if ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const):
+            # put the reshape on CPU
+            op.run_on_npu = False
+            return
+
+        # Check if Reshape ifm/ofm are network ifm/ofm
+        ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list)
+        ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list)
+
+        if ifm_is_sg_ofm and ofm_is_sg_ofm:
+            # Both ifm and ofm are sg outputs,add reshape to the ifm and put it on CPU
+            ifm_cons_list_copy = ifm.consumer_list.copy()
+            ifm_ops_copy = ifm.ops.copy()
+            for ifm_cons in ifm_cons_list_copy:
+                if ifm_cons is None:
+                    # Create a reshape op with ifm as output
+                    name = ifm.name + "_cpu_reshape"
+                    reshape_ifm = ifm.clone()
+                    reshape_op = Operation(Op.Reshape, name)
+                    reshape_op.attrs["new_shape"] = ifm.shape
+                    reshape_op.add_input_tensor(reshape_ifm)
+                    reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, ifm.shape))
+                    reshape_op.set_output_tensor(ifm)
+                    reshape_op.set_ifm_ofm_shapes()
+                    reshape_op.run_on_npu = False
+                    reshape_op.ofm.ops = [reshape_op]
+                    reshape_op.ofm.consumer_list = [None]
+
+                    # Set reshape_ifm producers
+                    for prev_op in ifm_ops_copy:
+                        prev_op.outputs = [reshape_ifm]
+                        reshape_ifm.ops.append(prev_op)
+
+                    # Set reshape_ifm consumers
+                    for ifm_cons in ifm_cons_list_copy:
+                        if ifm_cons is not None:
+                            for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
+                                if cons_ifm == ifm:
+                                    ifm_cons.set_input_tensor(reshape_ifm, ifm_idx)
+
+                    ifm = reshape_ifm
+                    break
+            ifm_is_sg_ofm = False
+
+        if ofm_is_sg_ofm:
+            # Bypassed by replacing ifm with ofm
+            ofm.ops = []
+            for prev_op in ifm.ops:
+                prev_op.outputs = [ofm]
+                ofm.ops.append(prev_op)
+
+            # All ifm consumers need to use ofm as input
+            for ifm_cons in ifm.consumer_list:
+                for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
+                    if cons_ifm == ifm:
+                        ifm_cons.set_input_tensor(ofm, ifm_idx)
+            if op.ifm_shapes[0] != op.ofm_shapes[0]:
+                ofm.avoid_NHCWB16 = True
+        else:
+            # Bypassed Reshape by replacing ofm with ifm
+            for cons in ofm.consumer_list:
+                for ifm_idx, cons_ifm in enumerate(cons.inputs):
+                    if cons_ifm == ofm:
+                        cons.set_input_tensor(ifm, ifm_idx)
+            if op.ifm_shapes[0] != op.ofm_shapes[0]:
+                ifm.avoid_NHCWB16 = True
+
+
+def check_reshapes(op, arch):
+    if op.run_on_npu and op.type == Op.Reshape:
+        ofm = op.ofm
+
+        if check_quantized_tens_scaling_equal(op.ifm, ofm):
+            # Reshape should have been removed
+            raise VelaError(f"Reshape op {op} expected to have been removed, still remains")
 
 
 def fuse_activation_function_with_prev(op, arch, nng):
@@ -1174,13 +1135,19 @@
 def add_attrs_to_resizebilinear(op, arch, nng):
     if op.type == Op.ResizeBilinear and op.run_on_npu:
         input_tensor = op.inputs[0]
-        upscaled_shape = [input_tensor.shape[1] * 2, input_tensor.shape[2] * 2]
-        out_shape = op.outputs[0].shape[1:3]
-        if not op.attrs["align_corners"] and out_shape == upscaled_shape:
+        input_shape = op.ifm_shapes[0]
+        upscaled_height = input_shape.height * 2
+        upscaled_width = input_shape.width * 2
+        out_shape = op.ofm_shapes[0]
+        if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
             # this means the output is supposed to be a x2 upscale,
             # so we need to do SAME padding
             op.attrs["padding"] = Padding.SAME
-        elif op.attrs["align_corners"] and out_shape == [upscaled_shape[0] - 1, upscaled_shape[1] - 1]:
+        elif (
+            op.attrs["align_corners"]
+            and out_shape.height == (upscaled_height - 1)
+            and out_shape.width == (upscaled_width - 1)
+        ):
             # here we can just run the avg pool without padding and
             # produce a (M * 2 - 1, N * 2 - 1) sized output
             op.attrs["padding"] = Padding.VALID
@@ -1229,26 +1196,52 @@
             nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
         )
 
+    # Handle Concat Ops
+    for idx, sg in enumerate(nng.subgraphs):
+        # rewrite graph pass
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng, sg, arch, [], [rewrite_concat_ops], rewrite_unsupported=False,
+        )
+
+    # Handle Split Ops
+    for idx, sg in enumerate(nng.subgraphs):
+        # rewrite graph pass
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng,
+            sg,
+            arch,
+            [],
+            [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
+            rewrite_unsupported=False,
+        )
+
+    for idx, sg in enumerate(nng.subgraphs):
+        # rewrite graph pass
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False,
+        )
+
+    # Removal of reshapes
+    for sg in nng.subgraphs:
+        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
+        sg.refresh_after_modification()
+
     op_rewrite_list = [
         set_tensor_equivalence,
         convert_depthwise_to_conv,
         convert_conv_to_fc,
         convert_softmax,
         optimise_strided_conv,
-        fixup_fully_connected_input,
         convert_batched_fc_shape,
-        fixup_pack_input,
         unfuse_activation_function,
         fixup_conv2d_backprop,
         fixup_relus_with_differing_ifm_ofm_scaling,
         fixup_act_reorder,
-        fixup_elementwise_with_scalars,
+        fixup_elementwise_with_scalars,  # TODO Move to early stage?
         reorder_depthwise_weights,
         fixup_resizebilinear,
         fixup_bias_tensors,
-        convert_nop_split_to_identity,
         convert_mul_max_to_abs_or_lrelu,
-        remove_unwanted_reshapes,
         convert_lrelu,
         convert_tanh_sigmoid_to_lut,
     ]
@@ -1269,24 +1262,9 @@
             [fuse_activation_function_with_prev, optimise_pad, add_padding_fields],
         )
 
-    # Post-optimisation operator debug tracing
+    # Post-optimisation operator debug tracing, and checking that no undesired reshapes are left in the graph
     for sg in nng.subgraphs:
-        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [_record_optimised])
-
-    if verbose_graph:
-        nng.print_graph()
-    return nng
-
-
-def optimise_graph_b(nng, arch, verbose_graph=False):
-    if verbose_graph:
-        nng.print_graph()
-
-    for idx, sg in enumerate(nng.subgraphs):
-        # combined rewrite graph pass
-        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], [],
-        )
+        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [check_reshapes, _record_optimised])
 
     if verbose_graph:
         nng.print_graph()
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index 60e62aa..e514e76 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -374,13 +374,13 @@
         if cmd.is_npu_pass_command():
             if cmd.is_first:
                 ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(
-                    cmd.ifm_box.start_coord, cmd.ps.ifm_shapes[0].as_list(), is_top_box=False
+                    cmd.ifm_box.start_coord, cmd.ps.ifm_shapes[0], is_top_box=False
                 )
                 if ifm_read is None:
                     return 0
             if cmd.is_last:
                 write_offset = cmd.ofm_tensor.address_offset_for_coordinate(
-                    cmd.ofm_box.end_coord, cmd.ps.ofm_shapes[0].as_list(), is_top_box=True
+                    cmd.ofm_box.end_coord, cmd.ps.ofm_shapes[0], is_top_box=True
                 )
                 if write_offset is None:
                     return 0
@@ -393,7 +393,7 @@
 
             if cmd.is_first:
                 ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(
-                    cmd.ifm_box.end_coord, cmd.ps.ifm_shapes[0].as_list(), is_top_box=True
+                    cmd.ifm_box.end_coord, cmd.ps.ifm_shapes[0], is_top_box=True
                 )
 
     min_overlap = max(min_overlap, 0)
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 8e4d33a..3143483 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -43,7 +43,6 @@
 from .api import NpuShape3D
 from .api import NpuTileBox
 from .architecture_features import ArchitectureFeatures
-from .architecture_features import Block
 from .data_type import DataType
 from .debug_database import DebugDatabase
 from .errors import UnsupportedFeatureError
@@ -152,7 +151,7 @@
     # because of activation function needed to be fused.
     if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > 0:
         left = 0
-    if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < Block.from_shape(cmd.ifm_tensor.shape).width:
+    if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < cmd.ps.ifm_shapes[0].width:
         right = 0
     return NpuPadding(top=top, left=left, bottom=bottom, right=right)
 
@@ -233,7 +232,7 @@
     return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
 
 
-def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, fm_shape: Shape4D) -> NpuFeatureMap:
+def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
     """Creates feature map with common fields populated"""
     fm = NpuFeatureMap()
     fm.region = get_region(tens, arch)
@@ -244,14 +243,16 @@
         fm.layout = NpuLayout.NHCWB16
     else:
         assert 0, "Incorrect tensor format"
-    height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(box.start_coord, box.end_coord, fm_shape)
+    height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
+        box.start_coord, box.end_coord, op_shape4D
+    )
     for idx, addr in enumerate(addresses):
         if addr is None:
             addresses[idx] = 0
     fm.tiles = NpuTileBox(
         height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
     )
-    strides = tens.get_strides()
+    strides = tens.get_strides(shape4D=op_shape4D)
     fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
     return fm
 
@@ -325,7 +326,7 @@
     op = ps.primary_op
 
     ifm_height = cmd.ifm_box.get_block().height
-    ifm_width = Block.from_shape(cmd.ifm_tensor.shape).width
+    ifm_width = cmd.ps.ifm_shapes[0].width
     ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
 
     npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
@@ -401,7 +402,9 @@
     npu_op = NpuElementWiseOperation(elemwise_op)
 
     if elemwise_op not in UNARY_ELEMWISE_OPS:
-        if not ifm_ifm2_correct_order(cmd.ifm_tensor.shape, cmd.ifm2_tensor.shape):
+        ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
+        ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
+        if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
             # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
             cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
             cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
@@ -416,7 +419,7 @@
             npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
         else:
             ifm2_blk = cmd.ifm2_box.get_block()
-            ifm2_width = Block.from_shape(cmd.ifm2_tensor.shape).width
+            ifm2_width = ps.ifm_shapes[1].width
             npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
     set_common_op_fields(npu_op, cmd, arch)
     # Check if output scale needs to be overridden
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index c2418d7..3acd5e6 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -117,15 +117,21 @@
     return min(ifm_depth, ifm_blk_depth)
 
 
-def get_minimal_cmd_cycles(arch, ifm_tensor, ofm_tensor, ifm_blk: Block, ofm_blk: Block, output_cycles, dpu_cycles=0):
+def get_minimal_cmd_cycles(
+    arch, ifm_tensor, ofm_tensor, ifm_blk: Block, ofm_blk: Block, output_cycles, ifm_shape4D, ofm_shape4D, dpu_cycles=0
+):
     ifm_tens_blk = Tensor((1, ifm_blk.height, ifm_blk.width, ifm_blk.depth), ifm_tensor.dtype, "ifm_blk")
     ofm_tens_blk = Tensor((1, ofm_blk.height, ofm_blk.width, ofm_blk.depth), ofm_tensor.dtype, "ofm_blk")
     cycles_ifm_blk = (
-        estimate_memory_transfer_efficiency(arch, ifm_tensor.mem_area, BandwidthDirection.Read, ifm_tens_blk, ifm_blk)
+        estimate_memory_transfer_efficiency(
+            arch, ifm_tensor.mem_area, BandwidthDirection.Read, ifm_tens_blk, ifm_blk, shape4D=ifm_shape4D
+        )
         / arch.memory_bandwidths_per_cycle[ifm_tensor.mem_area]
     )
     cycles_ofm_blk = (
-        estimate_memory_transfer_efficiency(arch, ofm_tensor.mem_area, BandwidthDirection.Write, ofm_tens_blk, ofm_blk)
+        estimate_memory_transfer_efficiency(
+            arch, ofm_tensor.mem_area, BandwidthDirection.Write, ofm_tens_blk, ofm_blk, shape4D=ofm_shape4D
+        )
         / arch.memory_bandwidths_per_cycle[ofm_tensor.mem_area]
     )
     return (
@@ -204,7 +210,14 @@
     if primary_op.type.is_elementwise_op() and block_config is not None:
         num_elems_blk = block_config.width * block_config.height * block_config.depth
         cycle_cmd = get_minimal_cmd_cycles(
-            arch, ifm_tensor, ofm_tensor, block_config, block_config, num_elems_blk * cycle_per_elem
+            arch,
+            ifm_tensor,
+            ofm_tensor,
+            block_config,
+            block_config,
+            num_elems_blk * cycle_per_elem,
+            primary_op.ifm_shapes[0],
+            primary_op.ofm_shapes[0],
         )
         cycle_per_elem = max(cycle_per_elem, cycle_cmd / num_elems_blk)
 
@@ -343,7 +356,15 @@
         cycles_output_blk = max(cycles_output_blk, cycles_bias_blk)
 
     cycles_cmd = get_minimal_cmd_cycles(
-        arch, ifm_tensor, ofm_tensor, ifm_block, ofm_block, cycles_dpu_blk, cycles_output_blk
+        arch,
+        ifm_tensor,
+        ofm_tensor,
+        ifm_block,
+        ofm_block,
+        cycles_dpu_blk,
+        ifm_tens_shape,
+        ofm_tens_shape,
+        cycles_output_blk,
     )
     cycles_dpu_blk = max(cycles_dpu_blk, cycles_cmd)
     cycles_output_blk = max(cycles_output_blk, cycles_cmd)
@@ -356,7 +377,9 @@
     return total_cycles
 
 
-def estimate_memory_transfer_efficiency(arch, mem_area, direction, tensor, block_size: Block, replace_bw=None):
+def estimate_memory_transfer_efficiency(
+    arch, mem_area, direction, tensor, block_size: Block, replace_bw=None, shape4D=None
+):
     if tensor.format not in (TensorFormat.NHWC, TensorFormat.NHCWB16):
         return tensor.bandwidth() if replace_bw is None else replace_bw
 
@@ -368,9 +391,10 @@
     tens = tensor.clone()
     if not tens.avoid_NHCWB16:
         tens.set_format(TensorFormat.NHCWB16, arch)
+    strides = tens.get_strides(shape4D=shape4D)
 
     if tens.format == TensorFormat.NHCWB16:
-        if tens.get_strides()[1] == block_size.depth:
+        if strides[1] == block_size.depth:
             burst_len = elem_size * block_size.depth * block_size.width
         elif is_ifm:
             burst_len = 16 * elem_size * block_size.width
@@ -379,12 +403,12 @@
     else:
         assert tens.format == TensorFormat.NHWC
         if is_ifm:
-            if tens.get_strides()[3] == block_size.depth:
+            if strides[3] == block_size.depth:
                 burst_len = elem_size * block_size.depth * block_size.width
             else:
                 burst_len = elem_size * block_size.depth
         else:
-            if block_size.depth <= 16 and tens.get_strides()[3] == block_size.depth:
+            if block_size.depth <= 16 and strides[3] == block_size.depth:
                 burst_len = elem_size * block_size.depth * block_size.width
             else:
                 burst_len = min(64, 16 * elem_size * arch.ncores, block_size.depth * elem_size)
@@ -585,12 +609,12 @@
             scaled_bws[arch.fast_storage_mem_area][tens.purpose][
                 BandwidthDirection.Write
             ] += estimate_memory_transfer_efficiency(
-                arch, arch.fast_storage_mem_area, BandwidthDirection.Write, tens, ofm_block
+                arch, arch.fast_storage_mem_area, BandwidthDirection.Write, tens, ofm_block, shape4D=ps.ofm_shapes[0],
             )
         else:
             bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
             scaled_bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += estimate_memory_transfer_efficiency(
-                arch, tens.mem_area, BandwidthDirection.Write, tens, ofm_block
+                arch, tens.mem_area, BandwidthDirection.Write, tens, ofm_block, shape4D=ps.ofm_shapes[0]
             )
 
     for tens in ps.intermediates:
@@ -612,8 +636,16 @@
             bw = tens.bandwidth()
 
         bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
+
+        op_shape = None
+        if ps.placement == PassPlacement.Npu and primary_op:
+            if tens == ps.ifm_tensor:
+                op_shape = ps.ifm_shapes[0]
+            elif tens == ps.ifm2_tensor:
+                op_shape = ps.ifm_shapes[1]
+
         scaled_bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += estimate_memory_transfer_efficiency(
-            arch, tens.mem_area, BandwidthDirection.Read, tens, ifm_block, bw
+            arch, tens.mem_area, BandwidthDirection.Read, tens, ifm_block, bw, op_shape
         )
 
     # quick build access counts for only current pass, even though these aren't the final numbers
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 844f298..342efd9 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -629,7 +629,6 @@
         elif self.type == Op.StridedSlice:
             input_tens, begin_tens, end_tens, strides_tens = self.inputs
             outputs = self.outputs
-            out_tens = outputs[0]
 
             # Extract masks
             begin_mask = self.attrs["begin_mask"]
@@ -641,7 +640,6 @@
             # shrink_axis_mask/new_axis_mask/ellipsis_mask is not supported by the Operation class but the operation
             # may have the attribute modified and handled in the graph optimization phase.
             assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
-            assert len(input_tens.shape) == len(out_tens.shape)
             offset_start = get_slice_offsets(input_tens.shape, begin_tens, begin_mask, is_begin=True)
             offset_end = get_slice_offsets(input_tens.shape, end_tens, end_mask, is_begin=False)
         elif self.type == Op.UnpackReshaped:
diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py
index 7015b79..e7e4bbb 100644
--- a/ethosu/vela/operation_util.py
+++ b/ethosu/vela/operation_util.py
@@ -24,7 +24,7 @@
 from .operation import Op
 from .operation import Operation
 from .operation import Padding
-from .tensor import create_reshape_tensor
+from .shape4d import Shape4D
 from .tensor import QuantizationParameters
 from .tensor import Tensor
 
@@ -44,12 +44,17 @@
 
 
 def create_depthwise_maxpool(
-    name: str, ifm: Tensor, quantization: QuantizationParameters, activation: Optional[ActivationFunction] = None
+    name: str,
+    ifm: Tensor,
+    inp_shape: Shape4D,
+    quantization: QuantizationParameters,
+    activation: Optional[ActivationFunction] = None,
 ) -> Operation:
     op = Operation(Op.MaxPool, name)
-    height = ifm.shape[1] * ifm.shape[2]
-    width = ifm.shape[3]
-    ifm_shape = [1, height, width, 1]
+    height = inp_shape.height * inp_shape.width
+    width = inp_shape.depth
+    ifm_shape = Shape4D([1, height, width, 1])
+
     op.attrs["padding"] = Padding.VALID
     op.attrs["stride_w"] = 1
     op.attrs["stride_h"] = 1
@@ -58,11 +63,14 @@
     op.attrs["strides"] = [1, op.attrs["stride_h"], op.attrs["stride_w"], 1]
     op.attrs["ksize"] = [1, op.attrs["filter_height"], op.attrs["filter_width"], 1]
     op.activation = activation
-    op.inputs = [create_reshape_tensor(ifm, ifm_shape)]
+    op.inputs = [ifm]
     ofm = Tensor([1, height, 1, 1], ifm.dtype, op.name + "_tens0")
     ofm.quantization = quantization
     op.set_output_tensor(ofm)
-    op.set_ifm_ofm_shapes()
+    op.ifm_shapes.append(ifm_shape)
+    op.ofm_shapes.append(Shape4D(ofm.shape))
+    op.ifm.avoid_NHCWB16 = True
+    op.ofm.avoid_NHCWB16 = True
     return op
 
 
@@ -95,8 +103,12 @@
     activation: Optional[ActivationFunction] = None,
     dtype: Optional[DataType] = None,
     attrs: Optional[dict] = None,
+    ifm_shape: Optional[Shape4D] = None,
+    ifm2_shape: Optional[Shape4D] = None,
 ) -> Operation:
-    return create_binary_elementwise(Op.Add, name, ifm, ifm2, quantization, activation, dtype, attrs)
+    return create_binary_elementwise(
+        Op.Add, name, ifm, ifm2, quantization, activation, dtype, attrs, ifm_shape, ifm2_shape
+    )
 
 
 def create_rescale_add(
@@ -108,8 +120,12 @@
     activation: Optional[ActivationFunction] = None,
     dtype: Optional[DataType] = None,
     attrs: Optional[dict] = None,
+    ifm_shape: Optional[Shape4D] = None,
+    ifm2_shape: Optional[Shape4D] = None,
 ) -> Operation:
-    op = create_binary_elementwise(Op.RescaleAdd, name, ifm, ifm2, quantization, activation, dtype, attrs)
+    op = create_binary_elementwise(
+        Op.RescaleAdd, name, ifm, ifm2, quantization, activation, dtype, attrs, ifm_shape, ifm2_shape
+    )
     op.rescale = rescale
     return op
 
@@ -121,8 +137,9 @@
     activation: Optional[ActivationFunction] = None,
     dtype: Optional[DataType] = None,
     attrs: Optional[dict] = None,
+    ifm_shape: Optional[Shape4D] = None,
 ) -> Operation:
-    return create_unary_elementwise(Op.CLZ, name, ifm, quantization, activation, dtype, attrs)
+    return create_unary_elementwise(Op.CLZ, name, ifm, quantization, activation, dtype, attrs, ifm_shape)
 
 
 def create_mul(
@@ -133,8 +150,12 @@
     activation: Optional[ActivationFunction] = None,
     dtype: Optional[DataType] = None,
     attrs: Optional[dict] = None,
+    ifm_shape: Optional[Shape4D] = None,
+    ifm2_shape: Optional[Shape4D] = None,
 ) -> Operation:
-    return create_binary_elementwise(Op.Mul, name, ifm, ifm2, quantization, activation, dtype, attrs)
+    return create_binary_elementwise(
+        Op.Mul, name, ifm, ifm2, quantization, activation, dtype, attrs, ifm_shape, ifm2_shape
+    )
 
 
 def create_shl(
@@ -145,8 +166,12 @@
     activation: Optional[ActivationFunction] = None,
     dtype: Optional[DataType] = None,
     attrs: Optional[dict] = None,
+    ifm_shape: Optional[Shape4D] = None,
+    ifm2_shape: Optional[Shape4D] = None,
 ) -> Operation:
-    return create_binary_elementwise(Op.SHL, name, ifm, ifm2, quantization, activation, dtype, attrs)
+    return create_binary_elementwise(
+        Op.SHL, name, ifm, ifm2, quantization, activation, dtype, attrs, ifm_shape, ifm2_shape
+    )
 
 
 def create_shr(
@@ -157,8 +182,12 @@
     activation: Optional[ActivationFunction] = None,
     dtype: Optional[DataType] = None,
     attrs: Optional[dict] = None,
+    ifm_shape: Optional[Shape4D] = None,
+    ifm2_shape: Optional[Shape4D] = None,
 ) -> Operation:
-    return create_binary_elementwise(Op.SHR, name, ifm, ifm2, quantization, activation, dtype, attrs)
+    return create_binary_elementwise(
+        Op.SHR, name, ifm, ifm2, quantization, activation, dtype, attrs, ifm_shape, ifm2_shape
+    )
 
 
 def create_sub(
@@ -169,8 +198,12 @@
     activation: Optional[ActivationFunction] = None,
     dtype: Optional[DataType] = None,
     attrs: Optional[dict] = None,
+    ifm_shape: Optional[Shape4D] = None,
+    ifm2_shape: Optional[Shape4D] = None,
 ) -> Operation:
-    return create_binary_elementwise(Op.Sub, name, ifm, ifm2, quantization, activation, dtype, attrs)
+    return create_binary_elementwise(
+        Op.Sub, name, ifm, ifm2, quantization, activation, dtype, attrs, ifm_shape, ifm2_shape
+    )
 
 
 def create_unary_elementwise(
@@ -181,8 +214,9 @@
     activation: Optional[ActivationFunction] = None,
     dtype: Optional[DataType] = None,
     attrs: Optional[dict] = None,
+    ifm_shape: Optional[Shape4D] = None,
 ) -> Operation:
-    return create_binary_elementwise(op_type, name, ifm, None, quantization, activation, dtype, attrs)
+    return create_binary_elementwise(op_type, name, ifm, None, quantization, activation, dtype, attrs, ifm_shape, None)
 
 
 def create_binary_elementwise(
@@ -194,19 +228,34 @@
     activation: Optional[ActivationFunction] = None,
     dtype: Optional[DataType] = None,
     attrs: Optional[dict] = None,
+    ifm_shape: Optional[Shape4D] = None,
+    ifm2_shape: Optional[Shape4D] = None,
 ) -> Operation:
+    if ifm_shape is None:
+        ifm_shape = Shape4D(ifm.shape)
     op = Operation(op_type, name)
     op.add_input_tensor(ifm)
+    op.ifm_shapes.append(ifm_shape)
     if ifm2:
         op.add_input_tensor(ifm2)
+        if ifm2_shape is None:
+            ifm2_shape = Shape4D(ifm2.shape)
+        op.ifm_shapes.append(ifm2_shape)
     op.activation = activation
     if not dtype:
         dtype = ifm.dtype
     if attrs:
         op.attrs.update(attrs)
-    ofm_shape = ifm.shape if ifm2 is None or ifm_ifm2_correct_order(ifm.shape, ifm2.shape) else ifm2.shape
-    ofm = Tensor(ofm_shape, dtype, f"{op.name}_tens0")
+
+    if ifm2 is None:
+        ofm_shape = ifm_shape
+    else:
+        in_shape = [] if ifm.shape == [] else ifm_shape.as_list()
+        in2_shape = [] if ifm2.shape == [] else ifm2_shape.as_list()
+        ofm_shape = ifm_shape if ifm_ifm2_correct_order(in_shape, in2_shape) else ifm2_shape
+
+    ofm = Tensor(ofm_shape.as_list(), dtype, f"{op.name}_tens0")
     ofm.quantization = quantization
     op.set_output_tensor(ofm)
-    op.set_ifm_ofm_shapes()
+    op.ofm_shapes.append(ofm_shape)
     return op
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index ee0d712..a95e383 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -150,7 +150,7 @@
         # ops_set
         npu_pre_ops,
         # incompatible_pack_flags
-        PassFlags.Cpu | PassFlags.MemoryOnly,
+        PassFlags.Cpu | PassFlags.MemoryOnly | PassFlags.ElementWise,
         # flags_to_set
         PassFlags.Npu | PassFlags.Mac | PassFlags.Pre | PassFlags.ElementWise,
         # flags_to_clear
@@ -458,11 +458,11 @@
             avgpool_out = inp.clone("_avgpooled")
             avgpool_out.consumer_list.append(op)
             avgpool_op.set_output_tensor(avgpool_out)
-            avgpool_op.set_ifm_ofm_shapes()
+            avgpool_op.ifm_shapes = op.ifm_shapes.copy()
+            avgpool_op.ofm_shapes = op.ofm_shapes.copy()
 
             op.inputs[0] = avgpool_out
             op_list.insert(0, avgpool_op)
-            op.set_ifm_ofm_shapes()
 
             DebugDatabase.add_optimised(op, avgpool_op)
             return avgpool_op
diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py
index a1b4fea..8981e20 100644
--- a/ethosu/vela/shape4d.py
+++ b/ethosu/vela/shape4d.py
@@ -75,3 +75,6 @@
 
     def as_list(self):
         return list(self._shape4D)
+
+    def get_hw_as_list(self):
+        return list([self.height, self.width])
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index 8a1770e..656a7e6 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -39,8 +39,8 @@
 from .operation_util import create_shl
 from .operation_util import create_shr
 from .operation_util import create_sub
+from .shape4d import Shape4D
 from .tensor import create_const_tensor
-from .tensor import create_reshape_tensor
 from .tensor import TensorPurpose
 
 
@@ -214,12 +214,13 @@
         ofm = self.op.outputs[0]
 
         # Reshape ifm/ofm (if needed)
-        full_shape = self.op.ifm_shapes[0].as_list()
-        if full_shape[0] > 1:
-            full_shape[1] *= full_shape[0]
-            full_shape[0] = 1
-        ifm = create_reshape_tensor(ifm, full_shape)
-        ofm = create_reshape_tensor(ofm, full_shape, False)
+        ifm_shape = self.op.ifm_shapes[0]
+        if ifm_shape.batch > 1:
+            ifm_shape.height = ifm_shape.batch * ifm_shape.height
+            ifm_shape.batch = 1
+            self.op.ifm.avoid_NHCWB16 = True
+            self.op.ofm_shapes[0] = ifm_shape.clone()
+            self.op.ofm.avoid_NHCWB16 = True
 
         if ifm.dtype in (DataType.uint8, DataType.int8) and ofm.dtype == ifm.dtype:
             return self.get_graph_8bit(ifm, ofm)
@@ -233,7 +234,6 @@
         exp_lut = self.generate_exp_table(self.op.attrs.get("beta", 1.0), ifm.quantization.scale_f32)
         no_scale_quant = ifm.quantization.clone()
         no_scale_quant.scale_f32 = None
-        no_scale_quant.zero_point = 0
         activation = ActivationFunction(Op.Clip)
         activation.min = ifm.quantization.quant_min
         activation.max = ifm.quantization.quant_max
@@ -245,7 +245,6 @@
         one_scale_quant.zero_point = 0
         two_scale_quant = one_scale_quant.clone()
         two_scale_quant.scale_f32 = 2.0
-        ifm.quantization.zero_point = 0
         pass_number = 0
 
         def add_op_get_ofm(op):
@@ -255,13 +254,25 @@
             return op.ofm
 
         # PASS 0 - Depthwise Maxpool
-        ifm_max = add_op_get_ofm(create_depthwise_maxpool(f"{self.op.name}_maxpool{pass_number}", ifm, no_scale_quant))
+        ifm_shape = self.op.ifm_shapes[0]
+        ifm_max = add_op_get_ofm(
+            create_depthwise_maxpool(f"{self.op.name}_maxpool{pass_number}", ifm, ifm_shape, no_scale_quant)
+        )
 
         # PASS 1 - Sub+LUT(exp)
         sub_op_quantization = one_scale_quant.clone()
         sub_op_quantization.zero_point = 127
-        ifm_max = create_reshape_tensor(ifm_max, [1, ifm.shape[1], ifm.shape[2], 1])
-        sub_op = create_sub(f"{self.op.name}_sub{pass_number}", ifm, ifm_max, sub_op_quantization, dtype=DataType.int32)
+        ifm_max_shape = Shape4D([1, ifm_shape.height, ifm_shape.width, 1])
+        ifm_max.avoid_NHCWB16 = True
+        sub_op = create_sub(
+            f"{self.op.name}_sub{pass_number}",
+            ifm,
+            ifm_max,
+            sub_op_quantization,
+            dtype=DataType.int32,
+            ifm_shape=ifm_shape,
+            ifm2_shape=ifm_max_shape,
+        )
         sub_op.set_activation_lut(
             create_const_tensor(
                 f"{sub_op.name}_exp_lut", [1, 1, 1, 256], DataType.int32, exp_lut, np.int32, TensorPurpose.LUT
@@ -415,7 +426,9 @@
         shr30_op.add_input_tensor(scaled_exp)
         shr30_op.add_input_tensor(right_shift)
         shr30_op.set_output_tensor(ofm)
-        shr30_op.set_ifm_ofm_shapes()
+        shr30_op.ifm_shapes.append(Shape4D(scaled_exp.shape))
+        shr30_op.ifm_shapes.append(Shape4D(right_shift.shape))
+        shr30_op.ofm_shapes.append(Shape4D(scaled_exp.shape))
         DebugDatabase.add_optimised(self.op, shr30_op)
 
         return shr30_op
@@ -432,12 +445,24 @@
             return op.ofm
 
         # PASS 0 - Depthwise Maxpool
-        ifm_max = add_op_get_ofm(create_depthwise_maxpool(f"{self.op.name}_maxpool{pass_number}", ifm, no_scale_quant))
+        ifm_shape = self.op.ifm_shapes[0]
+        ifm_max = add_op_get_ofm(
+            create_depthwise_maxpool(f"{self.op.name}_maxpool{pass_number}", ifm, ifm_shape, no_scale_quant)
+        )
 
         # PASS 1 - Sub
-        ifm_max = create_reshape_tensor(ifm_max, [1, ifm.shape[1], ifm.shape[2], 1])
+        ifm_max_shape = Shape4D([1, ifm_shape.height, ifm_shape.width, 1])
+        ifm_max.avoid_NHCWB16 = True
         sub1_ofm = add_op_get_ofm(
-            create_sub(f"{self.op.name}_sub{pass_number}", ifm, ifm_max, ifm.quantization.clone(), dtype=DataType.int32)
+            create_sub(
+                f"{self.op.name}_sub{pass_number}",
+                ifm,
+                ifm_max,
+                ifm.quantization.clone(),
+                dtype=DataType.int32,
+                ifm_shape=ifm_shape,
+                ifm2_shape=ifm_max_shape,
+            )
         )
 
         # PASS 2 - Mul
@@ -537,7 +562,9 @@
         shr13_op.add_input_tensor(mul_ofm)
         shr13_op.add_input_tensor(reciprocal_right_shift)
         shr13_op.set_output_tensor(ofm)
-        shr13_op.set_ifm_ofm_shapes()
+        shr13_op.ifm_shapes.append(Shape4D(mul_ofm.shape))
+        shr13_op.ifm_shapes.append(Shape4D(reciprocal_right_shift.shape))
+        shr13_op.ofm_shapes.append(Shape4D(mul_ofm.shape))
         DebugDatabase.add_optimised(self.op, shr13_op)
 
         return shr13_op
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index fb877ca..ef8a28f 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -314,26 +314,6 @@
     return const_tensor
 
 
-def create_reshape_tensor(tens, shape, ifm_reshape=True):
-    if shape == tens.shape:
-        return tens
-    # Tensors
-    name = tens.name + "_reshape"
-    reshape_ifm = tens
-    reshape_ofm = tens.clone("_reshaped")
-    reshape_ofm.set_all_shapes(shape)
-    if not ifm_reshape:
-        reshape_ifm, reshape_ofm = reshape_ofm, reshape_ifm
-    # Operator
-    reshape_op = Operation(Op.Reshape, name)
-    reshape_op.attrs["new_shape"] = shape
-    reshape_op.add_input_tensor(reshape_ifm)
-    reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape))
-    reshape_op.set_output_tensor(reshape_ofm)
-    reshape_op.set_ifm_ofm_shapes()
-    return reshape_ofm if ifm_reshape else reshape_ifm
-
-
 # class that keeps track of all tensor addresses in the different memory types
 class TensorAddressMap:
     address_map: Dict = defaultdict(dict)  # dict (tens.equivalence_id -> dict (mem_type -> address))
@@ -443,6 +423,10 @@
     def address(self, address: int):
         TensorAddressMap.set_address_for_tens(self.equivalence_id, self.mem_type, address)
 
+    @property
+    def is_standard_fm(self) -> bool:
+        return self.sub_purpose == TensorSubPurpose.Standard and self.purpose == TensorPurpose.FeatureMap
+
     def element_size(self) -> int:
         if self.element_size_bytes == 0:
             return self.dtype.size_in_bits() / 8
@@ -540,6 +524,15 @@
         rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
         return rounded_size
 
+    def storage_size_for_shape(self, op_storage_shape: Shape) -> int:
+        elems = shape_num_elements(op_storage_shape)
+        elems = elems if elems else 0
+        raw_size = elems * self.element_size()
+        if raw_size == 0:
+            raw_size = 1  # force it to take up space
+        rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
+        return rounded_size
+
     def storage_size_for_sub_purpose(
         self, arch, sub_purpose: TensorSubPurpose, param_a: Optional[int] = None, param_b: Optional[int] = None
     ) -> int:
@@ -614,7 +607,11 @@
     def consumers(self) -> List[Operation]:
         return self.consumer_list
 
-    def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, fm_shape: Shape4D) -> Tuple:
+    def get_4D_storage_shape_for_shape(self, op_shape4D: Shape4D) -> Shape4D:
+        rounding_quantum = full_shape(4, list(self.storage_rounding_quantum), 1)
+        return Shape4D(shape_round_to_quantum(op_shape4D.as_list(), rounding_quantum))
+
+    def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, op_shape4D: Shape4D) -> Tuple:
         # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] )
 
         if self.storage_shape == []:
@@ -622,12 +619,16 @@
                 1,
                 1,
                 1,
-                [self.address_for_coordinate(start_coord, shape=fm_shape.as_list()), None, None, None],
+                [self.address_for_coordinate(start_coord, op_shape4D=op_shape4D), None, None, None],
             )
 
-        storage_shape_4D = full_shape(4, self.storage_shape, 1)
-        crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D[1])
-        crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D[2])
+        if self.is_standard_fm:
+            storage_shape_4D = self.get_4D_storage_shape_for_shape(op_shape4D)
+        else:
+            storage_shape_4D = Shape4D(self.storage_shape)
+
+        crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D.height)
+        crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D.width)
 
         crossing_y = min(crossing_y, end_coord[1])
         crossing_x = min(crossing_x, end_coord[2])
@@ -636,39 +637,41 @@
         box_width = crossing_x - start_coord[2]
 
         addresses: List = [None] * 4
-        addresses[0] = self.address_for_coordinate(start_coord, shape=fm_shape.as_list())
+        addresses[0] = self.address_for_coordinate(start_coord, op_shape4D=op_shape4D)
 
         if end_coord[2] > crossing_x:
             addresses[1] = self.address_for_coordinate(
-                [start_coord[0], start_coord[1], crossing_x, start_coord[3]], shape=fm_shape.as_list()
+                [start_coord[0], start_coord[1], crossing_x, start_coord[3]], op_shape4D=op_shape4D
             )
             raise UnsupportedFeatureError("Striping in vertical direction is not supported")
         if end_coord[1] > crossing_y:
             addresses[2] = self.address_for_coordinate(
-                [start_coord[0], crossing_y, start_coord[2], start_coord[3]], shape=fm_shape.as_list()
+                [start_coord[0], crossing_y, start_coord[2], start_coord[3]], op_shape4D=op_shape4D
             )
         if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
             addresses[3] = self.address_for_coordinate(
-                [start_coord[0], crossing_y, crossing_x, start_coord[3]], shape=fm_shape.as_list()
+                [start_coord[0], crossing_y, crossing_x, start_coord[3]], op_shape4D=op_shape4D
             )
 
         return box_height0, box_height0, box_width, addresses
 
-    def address_for_coordinate(self, coord: Shape, is_top_box: bool = False, shape: Shape = None) -> int:
-        if shape is None:
-            shape = self.shape
-        offset = self.address_offset_for_coordinate(coord, shape, is_top_box)
+    def address_for_coordinate(self, coord: Shape, is_top_box: bool = False, op_shape4D: Shape4D = None) -> int:
+        offset = self.address_offset_for_coordinate(coord, op_shape4D=op_shape4D, is_top_box=is_top_box)
         assert offset is not None
         return self.address + offset
 
-    def get_strides_and_coord(self, coord: Optional[Shape] = None) -> Tuple[Optional[Shape], Optional[Shape]]:
+    def get_strides_and_coord(
+        self, coord: Optional[Shape] = None, shape4D: Optional[Shape4D] = None
+    ) -> Tuple[Optional[Shape], Optional[Shape]]:
         if coord is None:
             coord = [0] * len(self.storage_shape)
 
+        if shape4D and self.is_standard_fm:
+            augmented_shape = self.get_4D_storage_shape_for_shape(shape4D).as_list()
+        else:
+            augmented_shape = full_shape(4, self.storage_shape, 1)
+
         augmented_coord = coord
-        augmented_shape = self.storage_shape
-        while len(augmented_shape) < 4:
-            augmented_shape = [1] + augmented_shape
 
         while len(augmented_coord) < 4:
             augmented_coord = [0] + augmented_coord
@@ -713,8 +716,8 @@
 
         return strides, augmented_coord
 
-    def get_strides(self) -> Shape:
-        strides, _ = self.get_strides_and_coord()
+    def get_strides(self, shape4D: Optional[Shape4D] = None) -> Shape:
+        strides, _ = self.get_strides_and_coord(shape4D=shape4D)
         assert strides is not None
         return strides
 
@@ -769,13 +772,13 @@
         assert 0 <= index < len(self.compressed_values)
         return index == len(self.compressed_values) - 1
 
-    def address_offset_for_coordinate(self, orig_coord: Shape, shape: Shape, is_top_box: bool = False) -> Optional[int]:
+    def address_offset_for_coordinate(
+        self, orig_coord: Shape, op_shape4D: Optional[Shape4D] = None, is_top_box: bool = False
+    ) -> Optional[int]:
         address_offset = 0
-        coord = orig_coord
-
-        coord = coord[-len(self.storage_shape) :]
 
         if self.sub_purpose == TensorSubPurpose.Standard:
+            shape = op_shape4D.as_list() if op_shape4D else self.shape
             for idx, c in enumerate(orig_coord):
                 if is_top_box:
                     assert c > 0 and c <= shape[idx]
@@ -783,6 +786,7 @@
                     assert c >= 0 and c < shape[idx]
 
         if self.format == TensorFormat.WeightsCompressed:
+            storage_size = self.storage_size()
             if len(self.weight_compressed_offsets) == 0:
                 return 0
 
@@ -814,13 +818,22 @@
                 assert index < len(self.weight_compressed_offsets)
                 address_offset = self.weight_compressed_offsets[index]
         else:
+            coord = orig_coord
+            if op_shape4D and self.is_standard_fm:
+                storage_shape = self.get_4D_storage_shape_for_shape(op_shape4D).as_list()
+                storage_size = self.storage_size_for_shape(storage_shape)
+            else:
+                storage_shape = self.storage_shape
+                coord = coord[-len(storage_shape) :]
+                storage_size = self.storage_size()
+
             if is_top_box:
                 coord = [c - 1 for c in coord]
 
             # handle wraparound for partial buffers. make sure to do this after subtracting top box:
-            coord = [c % self.storage_shape[idx] for idx, c in enumerate(coord)]
+            coord = [c % storage_shape[idx] for idx, c in enumerate(coord)]
 
-            strides, augmented_coord = self.get_strides_and_coord(coord)
+            strides, augmented_coord = self.get_strides_and_coord(coord, op_shape4D)
             if strides is None:
                 return None
 
@@ -830,7 +843,7 @@
             address_offset += np.dot(augmented_coord, strides)
 
         assert address_offset >= 0
-        assert address_offset <= self.storage_size()
+        assert address_offset <= storage_size
         return address_offset
 
     def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area: MemArea) -> bool:
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index b3938bc..b01b07c 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -20,10 +20,12 @@
 
 from ethosu.vela.data_type import DataType
 from ethosu.vela.graph_optimiser import convert_batched_fc_shape
+from ethosu.vela.graph_optimiser import optimise_graph_a
 from ethosu.vela.graph_optimiser import optimise_pad
 from ethosu.vela.nn_graph import Graph
 from ethosu.vela.operation import Op
 from ethosu.vela.operation import Padding
+from ethosu.vela.rewrite_graph import verify_graph_health
 from ethosu.vela.tensor import create_const_tensor
 from ethosu.vela.tensor import Shape4D
 from ethosu.vela.tensor import Tensor
@@ -32,50 +34,49 @@
 
 def test_convert_batched_fc():
     """Tests shape conversion of batched fully connected"""
-    shape = [4, 8]
-    ifm = create_const_tensor("test_in", shape, np.uint8, np.zeros(shape))
-    weights = create_const_tensor("weight_in", shape, np.uint8, np.zeros(shape))
+    ifm_shape = [4, 8]
+    ifm = create_const_tensor("test_in", ifm_shape, np.uint8, np.zeros(ifm_shape))
+    w_shape = [8, 4]
+    weights = create_const_tensor("weight_in", w_shape, np.uint8, np.zeros(w_shape))
     ofm = Tensor(ifm.shape, np.uint8, "test_out")
     op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm)
 
     ifm.consumer_list.append(op)
 
-    op.ifm_shapes.append(Shape4D([4, 1, 1, 8]))
-    op.ofm_shapes.append(Shape4D([4, 1, 1, 8]))
-
     prev_op = op.clone()
-    prev_op.ifm_shapes = op.ifm_shapes
-    prev_op.ofm_shapes = op.ofm_shapes
-
-    conv_op = convert_batched_fc_shape(op, None, None)
-
-    assert conv_op.ifm != prev_op.ifm
-    assert conv_op.ofm != prev_op.ofm
-    assert conv_op.type == Op.FullyConnected
-    assert len(conv_op.ifm.shape) == 4
-    assert conv_op.ifm.shape == conv_op.ofm.shape
-    assert conv_op.ifm.ops[0].type == Op.Reshape
-
-    shape = [1, 8]
-    ifm.shape = shape
-    weights.shape = shape
-    ofm.shape = shape
-    op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm)
-    ifm.consumer_list.append(op)
-
-    op.ifm_shapes.append([1, 1, 1, 8])
-    op.ofm_shapes.append([1, 1, 1, 8])
-
-    prev_op = op.clone()
-    prev_op.ifm_shapes = op.ifm_shapes
-    prev_op.ofm_shapes = op.ofm_shapes
+    prev_op.ifm_shapes = op.ifm_shapes.copy()
+    prev_op.ofm_shapes = op.ofm_shapes.copy()
 
     conv_op = convert_batched_fc_shape(op, None, None)
 
     assert conv_op.ifm == prev_op.ifm
     assert conv_op.ofm == prev_op.ofm
+    assert op.ifm_shapes[0] == Shape4D([1, 2, 2, 8])
+    assert op.ofm_shapes[0] == Shape4D([1, 2, 2, 8])
     assert conv_op.type == Op.FullyConnected
     assert len(conv_op.ifm.shape) == 2
+    assert len(conv_op.ofm.shape) == 2
+    assert conv_op.ifm.shape == conv_op.ofm.shape
+
+    ifm.shape = [1, 8]
+    weights.shape = [8, 1]
+    ofm.shape = [1, 8]
+    op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm)
+    ifm.consumer_list.append(op)
+
+    prev_op = op.clone()
+    prev_op.ifm_shapes = op.ifm_shapes.copy()
+    prev_op.ofm_shapes = op.ofm_shapes.copy()
+
+    conv_op = convert_batched_fc_shape(op, None, None)
+
+    assert conv_op.ifm == prev_op.ifm
+    assert conv_op.ofm == prev_op.ofm
+    assert op.ifm_shapes[0] == prev_op.ifm_shapes[0]
+    assert op.ofm_shapes[0] == prev_op.ofm_shapes[0]
+    assert conv_op.type == Op.FullyConnected
+    assert len(conv_op.ifm.shape) == 2
+    assert len(conv_op.ofm.shape) == 2
     assert conv_op.ifm.shape == conv_op.ofm.shape
 
 
@@ -118,3 +119,91 @@
     assert op.attrs["explicit_padding"] == (2, 1, 1, 1)
     assert op.ifm.shape == [1, 76, 75, 64]
     assert pad_op not in op.ifm.ops
+
+
+def test_remove_reshape():
+    """
+    Tests that the expected reshape are removed in graph_optimisation
+    """
+
+    def setup_network():
+        quant = testutil.default_quant_params()
+        # create reshape1 op
+        ifm_shape = [64, 16]
+        reshape1_ofm_shape = [1, 4, 16, 16]
+        reshape1_ifm = create_const_tensor("reshape1_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
+        reshape1_ifm.quantization = quant
+        reshape1_ofm = create_const_tensor(
+            "reshape1_out", reshape1_ofm_shape, DataType.uint8, np.zeros(reshape1_ofm_shape)
+        )
+        reshape1_ofm.quantization = quant
+        shape_tens = create_const_tensor("reshape1_shape", [1], DataType.int32, reshape1_ofm_shape)
+        reshape1_op = testutil.create_op(Op.Reshape, [reshape1_ifm, shape_tens], reshape1_ofm, set_ifm_ofm_shapes=False)
+        reshape1_op.attrs["new_shape"] = reshape1_ofm_shape
+        reshape1_op.run_on_npu = True
+
+        # create reshape2 op
+        reshape2_ofm_shape = [1, 8, 8, 16]
+        reshape2_ofm = create_const_tensor(
+            "reshape2_out", reshape2_ofm_shape, DataType.uint8, np.zeros(reshape2_ofm_shape)
+        )
+        reshape2_ofm.quantization = quant
+        shape_tens = create_const_tensor("reshape2_shape", [1], DataType.int32, reshape2_ofm_shape)
+        reshape2_op = testutil.create_op(Op.Reshape, [reshape1_ofm, shape_tens], reshape2_ofm, set_ifm_ofm_shapes=False)
+        reshape2_op.attrs["new_shape"] = reshape2_ofm_shape
+        reshape2_op.run_on_npu = True
+
+        # create conv op
+        conv_ofm = Tensor([1, 8, 8, 16], DataType.uint8, "output")
+        conv_ofm.quantization = quant.clone()
+        weight_tens = Tensor([1, 1, 16, 16], DataType.uint8, "weights")
+        weight_tens.values = np.zeros(weight_tens.shape)
+        weight_tens.quant_values = np.zeros(weight_tens.shape, np.uint8)
+        weight_tens.quantization = quant.clone()
+        bias_tens = Tensor([16], DataType.int32, "biases")
+
+        attrs = {"padding": Padding.SAME, "stride_w": 1, "stride_h": 1, "dilation_w_factor": 1, "dilation_h_factor": 1}
+        attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
+
+        conv2d_op = testutil.create_op(
+            Op.Conv2D, [reshape1_ofm, weight_tens, bias_tens], conv_ofm, attrs=attrs, set_ifm_ofm_shapes=False
+        )
+        conv2d_op.run_on_npu = True
+
+        # create reshape3 op
+        ofm_shape = [8, 8, 16]
+        reshape3_ofm = create_const_tensor("reshape3_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
+        reshape3_ofm.quantization = quant
+        shape_tens = create_const_tensor("reshape3_shape", [1], DataType.int32, ofm_shape)
+        reshape3_op = testutil.create_op(Op.Reshape, [conv_ofm, shape_tens], reshape3_ofm, set_ifm_ofm_shapes=False)
+        reshape3_op.attrs["new_shape"] = ofm_shape
+        reshape3_op.run_on_npu = True
+        nng = Graph()
+        sg = testutil.create_subgraph([reshape1_op, reshape2_op, conv2d_op, reshape3_op])
+        nng.subgraphs.append(sg)
+
+        return nng, reshape1_op, reshape2_op, conv2d_op, reshape3_op
+
+    # Test1 no Reshape op is expected to remain in the NPU subgrapgh
+    # but first one will be put on CPU
+    # Network is Reshape-Reshape-Conv-Reshape
+    # Result is cpu_Reshape-Conv
+    nng, reshape1_op, reshape2_op, conv2d_op, reshape3_op = setup_network()
+    arch = testutil.create_arch()
+    assert verify_graph_health(nng)
+    nng = optimise_graph_a(nng, arch)
+    assert verify_graph_health(nng)
+    assert conv2d_op.ifm == reshape1_op.ofm
+    assert conv2d_op.ofm == reshape3_op.ofm
+
+    # Test2 reshape2 with different quantisation, this Reshape op is expected to remain
+    # Network is Reshape-Reshape-Conv-Reshape
+    # expected is cpu_Reshape-Reshape-Conv
+    nng, reshape1_op, reshape2_op, conv2d_op, reshape3_op = setup_network()
+    quant_zp32 = testutil.default_quant_params()
+    quant_zp32.zero_point = 32
+    reshape2_op.ofm.quantization = quant_zp32
+    assert verify_graph_health(nng)
+    nng = optimise_graph_a(nng, arch)
+    assert verify_graph_health(nng)
+    assert conv2d_op.ofm == reshape3_op.ofm
diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py
index 96aeb7e..02e01a5 100644
--- a/ethosu/vela/test/testutil.py
+++ b/ethosu/vela/test/testutil.py
@@ -113,14 +113,15 @@
     return op
 
 
-def create_op(op_type, inputs, output, attrs=None):
+def create_op(op_type, inputs, output, attrs=None, set_ifm_ofm_shapes=True):
     op = Operation(op_type, output.name + "_op")
     for input in inputs:
         op.add_input_tensor(input)
     op.set_output_tensor(output)
     if attrs is not None:
         op.attrs = attrs
-    op.set_ifm_ofm_shapes()
+    if set_ifm_ofm_shapes:
+        op.set_ifm_ofm_shapes()
     return op