MLBEDSW-1499: Add MEAN operator

This commit adds support for the MEAN operator,
with some caveats.

Signed-off-by: Dwight Lidman <dwight.lidman@arm.com>
Change-Id: I165cb26cb5aefd68e70d2cfc68291ccf7b778921
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index e1ceb9f..4e7c0fd 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -1274,6 +1274,148 @@
     return op
 
 
+def convert_mean_to_depthwise_conv(op, arch, nng):
+    if op.type == Op.Mean and op.run_on_npu:
+        keep_dims = op.attrs.get("keep_dims", False)
+        inp, axis = op.inputs
+        shape = inp.shape
+        dims = len(shape)
+
+        # Height and width axes have different index depending on dimensions
+        if axis.shape == []:  # single axis
+            axis = int(axis.values)
+            if dims in (2, 3):
+                if axis == 0:
+                    h, w = shape[axis], 1
+                else:
+                    h, w = 1, shape[axis]
+            else:
+                if axis == 1:
+                    h, w = shape[axis], 1
+                else:
+                    h, w = 1, shape[axis]
+        else:  # multiple axes
+            axis = sorted(axis.values)
+            h, w = [shape[i] for i in axis]
+
+        # Set necessary depthwise attributes
+        op.attrs.update(
+            {
+                "padding": Padding.VALID,
+                "stride_h": 1,
+                "stride_w": 1,
+                "strides": (1, 1, 1, 1),
+                "depth_multiplier": 1,
+                "channel_multiplier": 1,
+                "dilation_h_factor": 1,
+                "dilation_w_factor": 1,
+                "dilation": (1, 1, 1, 1),
+            }
+        )
+        # Change op type
+        op.type = Op.DepthwiseConv2DBias
+        # Set IFM/OFM shapes after changing op type
+        op.set_ifm_ofm_shapes()
+
+        ofmq, ifmq = op.ofm.quantization, inp.quantization
+        # Set rounding mode, scaling and zero point based on which reference implementation to match
+        if len(shape) == 4 and axis == [1, 2] and keep_dims:
+            if inp.dtype == DataType.uint8:
+                # This attribute means a different scaling calculation is used in order to match reference
+                op.low_precision_scaling = True
+                weight_scale = h * w
+                foq = ofmq.clone()
+                foq.zero_point -= int(np.round(ifmq.zero_point * ifmq.scale_f32 / foq.scale_f32))
+                op.forced_output_quantization = foq
+                fiq = ifmq.clone()
+                fiq.zero_point = 0
+                op.forced_input_quantization = fiq
+            else:
+                assert inp.dtype == DataType.int8
+                # Use a depthwise to calculate the sum,
+                # followed by a multiplication with 1/N to get the MEAN
+                op.type = Op.DepthwiseConv2DBias
+                weight_scale = 1
+                intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
+                intermediate.dtype = DataType.int16
+                mul_op = Operation(Op.Mul, op.name + "_mul")
+                mul_op.add_input_tensor(intermediate)
+                # Create scalar containing 1/N
+                quant = QuantizationParameters()
+                quant.zero_point = 0
+                # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
+                # while rounding mode NATURAL would round this to -1.
+                # This can only occur if N is even, and can be emulated by
+                # multiplying with a number that is slightly smaller than 1/N.
+                # It must be so small that other roundings are not affected;
+                # the calculated value is based on worst case,
+                # which is sum 256 * N (the maximum sum that can occur with int8)
+                n = int(h * w)
+                eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
+                quant.scale_f32 = 1 / (n - eps)
+                scalar = create_const_tensor(
+                    op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
+                )
+                mul_op.add_input_tensor(scalar)
+                mul_op.set_output_tensor(op.ofm)
+                mul_op.set_ifm_ofm_shapes()
+                mul_op.rounding_mode = NpuRoundingMode.NATURAL
+                mul_op.activation = op.activation
+                op.activation = None
+                op.set_output_tensor(intermediate)
+                op.set_ifm_ofm_shapes()
+        elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
+            op.rounding_mode = NpuRoundingMode.TRUNCATE
+            weight_scale = 1 / (h * w)
+            foq = ofmq.clone()
+            foq.zero_point = 0
+            op.forced_output_quantization = foq
+            fiq = ifmq.clone()
+            fiq.zero_point = 0
+            op.forced_input_quantization = fiq
+        else:
+            raise UnsupportedFeatureError("Mean operators with these attributes are currently not supported")
+
+        # Change dimensions to 4
+        if dims < 4:
+            shape = [1] + shape
+            if dims == 2:
+                shape += [1]
+
+        # If height is greater than max kernel height, reshape to from HxW to 1x(HxW)
+        if h > 64:
+            shape = [shape[0], 1, h * w, shape[3]]
+            op.ifm_shapes[0] = Shape4D(shape)
+            inp.avoid_NHCWB16 = True
+
+        # Add None bias tensor
+        op.inputs.append(None)
+        # Make unit weight tensor quantization
+        weight_quant = inp.quantization.clone()
+        weight_quant.min = 0
+        weight_quant.max = 255
+        weight_quant.scale_f32 = weight_scale
+        weight_quant.zero_point = 0
+
+        # Set weight shape to [H,W,C,B]
+        weight_shape = shape[1:4] + [shape[0]]
+        # Add unit weight tensor
+        op.set_input_tensor(
+            create_const_tensor(
+                "weights",
+                weight_shape,
+                inp.dtype,
+                np.ones(weight_shape),
+                value_dtype=np.uint8,
+                quantization=weight_quant,
+            ),
+            1,
+        )
+        op.inputs[1].quant_values = np.reshape(op.inputs[1].quant_values, weight_shape)
+
+    return op
+
+
 def supported_operator_check(op, arch, nng):
     op.run_on_npu = arch.supported_operators.is_operator_supported(op)
     return op
@@ -1337,6 +1479,7 @@
 
     op_rewrite_list = [
         set_tensor_equivalence,
+        convert_mean_to_depthwise_conv,
         convert_depthwise_to_conv,
         convert_conv_to_fc,
         convert_softmax,
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 1059e6e..56c5e74 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -209,13 +209,15 @@
 
 def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
     """Gets quantization for IFM/IFM2"""
-    if tens.quantization is None:
+    op = ps.primary_op
+    ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
+    if ifm_quant is None:
         return None
     if use_zero_point_0(ps, tens, True):
         zero_point = 0
     else:
-        zero_point = int(tens.quantization.zero_point)
-    return NpuQuantization(scale_f32=tens.quantization.scale_f32, zero_point=zero_point)
+        zero_point = int(ifm_quant.zero_point)
+    return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
 
 
 def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
@@ -389,8 +391,7 @@
     npu_op = NpuPoolingOperation(pool_op)
     set_common_op_fields(npu_op, cmd, arch)
     # Pooling specific info
-    if op.type == Op.ResizeBilinear:
-        npu_op.rescale = op.rescale
+    npu_op.rescale = op.rescale
     return npu_op
 
 
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index 820c7d6..e315f1f 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -509,8 +509,7 @@
                 )
             else:
                 weight_tensor_shape = [
-                    primary_op.attrs["ksize"][1],
-                    primary_op.attrs["ksize"][2],
+                    *primary_op.get_kernel_size(),
                     1,
                     ifm_tensor_shape.depth,
                 ]
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 967d30b..d2b08b5 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -196,7 +196,7 @@
     Max = OperatorInfo()
     MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
     Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
-    Mean = OperatorInfo()
+    Mean = OperatorInfo(indices=IFM_INDICES)
     Min = OperatorInfo()
     Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
     MirrorPad = OperatorInfo()
@@ -414,6 +414,7 @@
         "run_on_npu",
         "activation",
         "memory_function",
+        "forced_input_quantization",
         "forced_output_quantization",
         "activation_lut",
         "_kernel",
@@ -422,6 +423,7 @@
         "rescale",
         "read_offsets",
         "rounding_mode",
+        "low_precision_scaling",
     )
 
     def __init__(self, op_type: Op, name: str):
@@ -439,6 +441,7 @@
         self.memory_function = None
         # If not none: contains QuantizationParameters to be used as output quantization
         # (which overrides the ofm tensor's quantization), used in LUT
+        self.forced_input_quantization = None
         self.forced_output_quantization = None
         self.scheduled_pass = None
         self.op_index = None  # input network operator index
@@ -451,6 +454,9 @@
         self.rescale = None
         self.read_offsets: List[Shape4D] = [None, None]  # offset for [ifm, ifm2]
         self.rounding_mode: Optional[NpuRoundingMode] = None
+        # The Mean operator (implemented as a depthwise convolution) requires scaling
+        # to be calculated differently in one case. In that case, this is set to True.
+        self.low_precision_scaling = False
 
     def clone(self, suffix="_clone"):
         res = Operation(self.type, self.name + suffix)
@@ -463,11 +469,13 @@
         res.run_on_npu = self.run_on_npu
         res.activation = None if self.activation is None else self.activation.clone()
         res.memory_function = self.memory_function
+        res.forced_input_quantization = self.forced_input_quantization
         res.forced_output_quantization = self.forced_output_quantization
         res.scheduled_pass = self.scheduled_pass
         res.op_index = None  # not relevant as not part of input network
         res.read_offsets = list(self.read_offsets)
         res.rounding_mode = self.rounding_mode
+        res.low_precision_scaling = self.low_precision_scaling
 
         return res
 
@@ -692,6 +700,11 @@
         if self not in tens.consumer_list:
             tens.consumer_list.append(self)
 
+    def get_input_quantization(self):
+        if self.forced_input_quantization is not None:
+            return self.forced_input_quantization
+        return self.ifm.quantization
+
     def set_output_tensor(self, tens):
         tens.ops = [self]
         self.outputs = [tens]
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 8b759be..a82f812 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -75,6 +75,8 @@
         | resizing_ops
         # FC layers
         | fc_vector_products
+        # Mean (converts to depthwise conv)
+        | set((Op.Mean,))
     )
     unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
     binary_elem_wise_min_max_ops = set((Op.Minimum, Op.Maximum,))
@@ -99,7 +101,7 @@
     split_ops = set((Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack,))
     concat_ops = set((Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack,))
     memory_only_ops = set((Op.Reshape, Op.QuantizedReshape,)) | concat_ops | split_ops
-    shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV,))
+    shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean))
     per_axis_quant_ops = convolution_like_ops  # per-axis/channel quantization only currently supported for conv ops
     supported_fused_activations = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.LUT,))
     supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | pad_ops | npu_post_ops | memory_only_ops
@@ -118,6 +120,8 @@
     filter_range = (1, 8)
     filter_height_range = (1, 256)
     filter_product_range = (1, 256 * 256)
+    mean_kernel_product = 64 * 64
+    mean_kernel_product_int8 = 16 * 16
     # Supported consumers
     supported_pad_consumers = convolution_ops | depthwise_convolution_ops | pooling_ops
 
@@ -268,6 +272,13 @@
         # HardSwish specific checks:
         self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_input_8bit)
         self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_matching_in_out_types)
+        # Mean specific checks:
+        self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_input_8bit)
+        self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_properties)
+        self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_input_dims)
+        self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_axis)
+        self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_height_width_product)
+        self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_height_width_product_int8)
 
     def is_operator_supported(self, op):
         ext_type = optype_to_builtintype(op.type)
@@ -1077,3 +1088,83 @@
         if op.attrs.get("keep_num_dims"):
             valid = len(op.ifm.shape) == len(op.ofm.shape)
         return valid, f"Op has ifm shape={op.ifm.shape} and ofm shape={op.ofm.shape}"
+
+    def constraint_mean_input_dims(op):
+        "Input tensor must be at least 2D"
+        dims = len(op.inputs[0].shape)
+        return 2 <= dims <= 4, f"Input is {dims}D"
+
+    @staticmethod
+    def constraint_mean_axis(op):
+        "Axis indices must correspond to height and width axes"
+        dims = len(op.inputs[0].shape)
+        axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values)
+        if dims == 2 or dims == 3:
+            valid = axis in (0, 1, [0, 1], [1, 0])
+        elif dims == 4:
+            valid = axis in (1, 2, [1, 2], [2, 1])
+        return valid, f"Axis is {axis}"
+
+    @classmethod
+    @docstring_format_args([mean_kernel_product])
+    def constraint_mean_height_width_product(cls, op):
+        "Product of height and width can be at most {}"
+        shape = op.inputs[0].shape
+        hi = 0 if len(shape) < 4 else 1
+        h, w = shape[hi : hi + 2]
+        max_prod = cls.mean_kernel_product
+        return h * w <= max_prod, f"Product of height and width is {h * w}"
+
+    @classmethod
+    @docstring_format_args([mean_kernel_product_int8])
+    def constraint_mean_height_width_product_int8(cls, op):
+        """Product of IFM height and width can be at most {} when the following are true:
+        IFM dimensions are 4,
+        Axis indices are 1 and 2,
+        keep_dims is set to True and
+        IFM datatype is int8"""
+        shape = op.ifm.shape
+        axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values)
+        if (
+            len(shape) != 4
+            or op.ifm.dtype != DataType.int8
+            or not op.attrs.get("keep_dims")
+            or axis not in ([1, 2], [2, 1])
+        ):
+            return True, ""
+        hi = 0 if len(shape) < 4 else 1
+        h, w = shape[hi : hi + 2]
+        max_prod = cls.mean_kernel_product_int8
+        return h * w <= max_prod, f"Product of height and width is {h * w}"
+
+    @staticmethod
+    def constraint_mean_properties(op):
+        """Every constraint in either one (or both) of the following sets of constraints must be fulfilled:
+        Set A:
+            IFM dimensions are 4,
+            Axis indices are 1 and 2,
+            keep_dims is set to True
+        Set B:
+            IFM zero point and OFM zero point are the same,
+            IFM scale and OFM scale are the same"""
+        seta, setb = True, True
+        extra = []
+        axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values)
+        if len(op.ifm.shape) != 4:
+            seta = False
+            extra.append(f"IFM shape is {op.ifm.shape}")
+        if not any(np.array_equal(axis, ax) for ax in ([1, 2], [2, 1])):
+            seta = False
+            extra.append(f"Axis is {axis}")
+        if not op.attrs.get("keep_dims"):
+            seta = False
+            extra.append("keep_dims is False")
+        ifmq, ofmq = op.ifm.quantization, op.ofm.quantization
+        if ifmq.zero_point != ofmq.zero_point:
+            setb = False
+            extra.append("IFM zero point does not match OFM zero point")
+        if ifmq.scale_f32 != ofmq.scale_f32:
+            setb = False
+            extra.append("IFM scale does not match OFM scale")
+        extra = ", ".join(extra)
+        return seta or setb, f"The following constraints were not fulfilled: {extra}"
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 832d60f..cd331fd 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -953,3 +953,47 @@
     assert not support.is_operator_supported(op)
     op.attrs["keep_num_dims"] = False
     assert support.is_operator_supported(op)
