MLBEDSW-3772 Reshape removal

-Removed reshapes in the original graph
-Removed the addition of reshapes to the
 optimized graph

-Reshapes with different ifm/ofm quantisation will remain

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I94862be53dac0d7434815e2aee5ca678228495f8
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 5f11178..bb5a9e0 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -28,6 +28,7 @@
 from .data_type import DataType
 from .debug_database import DebugDatabase
 from .errors import UnsupportedFeatureError
+from .errors import VelaError
 from .ethos_u55_regs.ethos_u55_regs import resampling_mode
 from .numeric_util import clamp_sigmoid
 from .numeric_util import full_shape
@@ -42,7 +43,6 @@
 from .softmax import SoftMax
 from .tensor import check_quantized_tens_scaling_equal
 from .tensor import create_const_tensor
-from .tensor import create_reshape_tensor
 from .tensor import QuantizationParameters
 from .tensor import Tensor
 from .tflite_mapping import optype_to_builtintype
@@ -59,52 +59,68 @@
     return tens
 
 
-def rewrite_concat(tens, arch, nng):
-    if len(tens.ops) == 1 and tens.ops[0].type.is_concat_op():
-        concat_op = tens.ops[0]
-        if tens != concat_op.outputs[0]:
-            return tens  # don't attempt to rewrite the min/max outputs of QuantizedConcat
+def rewrite_concat_ops(op, arch, nng):
+    if not op.run_on_npu or not op.type.is_concat_op():
+        return op
 
-        # Not supported so leave it and run on CPU
-        if not concat_op.run_on_npu:
-            return tens
+    axis_4D = 0
+    ofm = op.ofm
+    ofm.ops = []
+    offset = 0
 
-        inputs, axis = concat_op.get_concat_inputs_axis()
+    if op.type == Op.Pack:
+        # Pack is also referred to as Stack
+        axis = int(op.attrs["axis"])
+        desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
 
-        tens.ops = []
-        offset = 0
-        for idx, inp in enumerate(inputs):
+        if axis >= 0:
+            axis_4D = axis + (4 - len(desired_shape))
+        else:
+            axis_4D = axis
+
+        for idx, inp in enumerate(op.inputs):
+            op.ifm_shapes[idx] = Shape4D(desired_shape)
+            if Shape4D(inp.shape) != op.ifm_shapes[idx]:
+                inp.avoid_NHCWB16 = True
+        op.type = Op.PackReshaped
+
+    inputs, axis = op.get_concat_inputs_axis()
+
+    for idx, inp in enumerate(inputs):
+        if op.type != Op.PackReshaped:
+            op.ifm_shapes[idx] = Shape4D(inp.shape)
             if axis >= 0:
                 axis_4D = axis + (4 - len(inp.shape))
             else:
                 axis_4D = axis
-            new_op = Operation(Op.ConcatSliceWrite, concat_op.name + str(idx))
-            new_op.inputs = [inp]
-            new_op.outputs = [tens]
-            new_op.attrs["concat_axis"] = axis_4D
-            new_op.attrs["concat_start"] = offset
-            offset += inp.shape[axis]
-            new_op.attrs["concat_end"] = offset
-            new_op.run_on_npu = True
-            tens.ops.append(new_op)
-            DebugDatabase.add_optimised(concat_op, new_op)
-            new_op.set_ifm_ofm_shapes()
-        assert tens.shape[axis] == offset
+        new_op = Operation(Op.ConcatSliceWrite, op.name + str(idx))
+        new_op.inputs = [inp]
+        new_op.outputs = [ofm]
+        new_op.attrs["concat_axis"] = axis_4D
+        new_op.attrs["concat_start"] = offset
+        offset += op.ifm_shapes[idx].get_dim(axis_4D)
 
-        # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
-        # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
-        # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
-        # and those addresses are always 16 byte aligned due to the NHCWB16 format.
-        if axis == -1 or axis == (len(tens.shape) - 1):
-            for op in tens.ops:
-                if op.attrs["concat_start"] % 16 != 0:
-                    tens.avoid_NHCWB16 = True
-                    break
+        new_op.attrs["concat_end"] = offset
+        new_op.run_on_npu = True
+        ofm.ops.append(new_op)
+        DebugDatabase.add_optimised(op, new_op)
+        new_op.ifm_shapes.append(op.ifm_shapes[idx].clone())
+        new_op.ofm_shapes.append(op.ofm_shapes[0].clone())
+    assert ofm.shape[axis] == offset
 
-    return tens
+    # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
+    # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
+    # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
+    # and those addresses are always 16 byte aligned due to the NHCWB16 format.
+    if axis == -1 or axis == (len(ofm.shape) - 1):
+        for op in ofm.ops:
+            if op.attrs["concat_start"] % 16 != 0:
+                ofm.avoid_NHCWB16 = True
+                break
+    return op
 
 
-def rewrite_split(tens, arch, nng):
+def rewrite_split_ops(tens, arch, nng):
 
     if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
         split_op = tens.ops[0]
@@ -118,20 +134,27 @@
         tens.ops = []
         new_op = Operation(Op.SplitSliceRead, split_op.name)
         new_op.inputs = [inp]
+        ofm_shape_idx = 0
 
         # For Split the offset cannot be extracted from the tensor so it has to
         # be calculated from the index of the output tensor
         if axis is not None:
             # Get the start and end of the split
             offset_start = [0] * 4
+            axis_4D_list = split_op.attrs.get("split_axis_4D", None)  # Present for UnpackReshaped and some StridedSlice
             for idx, out in enumerate(outputs):
-                split_op.ofm_shapes[idx] = Shape4D(out.shape)
-                if out == tens:
-                    break
-                if axis >= 0:
-                    axis_4D = axis + (4 - len(out.shape))
+                if axis_4D_list is not None:
+                    axis_4D = axis_4D_list[idx]
                 else:
-                    axis_4D = axis
+                    split_op.ofm_shapes[idx] = Shape4D(out.shape)
+                    if axis >= 0:
+                        axis_4D = axis + (4 - len(out.shape))
+                    else:
+                        axis_4D = axis
+
+                if out == tens:
+                    ofm_shape_idx = idx
+                    break
 
                 offset_start[axis_4D] += split_op.ofm_shapes[idx].get_dim(axis_4D)
 
@@ -145,7 +168,7 @@
         new_op.run_on_npu = True
         new_op.set_output_tensor(tens)
         new_op.ifm_shapes.append(Shape4D(inp.shape))
-        new_op.ofm_shapes.append(Shape4D(full_shape(4, tens.shape, 1)))
+        new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx].clone())
         DebugDatabase.add_optimised(split_op, new_op)
 
     return tens
@@ -158,9 +181,9 @@
     return total_padding
 
 
-def calc_padding_and_skirt(padding_type, kernel_size, stride, input_dims, explicit_padding):
-    ypad = needed_total_padding(int(input_dims[1]), int(stride[1]), int(kernel_size[0]))
-    xpad = needed_total_padding(int(input_dims[2]), int(stride[2]), int(kernel_size[1]))
+def calc_padding_and_skirt(padding_type, kernel_size, stride, input_shape, explicit_padding):
+    ypad = needed_total_padding(int(input_shape.height), int(stride[1]), int(kernel_size[0]))
+    xpad = needed_total_padding(int(input_shape.width), int(stride[2]), int(kernel_size[1]))
     if padding_type == Padding.SAME:
         left_pad = (xpad + 0) // 2
         right_pad = (xpad + 1) // 2
@@ -184,11 +207,11 @@
     return padding, skirt
 
 
-def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_dims, upscaling_factor):
+def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
     kernel_height, kernel_width = kernel_size[0], kernel_size[1]
     if padding_type == Padding.SAME:
-        ypad = needed_total_padding(int(input_dims[1]) * upscaling_factor, int(stride[1]), int(kernel_height))
-        xpad = needed_total_padding(int(input_dims[2]) * upscaling_factor, int(stride[2]), int(kernel_width))
+        ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
+        xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
         right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
         bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
         left_pad = max(kernel_width - 1 - right_pad, 0)
@@ -225,7 +248,7 @@
     op.name = op.name + "_add"
     op.attrs["resizebilinear"] = True
     # Create an input tensor filled with zeros
-    shape = op.outputs[0].shape
+    shape = op.ofm_shapes[0].as_list()
     tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
     tens.values = np.zeros(shape)
     tens.quant_values = np.zeros(shape, np.uint8)
@@ -258,8 +281,8 @@
         op.attrs["padding"] = Padding.SAME
     op.inputs[0].resampling_mode = resampling_mode.NEAREST
 
-    upscaled_shape = np.array(op.inputs[0].shape[1:3])
-    out_shape = np.array(op.outputs[0].shape[1:3])
+    upscaled_shape = op.ifm_shape[0].get_hw_as_list()
+    out_shape = op.ofm_shape[0].get_hw_as_list()
     if (upscaled_shape == upscaled_shape * 2 - shape_modifier).all():
         return op
 
