MLBEDSW-3654 Add/use op ifm/ofm shapes

Add ifm/ofm shapes to op
Changed to rely on these shapes

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I571535a1dcadc2bdb04a3c727a8e1c49703b174d
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 4806001..fdb0fae 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -75,7 +75,7 @@
             new_op = Operation(Op.ConcatSliceWrite, concat_op.name + str(idx))
             new_op.inputs = [inp]
             new_op.outputs = [tens]
-            new_op.attrs["concat_axis"] = axis
+            new_op.attrs["concat_axis"] = axis + (4 - len(inp.shape))
             new_op.attrs["concat_start"] = offset
             offset += inp.shape[axis]
             new_op.attrs["concat_end"] = offset
@@ -116,21 +116,20 @@
         # be calculated from the index of the output tensor
         if axis is not None:
             # Get the start and end of the split
-            offset_start = [0] * len(tens.shape)
-            offset_end = [0] * len(tens.shape)
-            for out in outputs:
+            offset_start = [0] * 4
+            for idx, out in enumerate(outputs):
                 if out == tens:
                     break
-                offset_start[axis] += out.shape[axis]
+                axis_4D = axis + (4 - len(out.shape))
+                offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
 
                 # If start offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
                 if (offset_start[-1] % 16) != 0:
                     inp.avoid_NHCWB16 = True
-
-            offset_end[axis] = offset_start[axis] + tens.shape[axis]
+        else:
+            offset_start = full_shape(4, offset_start, 0)
 
         new_op.attrs["split_start"] = offset_start
-        new_op.attrs["split_end"] = offset_end
         new_op.run_on_npu = True
         new_op.set_output_tensor(tens)
         DebugDatabase.add_optimised(split_op, new_op)
@@ -217,6 +216,8 @@
     # Set the add inputs
     op.inputs[1] = op.inputs[0]
     op.inputs[0] = tens
+    op.ifm_shapes = []
+    op.ofm_shapes = []
 
     return op
 
@@ -321,13 +322,16 @@
         ifm = op.inputs[0]
         ofm = op.outputs[0]
         # Check if the FC is 2D and first dimension indicates batching
-        if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1:
+        # TOD0 op.ifm_shape[0] > 1 is enough when refactory is complete
+        if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0][0] > 1:
             n = ifm.shape[0]
             batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
             h, w = batching_split.get(n, (1, n))
 
             prev_op = ifm.ops[0]
             desired_shape = [1, h, w, ifm.shape[-1]]
+            op.ifm_shapes[0] = desired_shape
+
             if len(ifm.consumer_list) == 1 and prev_op is not None and prev_op.type == Op.Reshape:
                 # There is a preceding Reshape
                 # Compare input of prev_op and input of op, to see if prev_op can be removed
@@ -352,6 +356,8 @@
             weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
 
             desired_shape = [1, h, w, ofm.shape[-1]]
+            op.ofm_shapes[0] = desired_shape
+
             if (
                 len(ofm.consumer_list) == 1
                 and ofm.consumer_list[0] is not None
@@ -451,6 +457,7 @@
         new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape)
 
         for idx, out_tens in enumerate(op.outputs):
+            op.ofm_shapes[idx] = new_shape_tens
             reshape_in = out_tens.clone("_reshaped")
             reshape_in.set_all_shapes(reshape_input_shape)
             reshape_in.ops = [op]
@@ -489,7 +496,6 @@
             DebugDatabase.add_optimised(op, reshape_op)
 
             op.outputs[idx] = reshape_in
-
     return tens
 
 
@@ -582,7 +588,7 @@
     # caching/double buffering for the weights.
     # (Weights dont need to be reloaded for convs when IFM H and W are 1)
     if op.type == Op.Conv2DBias:
-        _, h, w, _ = op.inputs[0].shape
+        _, h, w, _ = op.ifm_shapes[0]
         kh, kw, _, _ = op.inputs[1].shape
         if h == 1 and w == 1 and kh == 1 and kw == 1:
             # Overwrite this op as a Fully Connected Op
@@ -595,6 +601,7 @@
             weight_tensor = op.inputs[1]
             weight_tensor.quant_values = weight_tensor.quant_values.squeeze(axis=(0, 1))
             weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
+
             # The output from a fully connected is expected to be 2D so we need to add a reshape layer to convert it
             # back to 4D afterwards as the next layer is expecting that shape
             orig_ofm_tensor = op.outputs[0]
@@ -609,6 +616,7 @@
             reshape_op.attrs["new_shape"] = orig_ofm_tensor.shape
             reshape_op.inputs = [fc_ofm_tensor, new_shape_tens]
             reshape_op.set_output_tensor(orig_ofm_tensor)
+
             # Replace this ops OFM to point to the 2D tensor
             op.outputs[0] = fc_ofm_tensor
             # Record optimisation in debug database
@@ -651,6 +659,8 @@
         prep_op = get_prepend_op(op)
         if prep_op is not None:
             act_op = op.clone("_reordered")
+            act_op.ifm_shapes = list(op.ifm_shapes)
+            act_op.ofm_shapes = list(op.ofm_shapes)
 
             # There is only one input tensor, overwrite it
             act_op.set_input_tensor(prep_op.inputs[0], 0)
@@ -658,6 +668,8 @@
             act_op_out = act_op.inputs[0].clone("_acted")
             act_op_out.quantization = op.outputs[0].quantization.clone()
             act_op.set_output_tensor(act_op_out)
+            act_op.ifm_shapes[0] = full_shape(4, prep_op.inputs[0].shape, 1)
+            act_op.ofm_shapes[0] = full_shape(4, act_op_out.shape, 1)
 
             # Update the consumer list
             act_op_out.consumer_list = op.outputs[0].consumer_list.copy()
@@ -704,6 +716,15 @@
     return op
 
 
+def set_ifm_ofm_op_shapes(op, arch, nng):
+    if op.run_on_npu and op.type.needs_shapes():
+        if op.ifm_shapes or op.ofm_shapes:
+            # Shapes already set
+            return op
+        op.set_ifm_ofm_shapes()
+    return op
+
+
 def convert_softmax(op, arch, nng):
     if op.type == Op.Softmax and op.run_on_npu:
         softmax = SoftMax(op)
@@ -839,7 +860,7 @@
         mul_identity.add_input_tensor(identity_tens)
         fm_id = ofm.clone(op.name + "_id")
         mul_identity.set_output_tensor(fm_id)
-        DebugDatabase.add_optimised(op, mul_alpha)
+        DebugDatabase.add_optimised(op, mul_identity)
 
     # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
     op.type = Op.Maximum
@@ -869,6 +890,8 @@
     quantization.zero_point = 0
     tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization)
     op.add_input_tensor(tens)
+    op.ifm_shapes.append(full_shape(4, tens.shape, 1))
+
     # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
     # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
     # should be the same as the IFM
@@ -1072,10 +1095,20 @@
     if verbose_graph:
         nng.print_graph()
 
+    pre_process_list = [
+        supported_operator_check,
+        set_ifm_ofm_op_shapes,
+        # TODO: memory-only Op removal
+    ]
+
+    for idx, sg in enumerate(nng.subgraphs):
+        # rewrite graph pass
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
+        )
+
     op_rewrite_list = [
         set_tensor_equivalence,
-        supported_operator_check,
-        # then do any rewrites of supported operators
         convert_depthwise_to_conv,
         convert_conv_to_fc,
         convert_softmax,
@@ -1106,7 +1139,7 @@
     for idx, sg in enumerate(nng.subgraphs):
         # remove passthrough tensors and attempt further optimizations
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields]
+            nng, sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields],
         )
 
     # Post-optimisation operator debug tracing
@@ -1125,7 +1158,11 @@
     for idx, sg in enumerate(nng.subgraphs):
         # combined rewrite graph pass
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], []
+            nng,
+            sg,
+            arch,
+            [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split],
+            [set_ifm_ofm_op_shapes],
         )
 
     if verbose_graph: