MLBEDSW-5554: Constraints for single-axis mean operations on NPU

 - Combine two MEAN operator checks for single axis averages into one
 - Only apply that check if the single axis is the height dimension
   (previously checks were also applied to width averages)
 - Rephrase some MEAN operator constraint descriptions

Signed-off-by: James Peet <james.peet@arm.com>
Change-Id: Ie0577f2b99aba1f3d6a4c39f8934eafe3813b736
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 193a23f..4d82677 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -209,8 +209,7 @@
         self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_avgpool)
         self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product)
         self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_int8)
-        self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_depthwise_conv_height_single_axis)
-        self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_avgpool_height_single_axis)
+        self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_single_axis)
 
         # Reshape specific checks:
         self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
@@ -637,7 +636,7 @@
     @classmethod
     @docstring_format_args([mean_kernel_product_avgpool])
     def constraint_mean_height_width_product_avgpool(cls, op):
-        """Product of height and width can be at most {}"""
+        """Product of height and width must be no greater than {}"""
         shape = op.inputs[0].shape
         hi = 0 if len(shape) < 4 else 1
         h, w = shape[hi : hi + 2]
@@ -647,8 +646,9 @@
     @classmethod
     @docstring_format_args([mean_kernel_product])
     def constraint_mean_height_width_product(cls, op):
-        """Product of height and width can be at most {} when IFM and OFM have different scale or zero point,
-        or keep_dims is True"""
+        """Product of height and width must be no greater than {} when:
+        IFM and OFM have different scale or zero point; or
+        'keep_dims' is True"""
         ifmq, ofmq = op.ifm.quantization, op.ofm.quantization
         keep_dims = op.attrs.get("keep_dims")
         # doesn't apply, size is checked by constraint_mean_height_width_product_avgpool
@@ -663,10 +663,11 @@
     @classmethod
     @docstring_format_args([mean_kernel_product_int8])
     def constraint_mean_height_width_product_int8(cls, op):
-        """Product of IFM height and width can be at most {} when the following are true:
-        IFM dimensions are 4,
-        Axis indices are 1 and 2,
-        keep_dims is set to True and
+        """Product of IFM height and width must be no greater than {} when:
+        The IFM shape has 4 dimensions; and
+        The axis indices specify reduction across 2 dimensions; and
+        The axis indices correspond to the width and height dimensions of the IFM; and
+        'keep_dims' is True; and
         IFM datatype is int8"""
         shape = op.ifm.shape
         axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
@@ -679,51 +680,39 @@
             or axis not in ([1, 2], [2, 1])
         ):
             return True, ""
-        hi = 0 if len(shape) < 4 else 1
-        h, w = shape[hi : hi + 2]
+        h = shape[-3]
+        w = shape[-2]
         max_prod = cls.mean_kernel_product_int8
         return h * w <= max_prod, f"Product of height and width is {h * w}"
 
     @classmethod
-    @docstring_format_args([dilated_height_range[1]])
-    def constraint_depthwise_conv_height_single_axis(cls, op):
-        """Height can be at most {} for single axis when axis is 1."""
+    @docstring_format_args([filter_height_range[1], dilated_height_range[1]])
+    def constraint_mean_height_single_axis(cls, op):
+        """For single axis averages across the height dimension:
+        IFM height must be no greater than {} if the IFM and OFM scale and zero point match; otherwise
+        IFM height must be no greater than {} if the IFM and OFM scale or zero point do not match"""
         inp, axis = op.inputs
         if axis.shape == [] or axis.shape[0] == 1:  # single axis
             axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
         else:
-            # Multiple axes, constraint does not apply
+            # Multiple axes
             return True, ""
 
-        # Height and width axes have different index depending on dimensions
         shape = inp.shape
-        h = shape[0] if len(shape) < 4 else shape[1]
+        if len(shape) < 3:
+            # No height dimension present in IFM
+            return True, ""
+        if axis != len(shape) - 3:
+            # Not averaging across the height dimension
+            return True, ""
 
-        # If quantization is the same across IFM and OFM op will become avgpool and this constraint does not apply.
+        h = shape[axis]
         ifm, ofm = op.get_ifm_ofm()
+
         if check_quantized_tens_scaling_equal(ifm, ofm):
-            return True, ""
-
-        return h <= 64 or axis != 1, f"Height is {h} and axis is {axis}."
-
-    @classmethod
-    @docstring_format_args([filter_height_range[1]])
-    def constraint_avgpool_height_single_axis(cls, op):
-        """Avgpool height can be at most {} for single axis when axis is 1."""
-        inp, axis = op.inputs
-        if axis.shape == [] or axis.shape[0] == 1:  # single axis
-            axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
+            return h <= cls.filter_height_range[1], f"Height is {h}, IFM and OFM quantizations match"
         else:
-            # Multiple axes, constraint does not apply
-            return True, ""
-
-        # Height and width axes have different index depending on dimensions
-        shape = inp.shape
-        h = shape[0] if len(shape) < 4 else shape[1]
-        ifm, ofm = op.get_ifm_ofm()
-        scaling_equal = check_quantized_tens_scaling_equal(ifm, ofm)
-
-        return h <= 256 or axis != 1 or not scaling_equal, f"Height is {h} and axis is {axis}"
+            return h <= cls.dilated_height_range[1], f"Height is {h}, IFM and OFM quantizations do not match"
 
     @staticmethod
     def constraint_reshape_shape_constant(op):