@@ -276,8 +299,8 @@
             scaled_op.outputs = outputs
             scaled_op.outputs[0].ops = [scaled_op]
         else:
-            shape = outputs[0].shape.copy()
-            shape[1:3] = upscaled_shape[0:2]
+            shape = op.ofm_shapes[0].as_list()
+            shape[1:3] = upscaled_shape
             out_tens = Tensor(shape, DataType.int16, "{}_{}".format(op.outputs[0].name, count))
             out_tens.quantization = op.outputs[0].quantization.clone()
             out_tens.quantization.quant_min = np.iinfo(np.int16).min
@@ -300,11 +323,11 @@
 
 def fixup_resizebilinear(op, arch, nng):
     if op.type == Op.ResizeBilinear and op.run_on_npu:
-        if op.inputs[0].shape == op.outputs[0].shape:
+        if op.ifm_shapes[0] == op.ofm_shapes[0]:
             # Bypass nop resizebilinear
             op.inputs = op.inputs[:1]
             op.type = Op.Identity
-        elif op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1:
+        elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
             convert_resizebilinear_1x1_to_add(op)
         else:
             convert_resizebilinear_to_2x2_pool(op)
@@ -321,109 +344,26 @@
     return op
 
 
-def fixup_fully_connected_input(op, arch, nng):
-    if op.type == Op.FullyConnected:
-        inp = op.inputs[0]
-        weights = op.inputs[1]
-
-        n_in_elems = weights.shape[-2]
-        elms = inp.elements()
-        batch_size = elms // n_in_elems
-        assert batch_size * n_in_elems == elms
-
-        desired_shape = [batch_size, n_in_elems]
-        if inp.shape != desired_shape:
-            # mismatch, insert a reshape to fix this.
-            op.set_input_tensor(create_reshape_tensor(inp, desired_shape), 0)
-
-    return op
-
-
 def convert_batched_fc_shape(op, arch, nng):
     if op.type == Op.FullyConnected:
-        ifm = op.inputs[0]
-        ofm = op.outputs[0]
-        # Check if the FC is 2D and first dimension indicates batching
-        # TOD0 op.ifm_shape[0] > 1 is enough when refactory is complete
-        if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0].batch > 1:
-            n = ifm.shape[0]
+        # Check if the first dimension indicates batching
+        if op.ifm_shapes[0].batch > 1:
             batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
+            n = op.ifm_shapes[0].batch
             h, w = batching_split.get(n, (1, n))
+            op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
 
-            prev_op = ifm.ops[0]
-            desired_shape = [1, h, w, ifm.shape[-1]]
-            op.ifm_shapes[0] = Shape4D(desired_shape)
-
-            if len(ifm.consumer_list) == 1 and prev_op is not None and prev_op.type == Op.Reshape:
-                # There is a preceding Reshape
-                # Compare input of prev_op and input of op, to see if prev_op can be removed
-                ifm_prev_op = prev_op.inputs[0]
-                if ifm_prev_op.shape == ifm.shape and check_quantized_tens_scaling_equal(ifm_prev_op, ifm):
-                    # prev_op can be removed
-                    op.set_input_tensor(ifm_prev_op, 0)
-                else:
-                    op.inputs[0].set_all_shapes(desired_shape)
-                    prev_op.set_input_tensor(
-                        create_const_tensor(prev_op.inputs[1].name, [1], DataType.int32, desired_shape), 1
-                    )
-                    prev_op.attrs["new_shape"] = desired_shape
-            else:
-                # Add reshape op to the input if there is no preceding reshape
-                ifm.consumer_list.remove(op)
-                op.set_input_tensor(create_reshape_tensor(ifm, desired_shape), 0)
+            op.ifm.avoid_NHCWB16 = True
 
             # Reshape Weights to be 4D. IO becomes HWIO
             weight_tensor = op.inputs[1]
             weight_tensor.quant_values = np.expand_dims(np.expand_dims(weight_tensor.quant_values, axis=0), axis=0)
             weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
 