+
+
+def create_mean(input_shape, output_shape, indices, datatype, attrs):
+    ifm = Tensor(input_shape, datatype, "in")
+    ifm.quantization = testutil.default_quant_params()
+    indices = create_const_tensor("indices", [len(indices)], DataType.int32, indices, np.uint8)
+    ofm = Tensor(output_shape, datatype, "out")
+    ofm.quantization = testutil.default_quant_params()
+    op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
+    return op
+
+
+def test_mean_dtype():
+    op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+    assert support.is_operator_supported(op)
+    op.ifm.dtype = DataType.int16
+    op.ofm.dtype = DataType.int16
+    assert not support.is_operator_supported(op)
+
+
+def test_mean_properties():
+    op = create_mean([1, 6, 6, 256], [1, 1, 256], [1, 2], DataType.uint8, {})
+    assert support.is_operator_supported(op)
+    op.ifm.quantization.zero_point = 55
+    assert not support.is_operator_supported(op)
+
+
+def test_mean_axis():
+    op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
+    assert not support.is_operator_supported(op)
+
+
+def test_mean_hw_product():
+    op = create_mean([1, 64, 64, 16], [1, 1, 16], [1, 2], DataType.uint8, {})
+    assert support.is_operator_supported(op)
+    op = create_mean([1, 65, 64, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+    assert not support.is_operator_supported(op)
+
+
+def test_mean_hw_product_int8():
+    op = create_mean([1, 16, 16, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+    assert support.is_operator_supported(op)
+    op = create_mean([1, 16, 17, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+    assert not support.is_operator_supported(op)
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index 41d57c0..b526ec5 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -563,7 +563,7 @@
     BuiltinOperator.BATCH_TO_SPACE_ND: (Op.BatchToSpaceND, OptionsSerializer("BatchToSpaceNDOptions")),
     BuiltinOperator.SPACE_TO_BATCH_ND: (Op.SpaceToBatchND, OptionsSerializer("SpaceToBatchNDOptions")),
     BuiltinOperator.TRANSPOSE: (Op.Transpose, OptionsSerializer("TransposeOptions")),
-    BuiltinOperator.MEAN: (Op.Mean, None),
+    BuiltinOperator.MEAN: (Op.Mean, reducer_opts),
     BuiltinOperator.SUB: (Op.Sub, OptionsSerializer("SubOptions", (fused_act, "pot_scale_int16",))),
     BuiltinOperator.DIV: (Op.Div, OptionsSerializer("DivOptions", (fused_act,))),
     BuiltinOperator.SQUEEZE: (Op.Squeeze, OptionsSerializer("SqueezeOptions", (("squeeze_dims", is_int_vec),))),
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index b291dce..bb7cd67 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -426,13 +426,13 @@
 
     first_consumer_op = tens.consumer_list[0]
     ifm_dtype = first_consumer_op.inputs[0].dtype
-    ifm_scale = first_consumer_op.inputs[0].quantization.scale_f32
+    ifm_scale = first_consumer_op.get_input_quantization().scale_f32
     ofm_scale = first_consumer_op.get_output_quantization().scale_f32
     weight_scales = first_consumer_op.inputs[1].quantization.scale_f32
 
     # biases can have multiple consumers for rnn cells. if so, then check that they are all the same
     for op in tens.consumer_list[1:]:
-        assert ifm_scale == op.inputs[0].quantization.scale_f32
+        assert ifm_scale == op.get_input_quantization().scale_f32
         assert ofm_scale == op.get_output_quantization().scale_f32
         assert weight_scales == op.inputs[1].quantization.scale_f32
 
@@ -445,7 +445,14 @@
     # TensorFlow Lite casts the scales slightly differently for uint8 and int8
     if not rescale_for_faf:
         if ifm_dtype == DataType.uint8:
-            scales = [np.double(ifm_scale * weight_scale) / np.double(ofm_scale) for weight_scale in weight_scales]
+            # for some cases of the Mean operator, the scale must be calculated differently to match reference
+            if first_consumer_op.low_precision_scaling:
+                scales = [
+                    np.double(np.single(ifm_scale) / (np.single(weight_scale) * np.single(ofm_scale)))
+                    for weight_scale in weight_scales
+                ]
+            else:
+                scales = [np.double(ifm_scale * weight_scale) / np.double(ofm_scale) for weight_scale in weight_scales]
         elif ifm_dtype == DataType.int8 or ifm_dtype == DataType.int16:
             scales = [
                 (np.double(ifm_scale) * np.double(weight_scale)) / np.double(ofm_scale)