MLBEDSW-5384 FC layers run on NPU if underlying shape is 2D

*Added generic function which checks if underlying shape of
FullyConnected operation is 2D and performs shape reduction
*Fully connected operation >2 dimensions now run on NPU if the above
case is satisfied
*constraint_fc_output_2d and rewrite_fully_connected_input refactored
*Added unit test to confirm this functionality

Signed-off-by: Ayaan Masood <Ayaan.Masood@arm.com>
Change-Id: I0e29c767e5b84841eb53bbc44464b36a454f7b38
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 38b0e43..e981584 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -823,6 +823,19 @@
         else:
             return self.values.item(0)
 
+    def get_shape_as_2d(self, dimension_2_size: int) -> Optional[Shape4D]:
+
+        elms = self.elements()
+        dimension_1_size = elms // dimension_2_size
+        # Checks if the reduction works and shape is not 1D
+        is_reducible = dimension_1_size * dimension_2_size == elms and not (len(self.shape) == 1)
+
+        new_shape = None
+        if is_reducible:
+            new_shape = Shape4D([dimension_1_size, 1, 1, dimension_2_size])
+
+        return new_shape
+
     def __lt__(self, other: "Tensor") -> bool:
         return self.equivalence_id < other.equivalence_id
 
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index 1e5dbd4..2d6ca15 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -81,11 +81,13 @@
 
 
 def test_constraint_fc_output_2d_not_supp():
-    op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 2, 1], weights_shape=[12, 1, 1, 1])
+    op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [7, 4, 6], [3, 2, 2, 8], weights_shape=[1, 9, 1])
     assert not semantic_checker.is_operator_semantic_valid(op)
-    op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 3, 4], weights_shape=[12, 1, 1, 1])
+    op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 6, 1], [3, 7, 4], weights_shape=[1, 1, 7, 1])
     assert not semantic_checker.is_operator_semantic_valid(op)
-    op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1], weights_shape=[1, 1, 1, 1])
+    op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 1, 4, 7], [1, 9], weights_shape=[12, 3])
+    assert not semantic_checker.is_operator_semantic_valid(op)
+    op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4], [9], weights_shape=[3, 2])
     assert not semantic_checker.is_operator_semantic_valid(op)
 
 
@@ -94,6 +96,20 @@
     assert semantic_checker.is_operator_semantic_valid(op)
     op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1024], [16, 64], weights_shape=[1, 1024])
     assert semantic_checker.is_operator_semantic_valid(op)
+    op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 1, 1], weights_shape=[12, 1, 1, 1])
+    assert semantic_checker.is_operator_semantic_valid(op)
+    op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 1], weights_shape=[12, 1, 1, 1])
+    assert semantic_checker.is_operator_semantic_valid(op)
+    op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [1, 1, 3, 2], weights_shape=[12, 1, 1, 1])
+    assert semantic_checker.is_operator_semantic_valid(op)
+    op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 1, 1], weights_shape=[12, 1, 1, 1])
+    assert semantic_checker.is_operator_semantic_valid(op)
+    op = testutil.create_op_with_quant_tensors(
+        Op.FullyConnected, [12, 1, 1, 1], [1, 1, 24], weights_shape=[12, 1, 1, 1]
+    )
+    assert semantic_checker.is_operator_semantic_valid(op)
+    op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1, 3, 4], weights_shape=[1, 1, 1, 1])
+    assert semantic_checker.is_operator_semantic_valid(op)
 
 
 def test_constraint_conv_pass():
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index b2a3419..0639578 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -379,14 +379,12 @@
     return op
 
 
-def rewrite_fully_connected_input(op, arch, nng):
-    if op.type == Op.FullyConnected:
-        n_in_elems = op.weights.shape[-2]
-        elms = op.ifm.elements()
-        batch_size = elms // n_in_elems
-        assert batch_size * n_in_elems == elms
+def rewrite_fully_connected_input(op: Operation, arch, nng):
 
-        op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
+    if op.type == Op.FullyConnected:
+        new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
+        assert new_shape is not None, "Tensor can not be reshaped to 2D"
+        op.ifm_shapes[0] = new_shape
     return op
 
 
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index b264479..c811a0d 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -295,14 +295,11 @@
 
     @staticmethod
     def constraint_fc_output_2d(op):
-        "The output tensor(s) must have 2D shape"
-        valid = True
-        extra = []
-        for tens in op.outputs:
-            if len(tens.shape) != 2:
-                valid = False
-                extra.append(f"Tensor '{tens.name}' is {len(tens.shape)}D")
-        return valid, ", ".join(extra)
+        """The output tensor(s) must have 2D shape"""
+        valid = op.ifm.get_shape_as_2d(op.weights.shape[-2]) is not None
+        extra = f"Op has non-2D output tensor '{op.ofm.name}'" if not valid else ""
+
+        return valid, extra
 
     @staticmethod
     def constraint_stride_type(op):