-            desired_shape = [1, h, w, ofm.shape[-1]]
-            op.ofm_shapes[0] = Shape4D(desired_shape)
-
-            if (
-                len(ofm.consumer_list) == 1
-                and ofm.consumer_list[0] is not None
-                and ofm.consumer_list[0].type == Op.Reshape
-            ):
-                # There is a subsequent Reshape
-                # Compare desired shape and output of consumer op, to see if consumer op can be removed
-                ofm_cons_op = ofm.consumer_list[0].outputs[0]
-                if desired_shape == ofm_cons_op.shape and check_quantized_tens_scaling_equal(ofm, ofm_cons_op):
-                    op.outputs[0] = ofm_cons_op
-                    op.outputs[0].ops = [op]
-                else:
-                    op.outputs[0].set_all_shapes(desired_shape)
-            else:
-                # Add reshape op to the output
-                op.set_output_tensor(create_reshape_tensor(ofm, desired_shape, False))
-    return op
-
-
-def fixup_pack_input(op, arch, nng):
-    if op.type == Op.Pack:
-        # Pack is also referred to as Stack
-        # Requires the rewrite_concat function to be called on the op afterwards
-        axis = int(op.attrs["axis"])
-        desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
-
-        # Construct 1 shape tensor to be used by all inserted reshape ops
-        new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, desired_shape)
-
-        for idx, inp in enumerate(op.inputs):
-            reshape_out = inp.clone("_reshaped")
-            reshape_out.set_all_shapes(desired_shape)
-
-            reshape_op = Operation(Op.Reshape, "{}{}_reshape".format(op.name, idx))
-            reshape_op.attrs["new_shape"] = desired_shape
-            reshape_op.inputs = [inp, new_shape_tens]
-            reshape_op.set_output_tensor(reshape_out)
-            reshape_op.set_ifm_ofm_shapes()
-            DebugDatabase.add_optimised(op, reshape_op)
-
-            op.inputs[idx] = reshape_out
-
-        op.type = Op.PackReshaped
-
+            n = op.ofm_shapes[0].batch
+            h, w = batching_split.get(n, (1, n))
+            op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
+            op.ofm.avoid_NHCWB16 = True
     return op
 
 
@@ -441,12 +381,19 @@
     return op
 
 
-def fixup_stridedslice_output(tens, arch, nng):
-    op = tens.ops[0]
-    if op.run_on_npu and op.type == Op.StridedSlice:
-        reshape_input_shape = tens.shape
-        new_axis_mask = op.attrs["new_axis_mask"]
-        shrink_axis_mask = op.attrs["shrink_axis_mask"]
+def rewrite_stridedslice_output(op, arch, nng):
+    if not op.run_on_npu or op.type != Op.StridedSlice:
+        return op
+
+    new_axis_mask = op.attrs["new_axis_mask"]
+    shrink_axis_mask = op.attrs["shrink_axis_mask"]
+
+    if shrink_axis_mask == 0 and new_axis_mask == 0:
+        return op
+
+    axis_4D = [0] * len(op.outputs)
+    for idx, out_tens in enumerate(op.outputs):
+        output_shape = list(out_tens.shape)
 
         if shrink_axis_mask != 0:
             n = 0
@@ -456,10 +403,16 @@
                 n += 1
                 shrink_axis_mask &= shrink_axis_mask - 1
                 axis = int(math.log2(prev_mask - shrink_axis_mask))
-                reshape_input_shape = reshape_input_shape[:axis] + [1] + reshape_input_shape[axis:]
+                output_shape = output_shape[:axis] + [1] + output_shape[axis:]
 
-            assert len(tens.shape) == (len(op.inputs[0].shape) - n)
+            assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
             op.attrs["shrink_axis_mask"] = 0
+            if axis >= 0:
+                axis_4D[idx] = axis + (4 - len(output_shape))
+            else:
+                axis_4D[idx] = axis
+            op.ofm_shapes[idx] = Shape4D(output_shape)
+
         elif new_axis_mask != 0:
             n = 0
             axis = 0
@@ -468,77 +421,62 @@
                 n += 1
                 new_axis_mask &= new_axis_mask - 1
                 axis = int(math.log2(prev_mask - new_axis_mask))
-                reshape_input_shape = reshape_input_shape[:axis] + reshape_input_shape[(axis + 1) :]
+                output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
                 new_axis_mask >>= 1
 
-            assert len(tens.shape) == (len(op.inputs[0].shape) + n)
+            assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
             op.attrs["new_axis_mask"] = 0
-        else:
-            # Equal Rank StridedSlice, no need to insert reshape
-            return tens
+            if axis >= 0:
+                axis_4D[idx] = axis + (4 - len(output_shape))
+            else:
+                axis_4D[idx] = axis
+            op.ofm_shapes[idx] = Shape4D(output_shape)
 
-        # Construct 1 shape tensor to be used by all inserted reshape ops
-        new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape)
+        if op.ofm_shapes[idx] != Shape4D(out_tens.shape):
+            out_tens.avoid_NHCWB16 = True
 
