MLBEDSW-3791 Fix converting axis to 4D axis

Fix converting axis to 4D axis.

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I83501494738f402b374efd8a369e5001f17b8152
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 0754f7e..c321678 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -73,10 +73,14 @@
         tens.ops = []
         offset = 0
         for idx, inp in enumerate(inputs):
+            if axis >= 0:
+                axis_4D = axis + (4 - len(inp.shape))
+            else:
+                axis_4D = axis
             new_op = Operation(Op.ConcatSliceWrite, concat_op.name + str(idx))
             new_op.inputs = [inp]
             new_op.outputs = [tens]
-            new_op.attrs["concat_axis"] = axis + (4 - len(inp.shape))
+            new_op.attrs["concat_axis"] = axis_4D
             new_op.attrs["concat_start"] = offset
             offset += inp.shape[axis]
             new_op.attrs["concat_end"] = offset
@@ -122,7 +126,10 @@
             for idx, out in enumerate(outputs):
                 if out == tens:
                     break
-                axis_4D = axis + (4 - len(out.shape))
+                if axis >= 0:
+                    axis_4D = axis + (4 - len(out.shape))
+                else:
+                    axis_4D = axis
 
                 offset_start[axis_4D] += split_op.ofm_shapes[idx].get_dim(axis_4D)