MLBEDSW-7654: Extend support for Mean where HxW > 4096

* Convert Means with large IFMs to several DeptwiseConv2DBias and Add
  operations.
* Update tflite supported operator check with new height and width
  constraints.
* Update unit-tests to verify supported operator changes.
* Fix output-diff for 2D IFMs (MLBEDSW-7772)

Signed-off-by: Alexander Hansson <Alexander.Hansson@arm.com>
Change-Id: Ifae6fb1cdac475ae7dac5116c5f13631ff82108a
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index 5258946..c1c58d3 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -19,7 +19,7 @@
 # Supported Ops
 
 This file was automatically generated by Vela using the `--supported-ops-report` parameter.  
-Vela version: `3.8.1.dev7+g0f09dd2`
+Vela version: `3.8.1.dev9+g85b7790.d20230616`
 
 This file complies with
 [**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md)
@@ -252,9 +252,11 @@
 
 - Input tensor must be at least 2D
 - Axis indices must correspond to height and width axes
-- Product of height and width must be no greater than 4096
-- For single axis averages across the height dimension:  
-        IFM height must be no greater than 64
+- Product of height and width must be no greater than:  
+        - 16777216 for signed 8-bit inputs  
+        - 8388608 for unsigned 8-bit inputs  
+        - 65536 for signed 16-bit inputs
+- Width must be no greater than 4096
 
 ### TFLite MINIMUM Constraints
 
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index 6f3553d..f2ad858 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -613,9 +613,26 @@
 
 
 def test_mean_hw_product():