-        for idx, out_tens in enumerate(op.outputs):
-            op.ofm_shapes[idx] = Shape4D(new_shape_tens.shape)
-            reshape_in = out_tens.clone("_reshaped")
-            reshape_in.set_all_shapes(reshape_input_shape)
-            reshape_in.ops = [op]
-
-            reshape_op = Operation(Op.Reshape, "{}{}_reshape".format(op.name, idx))
-            reshape_op.attrs["new_shape"] = reshape_input_shape
-            reshape_op.inputs = [reshape_in, new_shape_tens]
-            reshape_op.set_output_tensor(out_tens)
-            reshape_op.set_ifm_ofm_shapes()
-
-            op.outputs[idx] = reshape_in
-
-    return tens
+    op.attrs["split_axis_4D"] = axis_4D
+    return op
 
 
-def fixup_unpack_output(tens, arch, nng):
-    op = tens.ops[0]
+def rewrite_unpack_output(op, arch, nng):
+    tens = op.outputs[0]
     if op.run_on_npu and op.type == Op.Unpack:
         # Unpack is also referred to as Unstack
-        # Requires the rewrite_split function to be called on the op afterwards
         axis = int(op.attrs["axis"])
         op.type = Op.UnpackReshaped
-        reshape_input_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
+        desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
 
-        # Construct 1 shape tensor to be used by all inserted reshape ops
-        new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape)
+        if axis >= 0:
+            axis_4D = axis + (4 - len(desired_output_shape))
+        else:
+            axis_4D = axis
 
+        axis_4D_list = [0] * len(op.outputs)
         for idx, out_tens in enumerate(op.outputs):
-            reshape_in = out_tens.clone("_reshaped")
-            reshape_in.set_all_shapes(reshape_input_shape)
-            reshape_in.ops = [op]
+            op.ofm_shapes[idx] = Shape4D(desired_output_shape)
+            axis_4D_list[idx] = axis_4D
+            if op.ofm_shapes[idx] != Shape4D(out_tens.shape):
+                out_tens.avoid_NHCWB16 = True
 
-            reshape_op = Operation(Op.Reshape, "{}{}_reshape".format(op.name, idx))
-            reshape_op.attrs["new_shape"] = reshape_input_shape
-            reshape_op.inputs = [reshape_in, new_shape_tens]
-            reshape_op.set_output_tensor(out_tens)
-            reshape_op.set_ifm_ofm_shapes()
-            DebugDatabase.add_optimised(op, reshape_op)
-
-            op.outputs[idx] = reshape_in
-    return tens
+        op.attrs["split_axis_4D"] = axis_4D_list
+    return op
 
 
 def add_padding_fields(op, arch, nng):
     if op.run_on_npu:
         if "padding" in op.attrs:
+            input_shape = op.ifm_shapes[0]
+            output_shape = op.ofm_shapes[0]
             if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
                 kernel_size = op.inputs[1].shape[:2]
-                input_shape = op.inputs[0].shape
             elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
                 kernel_size = op.attrs["ksize"][1:3]
-                input_shape = op.inputs[0].shape
             else:
                 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
 
             if op.type == Op.Conv2DBackpropInputSwitchedBias:
-                upscaling_factor = op.outputs[0].shape[1] // input_shape[1]
+                upscaling_factor = output_shape.height // input_shape.height
                 padding, skirt = calc_upscaled_padding_and_skirt(
                     op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
                 )
@@ -582,10 +520,10 @@
     # switch of the operator type (and weight order)
 
     if op.type == Op.DepthwiseConv2DBias and (op.attrs["depth_multiplier"] != 1):
-        ifm_tensor = op.inputs[0]
+        ifm_shape = op.ifm_shapes[0]
         weight_tensor = op.inputs[1]
-        ofm_tensor = op.outputs[0]
-        if (ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"]):
+        ofm_shape = op.ofm_shapes[0]
+        if (ifm_shape.depth == 1) and (ofm_shape.depth == op.attrs["depth_multiplier"]):
             # Change op type to Conv2d
             op.type = Op.Conv2DBias
             del op.attrs["channel_multiplier"]
@@ -596,7 +534,7 @@
         else:
             raise UnsupportedFeatureError(
                 f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},",
-                f" ifm channels = {ifm_tensor.shape[3]}, ofm channels = {ofm_tensor.shape[3]}",
+                f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}",
             )
         DebugDatabase.add_optimised(op, op)
     return op
@@ -620,17 +558,15 @@
         op.type == Op.Conv2DBias
         and op.op_index == 0
         and stride_x == 2
