MLBEDSW-3791 Fix converting axis to 4D axis
Fix converting axis to 4D axis.
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I83501494738f402b374efd8a369e5001f17b8152
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 0754f7e..c321678 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -73,10 +73,14 @@
tens.ops = []
offset = 0
for idx, inp in enumerate(inputs):
+ if axis >= 0:
+ axis_4D = axis + (4 - len(inp.shape))
+ else:
+ axis_4D = axis
new_op = Operation(Op.ConcatSliceWrite, concat_op.name + str(idx))
new_op.inputs = [inp]
new_op.outputs = [tens]
- new_op.attrs["concat_axis"] = axis + (4 - len(inp.shape))
+ new_op.attrs["concat_axis"] = axis_4D
new_op.attrs["concat_start"] = offset
offset += inp.shape[axis]
new_op.attrs["concat_end"] = offset
@@ -122,7 +126,10 @@
for idx, out in enumerate(outputs):
if out == tens:
break
- axis_4D = axis + (4 - len(out.shape))
+ if axis >= 0:
+ axis_4D = axis + (4 - len(out.shape))
+ else:
+ axis_4D = axis
offset_start[axis_4D] += split_op.ofm_shapes[idx].get_dim(axis_4D)