MLBEDSW-2814 Add support for inferred size in SplitV

For SplitV sizesplit can contain one -1 indicating that
dimension is to be inferred.

Support added to handle this.

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: Ib9fc8dd2ee1749e81a978d85f2d4a016698bb441
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 8dec379..4b83b39 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -259,9 +259,17 @@
             size_tens = self.inputs[1]
             assert len(size_tens.ops) == 1 and size_tens.ops[0].type == "Const"
             sizes = size_tens.values
+
             axis_tens = self.inputs[2]
             assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
             axis = int(axis_tens.values)
+
+            for idx, size in enumerate(sizes):
+                # One but only one size might be set to -1, indicating that size should be inferred
+                if size == -1:
+                    sizes[idx] = input_tens.shape[axis] - (sum(sizes) + 1)
+                    break
+
             outputs = self.outputs
             assert num_splits == len(outputs)
             assert sum(sizes) == input_tens.shape[axis]
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 7cff0ee..e0ee616 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -378,6 +378,18 @@
             # check if both new_axis_mask and shrink_axis_mask have bit set
             if op.attrs["new_axis_mask"] != 0 and op.attrs["shrink_axis_mask"] != 0:
                 return False
+        if op.type == "SplitV":
+            # check that maximum one size is set to -1, indicating that size should be inferred
+            sizes = op.inputs[1].values
+            num_to_be_inferred = 0
+            for size in sizes:
+                if size == -1:
+                    num_to_be_inferred += 1
+
+            if num_to_be_inferred > 1:
+                print("Warning:", op.type, "has more than one size to be inferred, which is illegal, placing on CPU")
+                return False
+
         return True
 
     def check_quantization_restrictions_binary_elem_wise(self, op):