-        and len(ifm_tensor.shape) == 4
-        and ifm_tensor.shape[3] <= 4
-        and ifm_tensor.shape[2] % 2 == 0
+        and op.ifm_shapes[0].depth <= 4
+        and op.ifm_shapes[0].width % 2 == 0
         and weight_tensor is not None
         and weight_tensor.shape[1] >= 2
     ):
+        ifm_shape = op.ifm_shapes[0]
         # IFM
-        ifm_reshaped = create_reshape_tensor(
-            ifm_tensor, [ifm_tensor.shape[0], ifm_tensor.shape[1], ifm_tensor.shape[2] // 2, ifm_tensor.shape[3] * 2]
-        )
-        op.set_input_tensor(ifm_reshaped, 0)
+        op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
+        op.ifm.avoid_NHCWB16 = True
 
         # Weights
         weight_shape = weight_tensor.shape
@@ -657,8 +593,6 @@
         stride_x = 1
         op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
 
-        op.set_ifm_ofm_shapes()
-
     return op
 
 
@@ -683,27 +617,6 @@
             weight_tensor.quant_values = weight_tensor.quant_values.squeeze(axis=(0, 1))
             weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
 
-            # The output from a fully connected is expected to be 2D so we need to add a reshape layer to convert it
-            # back to 4D afterwards as the next layer is expecting that shape
-            orig_ofm_tensor = op.outputs[0]
-            # Reshape this ops output to be 2D: {(N*H*W), C} (We know N H and W are all 1 so this becomes {1, C})
-            fc_ofm_tensor = orig_ofm_tensor.clone("_fc")
-            fc_ofm_tensor.set_all_shapes([1, fc_ofm_tensor.shape[-1]])
-            fc_ofm_tensor.ops = [op]
-            # Add a reshape after the new OFM to convert it back to the original 4D shape
-            reshape_name = op.name + "_reshape"
-            new_shape_tens = create_const_tensor(reshape_name + "_shape", [1], DataType.int32, orig_ofm_tensor.shape)
-            reshape_op = Operation(Op.Reshape, reshape_name)
-            reshape_op.attrs["new_shape"] = orig_ofm_tensor.shape
-            reshape_op.inputs = [fc_ofm_tensor, new_shape_tens]
-            reshape_op.set_output_tensor(orig_ofm_tensor)
-            reshape_op.set_ifm_ofm_shapes()
-
-            # Replace this ops OFM to point to the 2D tensor
-            op.outputs[0] = fc_ofm_tensor
-            op.set_ifm_ofm_shapes()
-            # Record optimisation in debug database
-            DebugDatabase.add_optimised(op, reshape_op)
             DebugDatabase.add_optimised(op, op)
     return op
 
@@ -722,14 +635,6 @@
             # Tidy up and assign the ifm and ofm to the new op
             ifm.consumer_list.remove(op)
 
-            # if not 4d, reshape ifm/ofm
-            if len(ifm.shape) < 4:
-                ifm_shaped = create_reshape_tensor(ifm, full_shape(4, ifm.shape, 1))
-                ifm = ifm_shaped
-            if len(ofm.shape) < 4:
-                ofm_shaped = create_reshape_tensor(ofm, full_shape(4, ofm.shape, 1), False)
-                ofm = ofm_shaped
-
             relu_fused_op.add_input_tensor(ifm)
             relu_fused_op.set_output_tensor(ofm)
             relu_fused_op.set_ifm_ofm_shapes()
@@ -737,6 +642,7 @@
     return op
 
 
+# TODO remove if mem only ops can all be removed
 # Reorder activation op if it's after the memory only operations
 def fixup_act_reorder(op, arch, nng):
     if op.type.is_relu_op() or op.type in (Op.Sigmoid, Op.Tanh):
@@ -752,8 +658,8 @@
             act_op_out = act_op.inputs[0].clone("_acted")
             act_op_out.quantization = op.outputs[0].quantization.clone()
             act_op.set_output_tensor(act_op_out)
-            act_op.ifm_shapes[0] = Shape4D(prep_op.inputs[0].shape)
-            act_op.ofm_shapes[0] = Shape4D(act_op_out.shape)
+            act_op.ofm_shapes[0] = act_op.ifm_shapes[0].clone()
+            act_op.ifm_shapes[0] = prep_op.ifm_shapes[0].clone()
 
             # Update the consumer list
             act_op_out.consumer_list = op.outputs[0].consumer_list.copy()
@@ -1078,39 +984,94 @@
     return op
 
 
-def remove_unwanted_reshapes(op, arch, nng):
-    # Try to remove reshapes enclosing ElementWise operator with only one non-constant input
-    if not op.run_on_npu or not op.type.is_elementwise_op():
-        return op
+def remove_reshapes(op, arch):
+    if op.run_on_npu and op.type == Op.Reshape:
+        ofm = op.ofm
+        ifm = op.ifm
 
-    # Check if the ElementWise operator only have one non-constant input
-    non_const_tens = [x for x in op.inputs if x.ops[0].type != Op.Const]
-    if len(non_const_tens) != 1:
-        return op
-    ifm = non_const_tens[0]
+        # Check if quantization is the same in the input and output for the reshape ops
+        if not check_quantized_tens_scaling_equal(ifm, ofm):
+            # TODO Both tensors are needed, since quantisation properties currently are linked to Tensors.
+            # In order to remove this reshape either quantization properties need to be moved to Operator,
+            # or the reshape need to be replace with a NOP.
+            return
 
-    # Check if operation is enclosed by Reshapes that can be removed
-    ofm = op.outputs[0]
-    prev_op = ifm.ops[0]
-    if (
-        len(ifm.consumer_list) == 1
-        and prev_op.type == Op.Reshape
-        and len(ofm.consumer_list) == 1
-        and ofm.consumer_list[0].type == Op.Reshape
-    ):
-        # Operation is enclosed by reshapes, check if they can be removed
-        prev_op_ifm, prev_op_ofm = prev_op.get_ifm_ofm()
-        cons_op = ofm.consumer_list[0]
-        cons_op_ifm = ofm
-        cons_op_ofm = cons_op.outputs[0]
-        if len(prev_op_ifm.shape) == len(cons_op_ofm.shape):
-            # Check if quantization is the same in the input and output for the reshape ops
-            if check_quantized_tens_scaling_equal(prev_op_ifm, prev_op_ofm) and check_quantized_tens_scaling_equal(
-                cons_op_ifm, cons_op_ofm
-            ):
-                op.set_input_tensor(prev_op_ifm, 0)
-                op.set_output_tensor(cons_op_ofm)
-    return op
+        # Check if ifm is a sg input
+        if ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const):
+            # put the reshape on CPU
+            op.run_on_npu = False
+            return
+
+        # Check if Reshape ifm/ofm are network ifm/ofm
+        ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list)
+        ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list)
+
+        if ifm_is_sg_ofm and ofm_is_sg_ofm:
+            # Both ifm and ofm are sg outputs,add reshape to the ifm and put it on CPU
+            ifm_cons_list_copy = ifm.consumer_list.copy()
+            ifm_ops_copy = ifm.ops.copy()
+            for ifm_cons in ifm_cons_list_copy:
+                if ifm_cons is None:
+                    # Create a reshape op with ifm as output
+                    name = ifm.name + "_cpu_reshape"
+                    reshape_ifm = ifm.clone()
+                    reshape_op = Operation(Op.Reshape, name)
+                    reshape_op.attrs["new_shape"] = ifm.shape
+                    reshape_op.add_input_tensor(reshape_ifm)
+                    reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, ifm.shape))
+                    reshape_op.set_output_tensor(ifm)
+                    reshape_op.set_ifm_ofm_shapes()
+                    reshape_op.run_on_npu = False
+                    reshape_op.ofm.ops = [reshape_op]
+                    reshape_op.ofm.consumer_list = [None]
+
+                    # Set reshape_ifm producers
+                    for prev_op in ifm_ops_copy:
+                        prev_op.outputs = [reshape_ifm]
+                        reshape_ifm.ops.append(prev_op)
+
+                    # Set reshape_ifm consumers
+                    for ifm_cons in ifm_cons_list_copy:
+                        if ifm_cons is not None:
+                            for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
+                                if cons_ifm == ifm:
+                                    ifm_cons.set_input_tensor(reshape_ifm, ifm_idx)
+
+                    ifm = reshape_ifm
+                    break
+            ifm_is_sg_ofm = False
+
+        if ofm_is_sg_ofm:
+            # Bypassed by replacing ifm with ofm
+            ofm.ops = []
+            for prev_op in ifm.ops:
+                prev_op.outputs = [ofm]
+                ofm.ops.append(prev_op)
+
+            # All ifm consumers need to use ofm as input
+            for ifm_cons in ifm.consumer_list:
+                for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
+                    if cons_ifm == ifm:
+                        ifm_cons.set_input_tensor(ofm, ifm_idx)
+            if op.ifm_shapes[0] != op.ofm_shapes[0]:
+                ofm.avoid_NHCWB16 = True
+        else:
+            # Bypassed Reshape by replacing ofm with ifm
+            for cons in ofm.consumer_list:
+                for ifm_idx, cons_ifm in enumerate(cons.inputs):
+                    if cons_ifm == ofm:
+                        cons.set_input_tensor(ifm, ifm_idx)
+            if op.ifm_shapes[0] != op.ofm_shapes[0]:
+                ifm.avoid_NHCWB16 = True
+
+
+def check_reshapes(op, arch):
+    if op.run_on_npu and op.type == Op.Reshape:
+        ofm = op.ofm
+
+        if check_quantized_tens_scaling_equal(op.ifm, ofm):
+            # Reshape should have been removed
+            raise VelaError(f"Reshape op {op} expected to have been removed, still remains")
 
 
 def fuse_activation_function_with_prev(op, arch, nng):
