MLBEDSW-1499: Add MEAN operator

This commit adds support for the MEAN operator,
with some caveats.

Signed-off-by: Dwight Lidman <dwight.lidman@arm.com>
Change-Id: I165cb26cb5aefd68e70d2cfc68291ccf7b778921
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index e1ceb9f..4e7c0fd 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -1274,6 +1274,148 @@
     return op
 
 
+def convert_mean_to_depthwise_conv(op, arch, nng):
+    if op.type == Op.Mean and op.run_on_npu:
+        keep_dims = op.attrs.get("keep_dims", False)
+        inp, axis = op.inputs
+        shape = inp.shape
+        dims = len(shape)
+
+        # Height and width axes have different index depending on dimensions
+        if axis.shape == []:  # single axis
+            axis = int(axis.values)
+            if dims in (2, 3):
+                if axis == 0:
+                    h, w = shape[axis], 1
+                else:
+                    h, w = 1, shape[axis]
+            else:
+                if axis == 1:
+                    h, w = shape[axis], 1
+                else:
+                    h, w = 1, shape[axis]
+        else:  # multiple axes
+            axis = sorted(axis.values)
+            h, w = [shape[i] for i in axis]
+
+        # Set necessary depthwise attributes
+        op.attrs.update(
+            {
+                "padding": Padding.VALID,
+                "stride_h": 1,
+                "stride_w": 1,
+                "strides": (1, 1, 1, 1),
+                "depth_multiplier": 1,
+                "channel_multiplier": 1,
+                "dilation_h_factor": 1,
+                "dilation_w_factor": 1,
+                "dilation": (1, 1, 1, 1),
+            }
+        )
+        # Change op type
+        op.type = Op.DepthwiseConv2DBias
+        # Set IFM/OFM shapes after changing op type
+        op.set_ifm_ofm_shapes()
+
+        ofmq, ifmq = op.ofm.quantization, inp.quantization
+        # Set rounding mode, scaling and zero point based on which reference implementation to match
+        if len(shape) == 4 and axis == [1, 2] and keep_dims:
+            if inp.dtype == DataType.uint8:
+                # This attribute means a different scaling calculation is used in order to match reference
+                op.low_precision_scaling = True
+                weight_scale = h * w
+                foq = ofmq.clone()
+                foq.zero_point -= int(np.round(ifmq.zero_point * ifmq.scale_f32 / foq.scale_f32))
+                op.forced_output_quantization = foq
+                fiq = ifmq.clone()
+                fiq.zero_point = 0
+                op.forced_input_quantization = fiq
+            else:
+                assert inp.dtype == DataType.int8
+                # Use a depthwise to calculate the sum,
+                # followed by a multiplication with 1/N to get the MEAN
+                op.type = Op.DepthwiseConv2DBias
+                weight_scale = 1
+                intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
+                intermediate.dtype = DataType.int16
+                mul_op = Operation(Op.Mul, op.name + "_mul")
+                mul_op.add_input_tensor(intermediate)
+                # Create scalar containing 1/N
+                quant = QuantizationParameters()
+                quant.zero_point = 0
+                # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
+                # while rounding mode NATURAL would round this to -1.
+                # This can only occur if N is even, and can be emulated by
+                # multiplying with a number that is slightly smaller than 1/N.
+                # It must be so small that other roundings are not affected;
+                # the calculated value is based on worst case,
+                # which is sum 256 * N (the maximum sum that can occur with int8)
+                n = int(h * w)
+                eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
+                quant.scale_f32 = 1 / (n - eps)
+                scalar = create_const_tensor(
+                    op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
+                )
+                mul_op.add_input_tensor(scalar)
+                mul_op.set_output_tensor(op.ofm)
+                mul_op.set_ifm_ofm_shapes()
+                mul_op.rounding_mode = NpuRoundingMode.NATURAL
+                mul_op.activation = op.activation
+                op.activation = None
+                op.set_output_tensor(intermediate)
+                op.set_ifm_ofm_shapes()
+        elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
+            op.rounding_mode = NpuRoundingMode.TRUNCATE
+            weight_scale = 1 / (h * w)
+            foq = ofmq.clone()
+            foq.zero_point = 0
+            op.forced_output_quantization = foq
+            fiq = ifmq.clone()
+            fiq.zero_point = 0
+            op.forced_input_quantization = fiq
+        else:
+            raise UnsupportedFeatureError("Mean operators with these attributes are currently not supported")
+
+        # Change dimensions to 4
+        if dims < 4:
+            shape = [1] + shape
+            if dims == 2:
+                shape += [1]
+
+        # If height is greater than max kernel height, reshape to from HxW to 1x(HxW)
+        if h > 64:
+            shape = [shape[0], 1, h * w, shape[3]]
+            op.ifm_shapes[0] = Shape4D(shape)
+            inp.avoid_NHCWB16 = True
+
+        # Add None bias tensor
+        op.inputs.append(None)
+        # Make unit weight tensor quantization
+        weight_quant = inp.quantization.clone()
+        weight_quant.min = 0
+        weight_quant.max = 255
+        weight_quant.scale_f32 = weight_scale
+        weight_quant.zero_point = 0
+
+        # Set weight shape to [H,W,C,B]
+        weight_shape = shape[1:4] + [shape[0]]
+        # Add unit weight tensor
+        op.set_input_tensor(
+            create_const_tensor(
+                "weights",
+                weight_shape,
+                inp.dtype,
+                np.ones(weight_shape),
+                value_dtype=np.uint8,
+                quantization=weight_quant,
+            ),
+            1,
+        )
+        op.inputs[1].quant_values = np.reshape(op.inputs[1].quant_values, weight_shape)
+
+    return op
+
+
 def supported_operator_check(op, arch, nng):
     op.run_on_npu = arch.supported_operators.is_operator_supported(op)
     return op
@@ -1337,6 +1479,7 @@
 
     op_rewrite_list = [
         set_tensor_equivalence,
+        convert_mean_to_depthwise_conv,
         convert_depthwise_to_conv,
         convert_conv_to_fc,
         convert_softmax,