MLBEDSW-4215: Add support for MEAN to match QuantizedMeanOrSum implementation

This commit adds support for emulating the behavior
of the QuantizedMeanOrSum implementation of MEAN in
TensorFlow Lite.

Signed-off-by: Dwight Lidman <dwight.lidman@arm.com>
Change-Id: Ifd24e0e678e2f85cd66ab82deeaaf010d5351b1e
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index 3c90e20..1ad65c6 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -197,14 +197,6 @@
 This is a list of constraints that the MEAN operator must satisfy in order to be scheduled on the NPU.
 
 - IFM must be int8 or uint8
-- Every constraint in either one (or both) of the following sets of constraints must be fulfilled:  
-        Set A:  
-            IFM dimensions are 4,  
-            Axis indices are 1 and 2,  
-            keep_dims is set to True  
-        Set B:  
-            IFM zero point and OFM zero point are the same,  
-            IFM scale and OFM scale are the same
 - Input tensor must be at least 2D
 - Axis indices must correspond to height and width axes
 - Product of height and width can be at most 4096
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 3084117..e8218fc 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -1422,9 +1422,12 @@
         )
         # Change op type
         op.type = Op.DepthwiseConv2DBias
+        # Add None bias tensor
+        op.inputs.append(None)
         # Set IFM/OFM shapes after changing op type
         op.set_ifm_ofm_shapes()
 
+        weight_scale, bias = 1, None
         ofmq, ifmq = op.ofm.quantization, inp.quantization
         # Set rounding mode, scaling and zero point based on which reference implementation to match
         if len(shape) == 4 and axis == [1, 2] and keep_dims:
@@ -1442,7 +1445,6 @@
                 assert inp.dtype == DataType.int8
                 # Use a depthwise to calculate the sum,
                 # followed by a multiplication with 1/N to get the MEAN
-                op.type = Op.DepthwiseConv2DBias
                 weight_scale = 1
                 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
                 intermediate.dtype = DataType.int16
@@ -1482,7 +1484,13 @@
             fiq.zero_point = 0
             op.forced_input_quantization = fiq
         else:
-            raise UnsupportedFeatureError("Mean operators with these attributes are currently not supported")
+            op.rounding_mode = NpuRoundingMode.NATURAL
+            weight_scale = 1 / (h * w)
+            # Input zero point is adjusted after mean calculation, so we emulate that with a bias
+            bias = -ifmq.zero_point * h * w
+            fiq = ifmq.clone()
+            fiq.zero_point = 0
+            op.forced_input_quantization = fiq
 
         # Change dimensions to 4
         if dims < 4:
@@ -1496,10 +1504,8 @@
             op.ifm_shapes[0] = Shape4D(shape)
             inp.avoid_NHCWB16 = True
 
-        # Add None bias tensor
-        op.inputs.append(None)
         # Make unit weight tensor quantization
-        weight_quant = inp.quantization.clone()
+        weight_quant = ifmq.clone()
         weight_quant.min = 0
         weight_quant.max = 255
         weight_quant.scale_f32 = weight_scale
@@ -1519,7 +1525,23 @@
             ),
             1,
         )
-        op.inputs[1].quant_values = np.reshape(op.inputs[1].quant_values, weight_shape)
+        op.weights.quant_values = np.reshape(op.inputs[1].quant_values, weight_shape)
+
+        # Add bias tensor
+        if bias:
+            bias_shape = [shape[-1]]
+            op.set_input_tensor(
+                create_const_tensor(
+                    "bias",
+                    bias_shape,
+                    inp.dtype,
+                    np.ones(bias_shape) * bias,
+                    value_dtype=np.int32,
+                    quant_value_dtype=np.int32,
+                    quantization=None,
+                ),
+                2,
+            )
 
     return op
 
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 2319706..777e9c7 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -270,7 +270,6 @@
         self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_matching_in_out_types)
         # Mean specific checks:
         self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_input_8bit)
-        self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_properties)
         self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_input_dims)
         self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_axis)
         self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_height_width_product)
@@ -1076,35 +1075,3 @@
         h, w = shape[hi : hi + 2]
         max_prod = cls.mean_kernel_product_int8
         return h * w <= max_prod, f"Product of height and width is {h * w}"
-
-    @staticmethod
-    def constraint_mean_properties(op):
-        """Every constraint in either one (or both) of the following sets of constraints must be fulfilled:
-        Set A:
-            IFM dimensions are 4,
-            Axis indices are 1 and 2,
-            keep_dims is set to True
-        Set B:
-            IFM zero point and OFM zero point are the same,
-            IFM scale and OFM scale are the same"""
-        seta, setb = True, True
-        extra = []
-        axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values)
-        if len(op.ifm.shape) != 4:
-            seta = False
-            extra.append(f"IFM shape is {op.ifm.shape}")
-        if not any(np.array_equal(axis, ax) for ax in ([1, 2], [2, 1])):
-            seta = False
-            extra.append(f"Axis is {axis}")
-        if not op.attrs.get("keep_dims"):
-            seta = False
-            extra.append("keep_dims is False")
-        ifmq, ofmq = op.ifm.quantization, op.ofm.quantization
-        if ifmq.zero_point != ofmq.zero_point:
-            setb = False
-            extra.append("IFM zero point does not match OFM zero point")
-        if ifmq.scale_f32 != ofmq.scale_f32:
-            setb = False
-            extra.append("IFM scale does not match OFM scale")
-        extra = ", ".join(extra)
-        return seta or setb, f"The following constraints were not fulfilled: {extra}"
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 97885d0..e915363 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -300,13 +300,16 @@
     value_dtype: np.dtype = None,
     purpose: TensorPurpose = TensorPurpose.Unknown,
     quantization: QuantizationParameters = None,
+    quant_value_dtype: np.dtype = None,
 ):
     # Tensor
     const_tensor = Tensor(shape, dtype, name + "_0")
     const_tensor.purpose = purpose
     const_tensor.quantization = quantization
     const_tensor.values = np.array(values, dtype=value_dtype)
-    const_tensor.quant_values = np.frombuffer(const_tensor.values.tobytes(), dtype=np.uint8)
+    const_tensor.quant_values = np.frombuffer(
+        const_tensor.values.tobytes(), dtype=np.uint8 if not quant_value_dtype else quant_value_dtype
+    )
     # Operator
     const_op = Operation(Op.Const, name)
     const_op.set_output_tensor(const_tensor)
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 34ddb90..aad2849 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -858,13 +858,6 @@
     assert not support.is_operator_supported(op)
 
 
-def test_mean_properties():
-    op = create_mean([1, 6, 6, 256], [1, 1, 256], [1, 2], DataType.uint8, {})
-    assert support.is_operator_supported(op)
-    op.ifm.quantization.zero_point = 55
-    assert not support.is_operator_supported(op)
-
-
 def test_mean_axis():
     op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
     assert not support.is_operator_supported(op)