@@ -1174,13 +1135,19 @@
 def add_attrs_to_resizebilinear(op, arch, nng):
     if op.type == Op.ResizeBilinear and op.run_on_npu:
         input_tensor = op.inputs[0]
-        upscaled_shape = [input_tensor.shape[1] * 2, input_tensor.shape[2] * 2]
-        out_shape = op.outputs[0].shape[1:3]
-        if not op.attrs["align_corners"] and out_shape == upscaled_shape:
+        input_shape = op.ifm_shapes[0]
+        upscaled_height = input_shape.height * 2
+        upscaled_width = input_shape.width * 2
+        out_shape = op.ofm_shapes[0]
+        if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
             # this means the output is supposed to be a x2 upscale,
             # so we need to do SAME padding
             op.attrs["padding"] = Padding.SAME
-        elif op.attrs["align_corners"] and out_shape == [upscaled_shape[0] - 1, upscaled_shape[1] - 1]:
+        elif (
+            op.attrs["align_corners"]
+            and out_shape.height == (upscaled_height - 1)
+            and out_shape.width == (upscaled_width - 1)
+        ):
             # here we can just run the avg pool without padding and
             # produce a (M * 2 - 1, N * 2 - 1) sized output
             op.attrs["padding"] = Padding.VALID
@@ -1229,26 +1196,52 @@
             nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
         )
 
