MLBEDSW-3772 Fix FC with changed inp shape

When FC input is fixed by changing ifm_shape,
avoid_NHCWB16 must be set to ifm.

-Fixed issue with ResizeBilinear
-Changed to post order for concat ops in graph optimisation

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: Ie0c6a86637c210c0833ae9b2f8e7c494c5d4f66e
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index f1b2d35..ab4d916 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -59,7 +59,7 @@
     return tens
 
 
-def rewrite_concat_ops(op, arch, nng):
+def rewrite_concat_ops(op, arch):
     if not op.run_on_npu or not op.type.is_concat_op():
         return op
 
@@ -283,8 +283,8 @@
         op.attrs["padding"] = Padding.SAME
     op.inputs[0].resampling_mode = resampling_mode.NEAREST
 
-    upscaled_shape = op.ifm_shape[0].get_hw_as_list()
-    out_shape = op.ofm_shape[0].get_hw_as_list()
+    upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
+    out_shape = np.array(op.ofm_shapes[0].get_hw_as_list())
     if (upscaled_shape == upscaled_shape * 2 - shape_modifier).all():
         return op
 
@@ -346,6 +346,20 @@
     return op
 
 
+def rewrite_fully_connected_input(op, arch, nng):
+    if op.type == Op.FullyConnected:
+        n_in_elems = op.weights.shape[-2]
+        elms = op.ifm.elements()
+        batch_size = elms // n_in_elems
+        assert batch_size * n_in_elems == elms
+
+        if op.ifm.shape != [batch_size, n_in_elems]:
+            op.ifm.avoid_NHCWB16 = True
+
+        op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
+    return op
+
+
 def convert_batched_fc_shape(op, arch, nng):
     if op.type == Op.FullyConnected:
         # Check if the first dimension indicates batching
@@ -1199,9 +1213,8 @@
     # Handle Concat Ops
     for idx, sg in enumerate(nng.subgraphs):
         # rewrite graph pass
-        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [], [rewrite_concat_ops], rewrite_unsupported=False,
-        )
+        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
+        sg.refresh_after_modification()
 
     # Handle Split Ops
     for idx, sg in enumerate(nng.subgraphs):
@@ -1232,6 +1245,7 @@
         convert_conv_to_fc,
         convert_softmax,
         optimise_strided_conv,
+        rewrite_fully_connected_input,
         convert_batched_fc_shape,
         fixup_conv2d_backprop,
         fixup_relus_with_differing_ifm_ofm_scaling,
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 342efd9..8d54d65 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -719,14 +719,16 @@
 
         # set all shapes to op, as 4D
         if self.type == Op.FullyConnected:
-            n_in_elems = weight_tensor.shape[-2]
-            elms = ifm_tensor.elements()
-            batch_size = elms // n_in_elems
-            assert batch_size * n_in_elems == elms
-
-            self.ifm_shapes.append(Shape4D([batch_size, 1, 1, n_in_elems]))
-            self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
-        elif self.type == Op.Softmax:
+            if len(self.ifm.shape) == 2:
+                self.ifm_shapes.append(Shape4D([self.ifm.shape[0], 1, 1, self.ifm.shape[1]]))
+            else:
+                # Special case, handled in graph optimization
+                self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
+            if len(self.ofm.shape) == 2:
+                self.ofm_shapes.append(Shape4D([self.ofm.shape[0], 1, 1, self.ofm.shape[1]]))
+            else:
+                self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
+        if self.type == Op.Softmax:
             self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
             self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
         elif self.type.is_split_op or self.type.is_concat_op():
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index b01b07c..55980e3 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -22,6 +22,7 @@
 from ethosu.vela.graph_optimiser import convert_batched_fc_shape
 from ethosu.vela.graph_optimiser import optimise_graph_a
 from ethosu.vela.graph_optimiser import optimise_pad
+from ethosu.vela.graph_optimiser import rewrite_fully_connected_input
 from ethosu.vela.nn_graph import Graph
 from ethosu.vela.operation import Op
 from ethosu.vela.operation import Padding
@@ -47,8 +48,8 @@
     prev_op.ifm_shapes = op.ifm_shapes.copy()
     prev_op.ofm_shapes = op.ofm_shapes.copy()
 
+    rewrite_fully_connected_input(op, None, None)
     conv_op = convert_batched_fc_shape(op, None, None)
-
     assert conv_op.ifm == prev_op.ifm
     assert conv_op.ofm == prev_op.ofm
     assert op.ifm_shapes[0] == Shape4D([1, 2, 2, 8])
@@ -68,6 +69,7 @@
     prev_op.ifm_shapes = op.ifm_shapes.copy()
     prev_op.ofm_shapes = op.ofm_shapes.copy()
 
+    rewrite_fully_connected_input(op, None, None)
     conv_op = convert_batched_fc_shape(op, None, None)
 
     assert conv_op.ifm == prev_op.ifm