MLBEDSW-6909: Use int32 acc for the Mean op

Changed acc type from int16 to int32. This will solve
saturation problems and the constraint added in
commit "MLBEDSW-5029: Output diff for Mean op"
can be removed.

Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Change-Id: I05ec8835b43313b1a264d61a2b147fa62da123fe
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index cc8b3d2..89c2799 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -623,22 +623,6 @@
     op = create_mean([1, 16, 17, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
     assert not support.is_operator_supported(op)
 
-    # Create OP that will not saturate the accumulator
-    op = create_mean([1, 5, 14, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
-    op.ifm.quantization.scale_f32 = 2.0
-    op.ifm.quantization.zero_point = 95
-    op.ofm.quantization.scale_f32 = 1.0
-    op.ofm.quantization.zero_point = 95
-    assert support.is_operator_supported(op)
-
-    # Create OP that can saturate the accumulator
-    op = create_mean([1, 6, 14, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
-    op.ifm.quantization.scale_f32 = 2.0
-    op.ifm.quantization.zero_point = 95
-    op.ofm.quantization.scale_f32 = 1.0
-    op.ofm.quantization.zero_point = 95
-    assert not support.is_operator_supported(op)
-
 
 def test_mean_hw_product_avgpool():
     op = create_mean([1, 200, 200, 16], [1, 16], [1, 2], DataType.uint8, {"keep_dims": False})
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index aaa778e..0f199de 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1476,9 +1476,10 @@
                 # followed by a multiplication with 1/N to get the MEAN
                 weight_scale = 1
                 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
-                intermediate.dtype = DataType.int16
+                intermediate.dtype = DataType.int32
                 mul_op = Operation(Op.Mul, op.name + "_mul")
                 mul_op.add_input_tensor(intermediate)
+                mul_op.set_output_tensor(op.ofm)
                 # Create scalar containing 1/N
                 quant = QuantizationParameters()
                 quant.zero_point = 0
@@ -1492,11 +1493,23 @@
                 n = int(h * w)
                 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
                 quant.scale_f32 = 1 / (n - eps)
+
+                # For int8/int16 we could use IFM/OFM scaling to do the division
+                # intermediate * 1 -> scale > round and shift.
+                #
+                # For int32 scaling is not supported so instead multiply with the scale
+                # intermediate * scale -> round and shift.
+                #
+                # Calculate the scale and shift value. const Tensor must be created
+                # with correct quantization since the scale and shift is calculated later
+                # in the command stream generator.
+                mul_scale, _ = scaling.elementwise_mul_scale(
+                    mul_op.ifm.quantization.scale_f32, quant.scale_f32, mul_op.ofm.quantization.scale_f32
+                )
                 scalar = create_const_tensor(
-                    op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
+                    op.name + "_scalar", [1, 1, 1, 1], DataType.int32, [mul_scale], np.int32, quantization=quant
                 )
                 mul_op.add_input_tensor(scalar)
-                mul_op.set_output_tensor(op.ofm)
                 mul_op.set_ifm_ofm_shapes()
                 mul_op.rounding_mode = NpuRoundingMode.NATURAL
                 mul_op.activation = op.activation
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index f01a669..24cc26e 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -796,9 +796,10 @@
         max_prod = cls.mean_kernel_product
         return h * w <= max_prod, f"Product of height and width is {h * w}"
 
-    @staticmethod
-    def constraint_mean_height_width_product_int8(op):
-        """Number of IFM height and width elements might cause accumulator saturation when;
+    @classmethod
+    @docstring_format_args([mean_kernel_product_int8])
+    def constraint_mean_height_width_product_int8(cls, op):
+        """Product of IFM height and width must be no greater than {} when:
         The IFM shape has 4 dimensions; and
         The axis indices specify reduction across 2 dimensions; and
         The axis indices correspond to the width and height dimensions of the IFM; and
@@ -817,43 +818,8 @@
             return True, ""
         h = shape[-3]
         w = shape[-2]
-
-        ifmq, ofmq = op.ifm.quantization, op.ofm.quantization
-
-        # Scale factor
-        real_scale = ifmq.scale_f32 / ofmq.scale_f32
-
-        # Min and max value
-        ifm_min_val = np.iinfo(np.int8).min - ifmq.zero_point
-        ifm_max_val = np.iinfo(np.int8).max - ifmq.zero_point
-
-        # Accumulator limits
-        min_acc_limit = np.iinfo(np.int16).min
-        max_acc_limit = np.iinfo(np.int16).max
-
-        # Theoretical max/min value that accumulator need to store
-        min_acc_sum = h * w * ifm_min_val * real_scale + ofmq.zero_point
-        max_acc_sum = h * w * ifm_max_val * real_scale + ofmq.zero_point
-
-        # Max product of heigth and width that will not saturate the accumulator
-        ifm_min_val = 1 if ifm_min_val == 0 else ifm_min_val
-        ifm_max_val = 1 if ifm_max_val == 0 else ifm_max_val
-        if max_acc_sum > abs(min_acc_sum):
-            max_hw = int((max_acc_limit - ofmq.zero_point) / real_scale / ifm_max_val)
-        else:
-            max_hw = int((min_acc_limit - ofmq.zero_point) / real_scale / ifm_min_val)
-
-        extra = []
-
-        extra.append(f"   Possible accumulator range is ({min_acc_sum} - {max_acc_sum})\n")
-        extra.append(f"   Maximum  accumulator range is ({min_acc_limit} - {max_acc_limit})\n")
-        extra.append(
-            f"   Based on the IFM and OFM quantization the IFM height and width must be no greater than {max_hw}"
-        )
-
-        extra = "".join(extra)
-
-        return (min_acc_sum >= min_acc_limit and max_acc_sum <= max_acc_limit, f"\n{extra}")
+        max_prod = cls.mean_kernel_product_int8
+        return h * w <= max_prod, f"Product of height and width is {h * w}"
 
     @classmethod
     @docstring_format_args([filter_height_range[1], dilated_height_range[1]])