MLBEDSW-1493: Optimise strided conv

  - Reshape/rearrange IFM and weight tensor for better HW utilization
  - Update estimator to cover this case

Change-Id: I4be70a69fa600a1951bf1c247f9973e6cc9b03f4
Signed-off-by: Diqing Zhong <diqing.zhong@arm.com>
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index c321678..3759d3b 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -17,6 +17,7 @@
 # Early optimisation of the network graph, using the rewrite_graph module to do the traversal of the graph. These are
 # split into two parts optimise_graph_a and optimise_graph_b.
 import math
+import uuid
 
 import numpy as np
 
@@ -598,6 +599,56 @@
     return op
 
 
+def optimise_strided_conv(op, arch, nng):
+    stride_x, stride_y = op.get_kernel_stride()
+    ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()
+
+    if (
+        op.type == Op.Conv2DBias
+        and op.op_index == 0
+        and stride_x == 2
+        and len(ifm_tensor.shape) == 4
+        and ifm_tensor.shape[3] <= 4
+        and ifm_tensor.shape[2] % 2 == 0
+        and weight_tensor is not None
+        and weight_tensor.shape[1] >= 2
+    ):
+        # IFM
+        ifm_reshaped = create_reshape_tensor(
+            ifm_tensor, [ifm_tensor.shape[0], ifm_tensor.shape[1], ifm_tensor.shape[2] // 2, ifm_tensor.shape[3] * 2]
+        )
+        op.set_input_tensor(ifm_reshaped, 0)
+
+        # Weights
+        weight_shape = weight_tensor.shape
+        if weight_shape[1] % 2 != 0:
+            weight_shape[1] = weight_shape[1] + 1
+            padded_array = np.zeros(weight_shape)
+            for i in range(weight_shape[0]):
+                padded_array[i] = np.vstack(
+                    [
+                        weight_tensor.quant_values[i],
+                        np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
+                    ]
+                )
+            weight_tensor.quant_values = padded_array
+        weight_shape[1] //= 2
+        weight_shape[2] *= 2
+        weight_tensor.quant_values = np.reshape(weight_tensor.quant_values, weight_shape)
+        weight_tensor.set_all_shapes(weight_shape)
+        # If multiple copies of the weights are used, we could avoid
+        # them having the same address by changing the value_id
+        weight_tensor.value_id = uuid.uuid4()
+
+        # Strides
+        stride_x = 1
+        op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
+
+        op.set_ifm_ofm_shapes()
+
+    return op
+
+
 def convert_conv_to_fc(op, arch, nng):
     # Conv 1x1 can be equivalent to Fully Connected.
     # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
@@ -1134,6 +1185,7 @@
         convert_depthwise_to_conv,
         convert_conv_to_fc,
         convert_softmax,
+        optimise_strided_conv,
         fixup_fully_connected_input,
         convert_batched_fc_shape,
         fixup_pack_input,
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index 4ca4683..c2418d7 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -422,6 +422,9 @@
         ifm_tensor, _, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm()
         ifm_tensor_shape = ps.primary_op.ifm_shapes[0].clone()
         ofm_tensor_shape = ps.primary_op.ofm_shapes[0].clone()
+        ofm_block.width = min(ofm_block.width, ofm_tensor_shape.width)
+        ofm_block.height = min(ofm_block.height, ofm_tensor_shape.height)
+        ofm_block.depth = min(ofm_block.depth, ofm_tensor_shape.depth)
 
         if npu_block_type == NpuBlockType.ReduceSum:
             block_traversal = TensorBlockTraversal.DepthFirst
@@ -439,6 +442,8 @@
         ifm_block = arch.get_ifm_block_size(
             ifm_block_depth, ofm_block, primary_op.kernel, ifm_resampling_mode=ifm_tensor.resampling_mode
         )
+        ifm_block.width = min(ifm_block.width, ifm_tensor_shape.width)
+        ifm_block.height = min(ifm_block.height, ifm_tensor_shape.height)
 
         if npu_block_type in (
             NpuBlockType.ConvolutionMxN,
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py
index d8faf36..2043864 100644
--- a/ethosu/vela/shared_buffer_allocation.py
+++ b/ethosu/vela/shared_buffer_allocation.py
@@ -172,7 +172,11 @@
 
 
 def is_acc_40bits_used(npu_block_type, ifm_tensor, ofm_tensor, ifm2_tensor=None):
-    return npu_block_type != NpuBlockType.Pooling and _all_fms_have_quant(ifm_tensor, ofm_tensor, ifm2_tensor)
+    return (
+        ifm_tensor.dtype.size_in_bits() == 16
+        and npu_block_type != NpuBlockType.Pooling
+        and _all_fms_have_quant(ifm_tensor, ofm_tensor, ifm2_tensor)
+    )
 
 
 def shared_buffer_allocation_for_pass(arch, ps) -> SharedBufferAllocation: