Revert "Revert "MLBEDSW-3645 4D class for op ifm/ofm shapes""

This reverts commit df0a5905177f3a1b836076bc3f9f39b2e86f1794.

Reason for revert: <INSERT REASONING HERE>

Change-Id: I891c66fb29db9d25e942947e8d1c29a10610de51
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index fdb0fae..1128a31 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -37,6 +37,7 @@
 from .operation import Operation
 from .operation import Padding
 from .operation_util import create_avgpool_nop
+from .shape4d import Shape4D
 from .softmax import SoftMax
 from .tensor import check_quantized_tens_scaling_equal
 from .tensor import create_const_tensor
@@ -82,6 +83,7 @@
             new_op.run_on_npu = True
             tens.ops.append(new_op)
             DebugDatabase.add_optimised(concat_op, new_op)
+            new_op.set_ifm_ofm_shapes()
         assert tens.shape[axis] == offset
 
         # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
@@ -121,7 +123,8 @@
                 if out == tens:
                     break
                 axis_4D = axis + (4 - len(out.shape))
-                offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
+
+                offset_start[axis_4D] += split_op.ofm_shapes[idx].get_dim(axis_4D)
 
                 # If start offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
                 if (offset_start[-1] % 16) != 0:
@@ -132,6 +135,7 @@
         new_op.attrs["split_start"] = offset_start
         new_op.run_on_npu = True
         new_op.set_output_tensor(tens)
+        new_op.set_ifm_ofm_shapes()
         DebugDatabase.add_optimised(split_op, new_op)
 
     return tens
@@ -189,6 +193,7 @@
     if op.type == Op.Conv2DBackpropInput:
         # flip the inputs
         op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
+        op.set_ifm_ofm_shapes()
         op.type = Op.Conv2DBackpropInputSwitchedBias
 
         # Update strides
@@ -216,8 +221,7 @@
     # Set the add inputs
     op.inputs[1] = op.inputs[0]
     op.inputs[0] = tens
-    op.ifm_shapes = []
-    op.ofm_shapes = []
+    op.set_ifm_ofm_shapes()
 
     return op
 
@@ -323,14 +327,14 @@
         ofm = op.outputs[0]
         # Check if the FC is 2D and first dimension indicates batching
         # TOD0 op.ifm_shape[0] > 1 is enough when refactory is complete
-        if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0][0] > 1:
+        if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0].batch > 1:
             n = ifm.shape[0]
             batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
             h, w = batching_split.get(n, (1, n))
 
             prev_op = ifm.ops[0]
             desired_shape = [1, h, w, ifm.shape[-1]]
-            op.ifm_shapes[0] = desired_shape
+            op.ifm_shapes[0] = Shape4D(desired_shape)
 
             if len(ifm.consumer_list) == 1 and prev_op is not None and prev_op.type == Op.Reshape:
                 # There is a preceding Reshape
@@ -356,7 +360,7 @@
             weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
 
             desired_shape = [1, h, w, ofm.shape[-1]]
-            op.ofm_shapes[0] = desired_shape
+            op.ofm_shapes[0] = Shape4D(desired_shape)
 
             if (
                 len(ofm.consumer_list) == 1
@@ -395,6 +399,7 @@
             reshape_op.attrs["new_shape"] = desired_shape
             reshape_op.inputs = [inp, new_shape_tens]
             reshape_op.set_output_tensor(reshape_out)
+            reshape_op.set_ifm_ofm_shapes()
             DebugDatabase.add_optimised(op, reshape_op)
 
             op.inputs[idx] = reshape_out
@@ -413,6 +418,7 @@
         act_op.set_output_tensor(out_tens)
         act_op.add_input_tensor(intermediate_tens)
         op.set_output_tensor(intermediate_tens)
+        act_op.set_ifm_ofm_shapes()
 
     return op
 
@@ -457,7 +463,7 @@
         new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape)
 
         for idx, out_tens in enumerate(op.outputs):
-            op.ofm_shapes[idx] = new_shape_tens
+            op.ofm_shapes[idx] = Shape4D(new_shape_tens.shape)
             reshape_in = out_tens.clone("_reshaped")
             reshape_in.set_all_shapes(reshape_input_shape)
             reshape_in.ops = [op]