+    # Handle Concat Ops
+    for idx, sg in enumerate(nng.subgraphs):
+        # rewrite graph pass
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng, sg, arch, [], [rewrite_concat_ops], rewrite_unsupported=False,
+        )
+
+    # Handle Split Ops
+    for idx, sg in enumerate(nng.subgraphs):
+        # rewrite graph pass
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng,
+            sg,
+            arch,
+            [],
+            [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
+            rewrite_unsupported=False,
+        )
+
+    for idx, sg in enumerate(nng.subgraphs):
+        # rewrite graph pass
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False,
+        )
+
+    # Removal of reshapes
+    for sg in nng.subgraphs:
+        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
+        sg.refresh_after_modification()
+
     op_rewrite_list = [
         set_tensor_equivalence,
         convert_depthwise_to_conv,
         convert_conv_to_fc,
         convert_softmax,
         optimise_strided_conv,
-        fixup_fully_connected_input,
         convert_batched_fc_shape,
-        fixup_pack_input,
         unfuse_activation_function,
         fixup_conv2d_backprop,
         fixup_relus_with_differing_ifm_ofm_scaling,
         fixup_act_reorder,
-        fixup_elementwise_with_scalars,
+        fixup_elementwise_with_scalars,  # TODO Move to early stage?
         reorder_depthwise_weights,
         fixup_resizebilinear,
         fixup_bias_tensors,
-        convert_nop_split_to_identity,
         convert_mul_max_to_abs_or_lrelu,
-        remove_unwanted_reshapes,
         convert_lrelu,
         convert_tanh_sigmoid_to_lut,
     ]
@@ -1269,24 +1262,9 @@
             [fuse_activation_function_with_prev, optimise_pad, add_padding_fields],
         )
 
-    # Post-optimisation operator debug tracing
+    # Post-optimisation operator debug tracing, and checking that no undesired reshapes are left in the graph
     for sg in nng.subgraphs:
-        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [_record_optimised])
-
-    if verbose_graph:
-        nng.print_graph()
-    return nng
-
-
-def optimise_graph_b(nng, arch, verbose_graph=False):
-    if verbose_graph:
-        nng.print_graph()
-
-    for idx, sg in enumerate(nng.subgraphs):
-        # combined rewrite graph pass
-        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], [],
-        )
+        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [check_reshapes, _record_optimised])
 
     if verbose_graph:
         nng.print_graph()