-    op = create_mean([1, 64, 64, 16], [1, 16], [1, 2], DataType.uint8, {})
+    # max kernel size checks
+    op = create_mean([1, 4096, 4096, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {})
     assert support.is_operator_supported(op)
-    op = create_mean([1, 65, 64, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+    op = create_mean([1, 4097, 4096, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {})
+    assert not support.is_operator_supported(op)
+
+    op = create_mean([1, 2048, 4096, 16], [1, 1, 1, 16], [1, 2], DataType.uint8, {})
+    assert support.is_operator_supported(op)
+    op = create_mean([1, 2049, 4096, 16], [1, 1, 1, 16], [1, 2], DataType.uint8, {})
+    assert not support.is_operator_supported(op)
+
+    op = create_mean([1, 16, 4096, 16], [1, 1, 1, 16], [1, 2], DataType.int16, {})
+    assert support.is_operator_supported(op)
+    op = create_mean([1, 17, 4096, 16], [1, 1, 1, 16], [1, 2], DataType.int16, {})
+    assert not support.is_operator_supported(op)
+
+    # h > 4096 is OK but w > 4096 is not
+    op = create_mean([1, 4097, 10, 16], [1, 1, 1, 16], [1, 2], DataType.uint8, {"keep_dims": True})
+    assert support.is_operator_supported(op)
+    op = create_mean([1, 10, 4097, 16], [1, 1, 1, 16], [1, 2], DataType.int16, {"keep_dims": True})
     assert not support.is_operator_supported(op)
 
 
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 7890637..21c02f3 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -57,6 +57,7 @@
 from .operation import Operation
 from .operation import Padding
 from .operation import RoundingMode
+from .operation_util import create_add
 from .operation_util import create_add_nop
 from .operation_util import create_avgpool_nop
 from .operation_util import create_cast_op
@@ -942,9 +943,10 @@
 def reorder_depthwise_weights(op, arch, nng):
     if op.type.is_depthwise_conv2d_op():
         weight_tensor = op.inputs[1]
-        weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
-        weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
-        weight_tensor.weight_transpose_depthwise = True
+        if not weight_tensor.weight_transpose_depthwise:
+            weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
+            weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
+            weight_tensor.weight_transpose_depthwise = True
 
     return op
 
@@ -1949,44 +1951,45 @@
 
 
 def convert_mean_to_depthwise_conv(op, arch, nng):
+    """
+    When h x w <= 4096     When h x w > 4096 there is a need to split into several ops.
+                           Do this by splitting up h and change the read_offset/shape.
+                           Below is an example where ifm is 1x190x64x1
+           MEAN                                           MEAN
+             |                      |-----------------------|----------------------|
+    DepthwiseConv2DBias    1_DepthwiseConv2DBias   2_DepthwiseConv2DBias   3_DepthwiseConv2DBias
+             |                      |                       |                     |
+            MUL                     |---------ADD-----------|                     |
+                                               |                                  |
+                                               |----------------ADD---------------|
+                                                                 |
+                                                                MUL
+               1_DepthwiseConv2DBias: read_offset [0, 0, 0, 0]> read_shape [1,  64, 64, 1]>
+               2_DepthwiseConv2DBias: read_offset [0, 64, 0, 0]> read_shape [1,  64, 64, 1]>
+               3_DepthwiseConv2DBias: read_offset [0, 128, 0, 0]> read_shape [1,  62, 64, 1]>
+    """
     if op.type == Op.Mean and op.run_on_npu:
+        max_kernel_size = 4096
+        max_height = 64
         inp, axis = op.inputs
         shape = inp.shape
         ofm_shape = op.ofm.shape
         dims = len(shape)
         dims_ofm = len(ofm_shape)
+        ofmq = op.ofm.quantization
+        ifmq = op.ifm.quantization
 
         # Height and width axes have different index depending on dimensions
         if axis.shape == [] or axis.shape[0] == 1:  # single axis
             axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
-            if dims in (2, 3):
-                # If dims is 2 or 3, axis 0 refers to h-dimension
-                h, w = (shape[axis], 1) if axis == 0 else (1, shape[axis])
+            # If dims is 4, axis 1 refers to h-dimension
+            if dims == 4:
+                reduce_h, reduce_w = (True, False) if axis == 1 else (False, True)
             else:
-                # If dims is 4, axis 1 refers to h-dimension
-                h, w = (shape[axis], 1) if axis == 1 else (1, shape[axis])
+                reduce_h, reduce_w = (True, False) if axis == 0 else (False, True)
         else:  # multiple axes
             axis = sorted(axis.values)
-            h, w = [shape[i] for i in axis]
-
-        # Set necessary depthwise attributes
-        op.attrs.update(
-            {
-                "padding": Padding.VALID,
-                "stride_h": 1,
-                "stride_w": 1,
-                "strides": (1, 1, 1, 1),
-                "depth_multiplier": 1,
-                "channel_multiplier": 1,
-                "dilation_h_factor": 1,
-                "dilation_w_factor": 1,
-                "dilation": (1, 1, 1, 1),
-            }
-        )
-        # Change op type
-        op.type = Op.DepthwiseConv2DBias
-        # Set IFM/OFM shapes after changing op type
-        op.set_ifm_ofm_shapes()
+            reduce_h, reduce_w = (True, True)
 
         # Change dimensions to 4
         def extend_dims(dim, in_shape):
@@ -2009,63 +2012,140 @@
             ofm_shape = extend_dims(dims_ofm, ofm_shape)
             op.set_ifm_ofm_shapes()
 
-        # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
-        if h > 64:
-            # This can only happen and be done for multiple axes, and
-            # h * w <= 4096 for DepthwiseConv2DBias
-            # which is checked in supported ops
+        # Compute kernel sizes for our convolutions
+        h = shape[1] if reduce_h else 1
+        w = shape[2] if reduce_w else 1
+        num_elements_in_axis = h * w
+
+        # If one convolution is enough, but height is greater than max kernel height
+        # reshape from HxW to 1x(HxW)
+        # This can only be done if the mean is computed over both H and W
+        if h > max_height and num_elements_in_axis <= max_kernel_size and reduce_h and reduce_w:
             shape = [shape[0], 1, h * w, shape[3]]
             op.ifm_shapes[0] = Shape4D(shape)
-            weight_shape = [1, h * w, shape[3], shape[0]]
-        else:
-            # Set weight shape to [H,W,C,B]
-            weight_shape = [h, w, shape[3], shape[0]]
+            op.ifm.shape = shape
+            w = h * w
+            h = 1
 
-        op.rounding_mode = RoundingMode.HalfUp
-        identity_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
-        op.forced_input_quantization = identity_quant
-        op.forced_output_quantization = identity_quant
+        intermediate_op = None
+        height_per_conv = min(max_kernel_size // w, h)
+        height_per_conv = min(height_per_conv, max_height)
+        num_convs = math.ceil(h / height_per_conv)
+        convs = list()
 
-        # Add unit weight tensor
-        op.set_input_tensor(
-            create_const_tensor(
-                "weights",
+        for i in range(num_convs):
+            is_last_op = i == (num_convs - 1)
+
+            intermediate_op = op.clone(f"{op.name}_conv_{i}")
+
+            intermediate_op.type = Op.DepthwiseConv2DBias
+
+            # Set necessary depthwise attributes
+            intermediate_op.attrs.update(
+                {
+                    "padding": Padding.VALID,
+                    "stride_h": 1,
+                    "stride_w": 1,
+                    "strides": (1, 1, 1, 1),
+                    "depth_multiplier": 1,
+                    "channel_multiplier": 1,
+                    "dilation_h_factor": 1,
+                    "dilation_w_factor": 1,
+                    "dilation": (1, 1, 1, 1),
+                }
+            )
+
+            b, _, _, c = shape
+
+            intermediate_tensor = op.ofm.clone(suffix=f"_conv_sum_{i}", set_unique=True)
+            intermediate_tensor.dtype = DataType.int32
+            intermediate_op.set_output_tensor(intermediate_tensor)
+
+            # as we have several convs, scaling/rounding must be done after the sum has been calculated
+            intermediate_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1])
+
+            # compute height for the kernel
+            if is_last_op and h % height_per_conv != 0:
+                weight_h = h % height_per_conv
+            else:
+                weight_h = height_per_conv
+
+            # compute ifm read offset and shape for the convolution
+            read_shape_h = weight_h if reduce_h else shape[1]
+            read_shape_w = w if reduce_w else shape[2]
+
+            intermediate_op.read_offsets[0] = Shape4D([0, i * height_per_conv, 0, 0])
+            intermediate_op.read_shapes[0] = Shape4D(shape).with_hw(read_shape_h, read_shape_w)
+
+            weight_quant = QuantizationParameters(0, 255, scale_f32=1.0, zero_point=0)
+            weight_shape = [weight_h, w, c, b]
+            weight_tensor = create_const_tensor(
+                f"{intermediate_op.name}_weights",
                 weight_shape,
-                inp.dtype,
+                DataType.uint8,
                 np.ones(weight_shape),
-                quantization=identity_quant,
-            ),
-            1,
-        )
-        op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
+                TensorPurpose.Weights,
+                quantization=weight_quant,
+            )
 
-        # Input zero point is adjusted after the sum calculation, so we emulate that with a bias
-        ofmq, ifmq = op.ofm.quantization, inp.quantization
-        bias = -ifmq.zero_point * h * w
-        bias_shape = [shape[-1]]
-        op.inputs.append(create_const_tensor(op.name + "_bias", bias_shape, DataType.int32, np.ones(bias_shape) * bias))
-        DebugDatabase.add_optimised(op, op)
+            weights_1D = np.ones(np.prod(weight_shape))
+            weight_tensor.equivalence_id = create_equivalence_id(tuple(weights_1D))
+            weight_tensor.value_id = weight_tensor.equivalence_id
 
-        # Create intermediate tensor between depthwise conv and mul
-        intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
-        intermediate.dtype = DataType.int32
+            intermediate_op.set_input_tensor(weight_tensor, 1)
 
-        # Multiply sum with 1/num_elements_in_axis to get the mean
-        mul_op = Operation(Op.Mul, op.name + "_mul")
-        mul_op.add_input_tensor(intermediate)
-        mul_op.set_output_tensor(op.ofm)
-        mul_op.forced_input_quantization = identity_quant
+            dtype = DataType.int64 if intermediate_op.ifm.dtype == DataType.int16 else DataType.int32
+            bias_values = [0] * c
+            bias = create_const_tensor(f"{intermediate_op.name}_bias", [c], dtype, bias_values)
+            bias.equivalence_id = create_equivalence_id(tuple(bias_values))
+            bias.value_id = bias.equivalence_id
+            intermediate_op.inputs.append(bias)
+            intermediate_op.set_ifm_ofm_shapes()
 
-        # Set dw conv output to the intermediate tensor
-        op.set_output_tensor(intermediate)
+            # We want to avoid reshaping the tensor directly, to not affect other ops
+            # so we update the shape explicitly for this operation
+            intermediate_op.ifm_shapes[0] = Shape4D(shape)
 
-        # Move activation from original op to mean op
-        mul_op.activation = op.activation
-        op.activation = None
+            convs.append(intermediate_op)
+            DebugDatabase.add_optimised(op, intermediate_op)
+
+        # If we have more than one convolution
+        # We use add operations to accumulate the intermediate tensors
+        if len(convs) > 1:
+            prev_add_op = None
+            idx = 0
+
+            while len(convs):
+                intermediate_tensor = op.ofm.clone(suffix=f"_add_sum_{idx}", set_unique=True)
+                intermediate_tensor.dtype = DataType.int32
+
+                one_scale_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
+
+                ifm = convs.pop().ofm
+                if not prev_add_op:
+                    ifm2 = convs.pop().ofm
+                else:
+                    ifm2 = prev_add_op.ofm
+
+                intermediate_op = create_add(f"{op.name}_add_{idx}", ifm, ifm2, one_scale_quant)
+                intermediate_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1])
+                intermediate_op.set_output_tensor(intermediate_tensor)
+                intermediate_op.set_ifm_ofm_shapes()
+
+                prev_add_op = intermediate_op
+                idx += 1
+
+                DebugDatabase.add_optimised(op, intermediate_op)
+
+        # Convert the original mean op to our final Mul operation
+        # Which scales and divides by num_elements_in_axis
+        op.type = Op.Mul
+        op.name = f"{op.name}_mul"
+        op.attrs = {}
+        op.set_input_tensor(intermediate_op.ofm, 0)
 
         # The multiplier is calculated in the same way as in the reference,
         # clamping the shift value at the price of some precision loss.
-        num_elements_in_axis = int(h * w)
         output_multiplier, output_shift_vela = quantise_scale(np.double(ifmq.scale_f32) / np.double(ofmq.scale_f32))
 
         # Convert to reference representation shift value
@@ -2084,18 +2164,19 @@
 
         # For int32 scaling is not supported so instead multiply with the scale
         # intermediate * scale -> round and shift.
+        identity_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
         scalar = create_const_tensor(
             op.name + "_scalar", [1, 1, 1, 1], DataType.int32, [output_multiplier], quantization=identity_quant
         )
-        mul_op.add_input_tensor(scalar)
-        mul_op.set_ifm_ofm_shapes()
+        op.set_input_tensor(scalar, 1)
+        op.set_ifm_ofm_shapes()
 
         # Reference using TFL rounding for the multiply
-        mul_op.rounding_mode = RoundingMode.TFLite
+        op.rounding_mode = RoundingMode.TFLite
 
         # Need to use explicit scaling to get the wanted shift
-        mul_op.explicit_scaling = ExplicitScaling(False, [output_shift_vela], [1])
-        DebugDatabase.add_optimised(op, mul_op)
+        op.explicit_scaling = ExplicitScaling(False, [output_shift_vela], [1])
+        DebugDatabase.add_optimised(op, op)
     return op
 
 
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index f965d2b..92a7f3c 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -191,7 +191,10 @@
     filter_range = (1, 8)
     filter_height_range = (1, 256)
     filter_product_range = (1, 256 * 256)
-    mean_kernel_product = 64 * 64
+    mean_width_size = 64 * 64
+    mean_kernel_product_int8 = 2 ** (24)
+    mean_kernel_product_uint8 = 2 ** (23)
+    mean_kernel_product_int16 = 2 ** (16)
 
     def __init__(self):
         # Setup the generic constraints. Note: the order matters
@@ -311,7 +314,7 @@
 
         # Mean specific checks:
         self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product)
-        self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_single_axis)
+        self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_width)
 
         # Reshape specific checks:
         self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
@@ -838,14 +841,35 @@
         return valid, f"Op has ifm_shape={ifm_shape} and ifm2_shape={ifm2_shape}"
 
     @classmethod
-    @docstring_format_args([mean_kernel_product])
+    @docstring_format_args([mean_kernel_product_int8, mean_kernel_product_uint8, mean_kernel_product_int16])
     def constraint_mean_height_width_product(cls, op):
-        """Product of height and width must be no greater than {}"""
+        """Product of height and width must be no greater than:
+        - {} for signed 8-bit inputs
+        - {} for unsigned 8-bit inputs
+        - {} for signed 16-bit inputs"""
         shape = op.inputs[0].shape
         hi = 0 if len(shape) < 4 else 1
         h, w = shape[hi : hi + 2]
-        max_prod = cls.mean_kernel_product
-        return h * w <= max_prod, f"Product of height and width is {h * w}"
+        if op.ifm.dtype == DataType.int16:
+            max_prod = cls.mean_kernel_product_int16
+            datatype = "int16"
+        elif op.ifm.dtype == DataType.uint8:
+            max_prod = cls.mean_kernel_product_uint8
+            datatype = "uint8"
+        else:
+            max_prod = cls.mean_kernel_product_int8
+            datatype = "int8"
+        return h * w <= max_prod, f"Datatype is {datatype}, product of height and width is {h * w}"
+
+    @classmethod
+    @docstring_format_args([mean_width_size])
+    def constraint_mean_width(cls, op):
+        """Width must be no greater than {}"""
+        shape = op.inputs[0].shape
+        hi = 0 if len(shape) < 4 else 1
+        h, w = shape[hi : hi + 2]
+        max_width = cls.mean_width_size
+        return w <= max_width, f"Width is {w}"
 
     @classmethod
     @docstring_format_args([dilated_height_range[1]])