@@ -466,6 +472,7 @@
             reshape_op.attrs["new_shape"] = reshape_input_shape
             reshape_op.inputs = [reshape_in, new_shape_tens]
             reshape_op.set_output_tensor(out_tens)
+            reshape_op.set_ifm_ofm_shapes()
 
             op.outputs[idx] = reshape_in
 
@@ -493,6 +500,7 @@
             reshape_op.attrs["new_shape"] = reshape_input_shape
             reshape_op.inputs = [reshape_in, new_shape_tens]
             reshape_op.set_output_tensor(out_tens)
+            reshape_op.set_ifm_ofm_shapes()
             DebugDatabase.add_optimised(op, reshape_op)
 
             op.outputs[idx] = reshape_in
@@ -588,7 +596,8 @@
     # caching/double buffering for the weights.
     # (Weights dont need to be reloaded for convs when IFM H and W are 1)
     if op.type == Op.Conv2DBias:
-        _, h, w, _ = op.ifm_shapes[0]
+        h = op.ifm_shapes[0].height
+        w = op.ifm_shapes[0].width
         kh, kw, _, _ = op.inputs[1].shape
         if h == 1 and w == 1 and kh == 1 and kw == 1:
             # Overwrite this op as a Fully Connected Op
@@ -616,9 +625,11 @@
             reshape_op.attrs["new_shape"] = orig_ofm_tensor.shape
             reshape_op.inputs = [fc_ofm_tensor, new_shape_tens]
             reshape_op.set_output_tensor(orig_ofm_tensor)
+            reshape_op.set_ifm_ofm_shapes()
 
             # Replace this ops OFM to point to the 2D tensor
             op.outputs[0] = fc_ofm_tensor
+            op.set_ifm_ofm_shapes()
             # Record optimisation in debug database
             DebugDatabase.add_optimised(op, reshape_op)
             DebugDatabase.add_optimised(op, op)
@@ -649,6 +660,7 @@
 
             relu_fused_op.add_input_tensor(ifm)
             relu_fused_op.set_output_tensor(ofm)
+            relu_fused_op.set_ifm_ofm_shapes()
             op = relu_fused_op
     return op
 
@@ -668,8 +680,8 @@
             act_op_out = act_op.inputs[0].clone("_acted")
             act_op_out.quantization = op.outputs[0].quantization.clone()
             act_op.set_output_tensor(act_op_out)
-            act_op.ifm_shapes[0] = full_shape(4, prep_op.inputs[0].shape, 1)
-            act_op.ofm_shapes[0] = full_shape(4, act_op_out.shape, 1)
+            act_op.ifm_shapes[0] = Shape4D(prep_op.inputs[0].shape)
+            act_op.ofm_shapes[0] = Shape4D(act_op_out.shape)
 
             # Update the consumer list
             act_op_out.consumer_list = op.outputs[0].consumer_list.copy()
@@ -839,6 +851,7 @@
     mul_alpha.add_input_tensor(alpha_tens)
     fm_alpha = ofm.clone(op.name + "_alpha")
     mul_alpha.set_output_tensor(fm_alpha)
+    mul_alpha.set_ifm_ofm_shapes()
     DebugDatabase.add_optimised(op, mul_alpha)
 
     if check_quantized_tens_scaling_equal(ifm, ofm):
@@ -860,6 +873,7 @@
         mul_identity.add_input_tensor(identity_tens)
         fm_id = ofm.clone(op.name + "_id")
         mul_identity.set_output_tensor(fm_id)
+        mul_identity.set_ifm_ofm_shapes()
         DebugDatabase.add_optimised(op, mul_identity)
 
     # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
@@ -890,7 +904,7 @@
     quantization.zero_point = 0
     tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization)
     op.add_input_tensor(tens)
-    op.ifm_shapes.append(full_shape(4, tens.shape, 1))
+    op.ifm_shapes.append(Shape4D(tens.shape))
 
     # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
     # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
@@ -1158,11 +1172,7 @@
     for idx, sg in enumerate(nng.subgraphs):
         # combined rewrite graph pass
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng,
-            sg,
-            arch,
-            [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split],
-            [set_ifm_ofm_op_shapes],
+            nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], [],
         )
 
     if verbose_graph: