MLBEDSW-7315: Add support for AvgPool with stride_width > 3

* Convert AvgPool with stride_width > 3 and Valid padding to Conv2D to
  optimize it to run on NPU.

Change-Id: I06ab412357f0b09b1498f9019a9d1963a324ad34
Signed-off-by: Raul Farkas <raul.farkas@arm.com>
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index 947b585..fdceb43 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -134,7 +134,8 @@
 - Stride values for both width and height must be integer types
 - IFM and OFM data types must match
 - Kernel filter values for both width and height must be integer types
-- Stride values for both width and height must be in the range [1, 3]
+- Stride width must be greater than or equal to 1.  
+        For stride width greater than 3, valid padding needs to be used.
 - Kernel filter values for both width and height must be in the range [1, 8]
 - VALID padding: Kernel filter height must be in the range [1, 256]
 - VALID padding: Product of kernel filter width and height must be in the range [1, 65536]
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 9526bd5..79ac392 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -308,6 +308,14 @@
     if tens.dtype == DataType.int32 and is_ifm_tensor:
         return True
     if ps.primary_op.rounding_mode == RoundingMode.AwayZero:
+        if (
+            ps.primary_op.original_type == Op.AvgPool
+            and ps.primary_op.type == Op.Conv2DBias
+            and ps.primary_op.attrs.get("padding", None) == Padding.VALID
+        ):
+            # Force zero point to 0 for AveragePool operators converted to a Conv2DBias with rounding away from
+            # zero.
+            return True
         if ps.primary_op.original_type == Op.ResizeBilinear and ps.primary_op.type == Op.DepthwiseConv2DBias:
             # Force zero point to 0 for ResizeBilinear operators converted to a DepthwiseConv with rounding away from
             # zero. This is because the reference kernel ignores the zero points.
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 3685c5a..998d94f 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -615,7 +615,7 @@
             is_supported = False
             if self.original_type == Op.ResizeBilinear and self.type == Op.DepthwiseConv2DBias:
                 is_supported = True
-            if self.original_type == Op.AvgPool and self.type == Op.DepthwiseConv2DBias:
+            if self.original_type == Op.AvgPool and self.type in (Op.DepthwiseConv2DBias, Op.Conv2DBias):
                 is_supported = True
 
         if is_supported:
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 76383a4..7890637 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -949,7 +949,57 @@
     return op
 
 
-def fixup_strided_conv(op: Operation, arch, nng) -> Operation:
+def convert_avg_pool_to_conv2d(op: Operation, arch, nng) -> Operation:
+    """Convert strided Average Pools with stride >= 4 to Conv2D."""
+    if op.type != Op.AvgPool:
+        return op
+
+    stride_x, stride_y = op.get_kernel_stride()
+    # For strides <= 3 no optimization is needed
+    if stride_x <= 3:
+        return op
+    h, w = op.attrs["filter_height"], op.attrs["filter_width"]
+    inputs = op.inputs[0]
+    shape = inputs.shape
+
+    # Set necessary conv2d attributes
+    op.attrs.update(
+        {
+            "stride_h": stride_y,
+            "stride_w": stride_x,
+            "dilation_h_factor": 1,
+            "dilation_w_factor": 1,
+            "strides": (1, stride_y, stride_x, 1),
+            "dilation": (1, 1, 1, 1),
+        }
+    )
+
+    # Change op type
+    op.type = Op.Conv2DBias
+    op.name += "_conv2d"
+
+    op.rounding_mode = RoundingMode.AwayZero
+    shape = [h, w, 1, op.ofm.shape[-1]]
+    weights = np.full(shape, 1)
+    quant = QuantizationParameters(scale_f32=1 / (h * w), zero_point=0)
+    # Add unit weight tensor
+    op.add_input_tensor(
+        create_const_tensor(
+            "weights",
+            shape,
+            inputs.dtype,
+            weights,
+            quantization=quant,
+        ),
+    )
+    op.weights.values = np.reshape(op.inputs[1].values, shape)
+
+    # Set IFM/OFM shapes after changing op type
+    op.set_ifm_ofm_shapes()
+    return op
+
+
+def fixup_strided_conv(op: Operation, arch, nng):
     """Optimize or fixup strided Conv2DBias
     Optimization:
         Reduce, when possible, the Conv2DBias stride from N with 1 > N > 4 to 1
@@ -1853,7 +1903,11 @@
         if dtype is None:
             dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32
         bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values)
-        op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
+        bias_index = op.type.info.indices.biases[0]
+        if bias_index < len(op.inputs):
+            op.set_input_tensor(bias_tensor, bias_index)
+        else:
+            op.add_input_tensor(bias_tensor)
 
     return op
 
@@ -2349,6 +2403,7 @@
         convert_prelu,
         convert_mul_max_to_abs_or_lrelu,
         convert_lrelu,
+        convert_avg_pool_to_conv2d,
         fixup_strided_conv,
         convert_hardswish_to_lut,
         rewrite_fully_connected_input,
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 25b6897..a24eebc 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -220,7 +220,7 @@
 
         # Conv specific ops:
         for op_type in TFLiteSupportedOperators.convolution_ops:
-            self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_conv_stride)
+            self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_width_no_upper_limit)
 
         # Conv-like checks:
         for op_type in TFLiteSupportedOperators.convolution_like_ops:
@@ -244,10 +244,11 @@
             self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_depth_multiplier)
 
         # Pooling checks:
-        for op_type in TFLiteSupportedOperators.pooling_ops:
+        for op_type in TFLiteSupportedOperators.pooling_ops - TFLiteSupportedOperators.avg_pooling_ops:
             self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_range)
         # AVG pooling specific checks:
         for op_type in TFLiteSupportedOperators.avg_pooling_ops:
+            self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_range_no_padding)
             self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_filter_range)
             self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_filter_height_range_valid_pad)
             self.specific_constraints[op_type].append(
@@ -545,7 +546,7 @@
         return True, "Op has depth_multiplier=1"
 
     @staticmethod
-    def constraint_conv_stride(op):
+    def constraint_stride_width_no_upper_limit(op):
         """Stride width must be greater than or equal to 1.
         For stride widths greater than 3, the post-optimization stride needs to be less than or equal to 3.
         Stride height must be between 1 and 3."""
@@ -561,6 +562,17 @@
         return valid, f"Op has stride WxH as: {w}x{h}"
 
     @staticmethod
+    def constraint_stride_range_no_padding(op):
+        """Stride width must be greater than or equal to 1.
+        For stride width greater than 3, valid padding needs to be used."""
+        w, _ = op.get_kernel_stride()
+        valid, message = TFLiteSupportedOperators.constraint_stride_width_no_upper_limit(op)
+        padding = op.attrs.get("padding", None)
+        is_optimized_with_valid_padding = padding in (None, Padding.VALID) or w <= 3
+        valid = valid and is_optimized_with_valid_padding
+        return valid, f"{message}, padding: {padding}"
+
+    @staticmethod
     def constraint_depthwise_conv_stride(op):
         "Stride values for both width and height must be between 1 and 3"
         w, h = op.get_kernel_stride()
@@ -614,10 +626,11 @@
     def constraint_filter_range(cls, op):
         "Kernel filter values for both width and height must be in the range [{}, {}]"
         if op.attrs["padding"] == Padding.SAME:
+            sw, _ = op.get_kernel_stride()
             w = op.kernel.width
             h = op.kernel.height
             filter_min, filter_max = cls.filter_range
-            valid = (filter_min <= w <= filter_max) and (filter_min <= h <= filter_max)
+            valid = ((filter_min <= w <= filter_max) or sw == w) and (filter_min <= h <= filter_max)
             return valid, f"Op has kernel filter WxH as: {w}x{h}"
         return True, "Op has padding=VALID"