MLBEDSW-1693 Convert batched FC to Conv

Added support to convert batched FC to conv.
This enables choosing a suitable block-size.

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: Idc49e4fb6d29c554f10a38ece7996a7b7795ffad
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index a57ac82..f6b03f6 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -329,11 +329,82 @@
         desired_shape = [batch_size, n_in_elems]
         if inp.shape != desired_shape:
             # mismatch, insert a reshape to fix this.
-            op.inputs[0] = create_reshape_tensor(inp, desired_shape)
+            op.set_input_tensor(create_reshape_tensor(inp, desired_shape), 0)
 
     return op
 
 
+def convert_batched_fc_to_conv(op, arch):
+    if op.type == "FullyConnectedAct":
+        ifm = op.inputs[0]
+        ofm = op.outputs[0]
+        # Check if the FC is 2D and first dimension indicates batching
+        if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] != 1:
+            n = ifm.shape[0]
+            batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
+            h, w = batching_split.get(n, (1, n))
+
+            # Convert to convolution
+            op.name += "_conv"
+            op.type = "Conv2DBiasAct"
+            faf = op.attrs.get("fused_activation_function", None)
+            op.attrs = {
+                "dilation": (1, 1, 1, 1),
+                "dilation_h_factor": 1,
+                "dilation_w_factor": 1,
+                "fused_activation_function": faf,
+                "npu_block_type": NpuBlockType.ConvolutionMxN,
+                "padding": b"SAME",
+                "stride_h": 1,
+                "stride_w": 1,
+                "strides": (1, 1, 1, 1),
+            }
+
+            prev_op = ifm.ops[0]
+            desired_shape = [1, h, w, ifm.shape[-1]]
+            if len(ifm.consumer_list) == 1 and prev_op is not None and prev_op.type == "Reshape":
+                # There is a preceding Reshape
+                # Compare input of prev_op and input of op, to see if prev_op can be removed
+                ifm_prev_op = prev_op.inputs[0]
+                if ifm_prev_op.shape == ifm.shape and ifm_prev_op.quantization.is_scaling_equal(ifm.quantization):
+                    # prev_op can be removed
+                    op.set_input_tensor(ifm_prev_op, 0)
+                else:
+                    op.inputs[0].set_all_shapes(desired_shape)
+                    prev_op.set_input_tensor(
+                        create_const_tensor(prev_op.inputs[1].name, [1], DataType.int32, desired_shape), 1
+                    )
+                    prev_op.attrs["new_shape"] = desired_shape
+            else:
+                # Add reshape op to the input if there is no preceding reshape
+                ifm.consumer_list.remove(op)
+                op.set_input_tensor(create_reshape_tensor(ifm, desired_shape), 0)
+
+            # Reshape Weights to be 4D. IO becomes HWIO
+            weight_tensor = op.inputs[1]
+            weight_tensor.quant_values = np.expand_dims(np.expand_dims(weight_tensor.quant_values, axis=0), axis=0)
+            weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
+
+            desired_shape = [1, h, w, ofm.shape[-1]]
+            if (
+                len(ofm.consumer_list) == 1
+                and ofm.consumer_list[0] is not None
+                and ofm.consumer_list[0].type == "Reshape"
+            ):
+                # There is a subsequent Reshape
+                # Compare desired shape and output of consumer op, to see if consumer op can be removed
+                ofm_cons_op = ofm.consumer_list[0].outputs[0]
+                if desired_shape == ofm_cons_op.shape and ofm.quantization.is_scaling_equal(ofm_cons_op.quantization):
+                    op.outputs[0] = ofm_cons_op
+                    op.outputs[0].ops = [op]
+                else:
+                    op.outputs[0].set_all_shapes(desired_shape)
+            else:
+                # Add rehape op to the output
+                op.set_output_tensor(create_reshape_tensor(ofm, desired_shape, False))
+    return op
+
+
 def fixup_pack_input(op, arch):
     if op.type == "Pack":
         # Pack is also referred to as Stack
@@ -598,10 +669,18 @@
         prep_op = get_prepend_op(op)
         if prep_op is not None:
             act_op = op.clone("_reordered")
-            act_op.inputs = [prep_op.inputs[0]]
+
+            # There is only one input tensor, overwrite it
+            act_op.set_input_tensor(prep_op.inputs[0], 0)
+
             act_op_out = act_op.inputs[0].clone("_acted")
             act_op_out.quantization = op.outputs[0].quantization.clone()
             act_op.set_output_tensor(act_op_out)
+
+            # Update the consumer list
+            act_op_out.consumer_list = op.outputs[0].consumer_list.copy()
+            act_op_out.consumer_list.append(prep_op)
+
             prep_op.inputs[0] = act_op_out
             prep_op.outputs[0].quantization = act_op_out.quantization.clone()
 
@@ -956,6 +1035,7 @@
         convert_conv_to_fc,
         convert_softmax,
         fixup_fully_connected_input,
+        convert_batched_fc_to_conv,
         fixup_pack_input,
         fixup_conv2d_backprop,
         fixup_relus_with_differing_ifm_ofm_scaling,