MLBEDSW-2584: Support cascading of Transpose Convolution

Signed-off-by: Jacob Bohlin <jacob.bohlin@arm.com>
Change-Id: I39cff126dda89d71426ab731427ca1d64d02590d
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index dbf2b7b..cb0cc64 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -132,26 +132,27 @@
     return padding, skirt
 
 
-def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_dims):
-    upscaled_shape = [input_dims[0], input_dims[1] * stride[1], input_dims[2] * stride[2], input_dims[3]]
-    ypad = needed_total_padding(int(upscaled_shape[1]), int(stride[1]), int(kernel_size[0]))
-    xpad = needed_total_padding(int(upscaled_shape[2]), int(stride[2]), int(kernel_size[1]))
-
+def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_dims, upscaling_factor):
+    kernel_height, kernel_width = kernel_size[0], kernel_size[1]
     if padding_type == b"SAME":
-        right_pad = ((xpad + 1) // 2) - 1
-        bottom_pad = ((ypad + 1) // 2) - 1
-        left_pad = max(kernel_size[0] - 1 - right_pad, 0)
-        top_pad = max(kernel_size[1] - 1 - bottom_pad, 0)
+        ypad = needed_total_padding(int(input_dims[1]) * upscaling_factor, int(stride[1]), int(kernel_height))
+        xpad = needed_total_padding(int(input_dims[2]) * upscaling_factor, int(stride[2]), int(kernel_width))
+
+        right_pad = ((xpad + 1) // upscaling_factor) - 1
+        bottom_pad = ((ypad + 1) // upscaling_factor) - 1
+        left_pad = max(kernel_width - 1 - right_pad, 0)
+        top_pad = max(kernel_height - 1 - bottom_pad, 0)
+
     elif padding_type == b"VALID":
-        right_pad = (xpad + 1) // 2
-        bottom_pad = (ypad + 1) // 2
-        left_pad = max(kernel_size[0] - right_pad, 0)
-        top_pad = max(kernel_size[1] - bottom_pad, 0)
+        right_pad = max(kernel_width - 2, 0)
+        bottom_pad = max(kernel_height - 2, 0)
+        left_pad = kernel_width - 1
+        top_pad = kernel_height - 1
     else:
         assert 0, "Unknown padding"
 
     padding = (top_pad, left_pad, bottom_pad, right_pad)
-    skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
+    skirt = padding
     return padding, skirt
 
 
@@ -332,8 +333,9 @@
             raise UnsupportedFeatureError("Unknown operation that uses padding: {}".format(op.type))
 
         if op.type == "Conv2DBackpropInputSwitchedBias":
+            upscaling_factor = op.outputs[0].shape[1] // input_shape[1]
             padding, skirt = calc_upscaled_padding_and_skirt(
-                op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape
+                op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
             )
         else:
             dilation_h, dilation_w = op.get_dilation_h_w()
diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py
index be8aac8..0053e79 100644
--- a/ethosu/vela/high_level_command_stream.py
+++ b/ethosu/vela/high_level_command_stream.py
@@ -48,11 +48,6 @@
         new_start_coord = list(self.start_coord)
         new_end_coord = list(self.end_coord)
 
-        # Adjust for upscaling
-        if len(new_start_coord) == len(new_end_coord) == 4:
-            new_start_coord[1] = new_start_coord[1] // upscaling_factor
-            new_end_coord[1] = new_end_coord[1] // upscaling_factor
-
         new_start_coord[concat_axis] -= concat_offset
         new_end_coord[concat_axis] -= concat_offset
 
@@ -69,9 +64,10 @@
         if npu_block_type == NpuBlockType.ElementWise and min(len(new_end_coord), len(ifm_shape)) >= 1:
             new_end_coord[-1] = min(new_end_coord[-1], ifm_shape[-1])
         if min(len(new_end_coord), len(ifm_shape)) >= 2:
-            new_end_coord[-2] = min(new_end_coord[-2], ifm_shape[-2])
+            new_end_coord[-2] = min(new_end_coord[-2], ifm_shape[-2] * upscaling_factor)
         if min(len(new_end_coord), len(ifm_shape)) >= 3:
-            new_end_coord[-3] = min(new_end_coord[-3], ifm_shape[-3])
+            original_end_coord = list(new_end_coord)
+            new_end_coord[-3] = min(new_end_coord[-3], ifm_shape[-3] * upscaling_factor)
 
         pad_top = 0
         pad_bottom = 0
@@ -83,22 +79,31 @@
 
             if len(new_start_coord) >= 3:
                 stride = strides[1]
+                skirt_top_remainder = skirt[0] % upscaling_factor
 
                 total_stride = stride * (new_end_coord[-3] - new_start_coord[-3] - 1)
-                new_start_coord[-3] = new_start_coord[-3] * stride - skirt[0]
+                new_start_coord[-3] = new_start_coord[-3] * stride - skirt[0] + skirt_top_remainder
 
-                pad_top = max(0, 0 - new_start_coord[-3])
+                pad_top = max(0, 0 - new_start_coord[-3]) + skirt_top_remainder
                 new_start_coord[-3] = max(new_start_coord[-3], 0)
 
                 while len(ifm_shape) < 3:
                     ifm_shape = [1] + ifm_shape
-                if (new_end_coord[-3] * stride + skirt[2]) > ifm_shape[-3]:
+
+                if (new_end_coord[-3] * stride + skirt[2]) > (ifm_shape[-3] * upscaling_factor):
                     # pad_bottom is calculated based the diff between the end position of the weight kernel,
                     # after last stride and the ifm height.
-                    k_start = new_start_coord[-3] - pad_top
-                    pad_bottom = max(0, k_start + total_stride + k_height - ifm_shape[-3])
+                    if upscaling_factor != 1 and original_end_coord[-3] > ifm_shape[-3] * upscaling_factor:
+                        # Special case for Transpose Convolution with VALID padding.
+                        pad_bottom = original_end_coord[-3] - (ifm_shape[-3] * upscaling_factor)
+                    else:
+                        k_start = new_start_coord[-3] - pad_top
+                        pad_bottom = max(0, k_start + total_stride + k_height - (ifm_shape[-3] * upscaling_factor))
 
-                new_end_coord[-3] = min(new_end_coord[-3] * stride + skirt[2], ifm_shape[-3])
+                # Adjust for upscaling
+                new_start_coord[-3] = max(new_start_coord[-3] // upscaling_factor, 0)
+                new_end_coord[-3] = new_end_coord[-3] * stride + skirt[2] + (skirt[2] % upscaling_factor)
+                new_end_coord[-3] = min(new_end_coord[-3] // upscaling_factor, ifm_shape[-3])
 
         return Box(new_start_coord, new_end_coord), pad_top, pad_bottom
 
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index 232a56c..6aa88d8 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -79,7 +79,6 @@
         skirt = ps.primary_op.attrs.get("skirt", None)
         if ps.primary_op.type in set(("Conv2DBackpropInputSwitchedBias", "ResizeBilinear")):
             upscaling = ofm_tensor.shape[-3] // ifm_tensor.shape[-3]
-            assert ofm_tensor.shape[-2] == (ifm_tensor.shape[-2] * upscaling)
 
     concat_axis = 0
     concat_offset = 0