MLBEDSW-7652: Add mean support for batch and channel when shape is 1

- Add support for batch and depth channels when shape is 1
- Refactor reshaping in convert_mean_to_depthwise_conv

Signed-off-by: Alexander Hansson <Alexander.Hansson@arm.com>
Change-Id: If663395934ab58c76ba92b6ebaaf484a389ae699
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index c1c58d3..d642fc5 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -19,7 +19,7 @@
 # Supported Ops
 
 This file was automatically generated by Vela using the `--supported-ops-report` parameter.  
-Vela version: `3.8.1.dev9+g85b7790.d20230616`
+Vela version: `3.8.1.dev14+ge59d5ed1.d20230707`
 
 This file complies with
 [**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md)
@@ -251,12 +251,18 @@
 This is a list of constraints that the MEAN operator must satisfy in order to be scheduled on the NPU.
 
 - Input tensor must be at least 2D
-- Axis indices must correspond to height and width axes
-- Product of height and width must be no greater than:  
+- Requirements for axis parameter:  
+        When IFM tensor is 2D:  
+          - Reduction in both axes is supported.  
+        When IFM tensor is 3D or 4D:  
+          - Reduction in Batch axis is only supported if batch size is 1.  
+          - Reduction in both Height and Width axes is supported.  
+          - Reduction in Depth axis is only supported if depth is 1.
+- Product of reduced axes must be no greater than:  
         - 16777216 for signed 8-bit inputs  
         - 8388608 for unsigned 8-bit inputs  
         - 65536 for signed 16-bit inputs
-- Width must be no greater than 4096
+- If Width axis is reduced its shape must be no greater than 4096.
 
 ### TFLite MINIMUM Constraints
 
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index ebfdbf3..7a82d2c 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -506,14 +506,21 @@
 
 
 def test_mean_axis():
-    op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True})
-    assert not semantic_checker.is_operator_semantic_valid(op)
     op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True})
     assert not semantic_checker.is_operator_semantic_valid(op)
-    op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True})
+    op = create_mean([1, 6, 6, 1], [1, 1, 1, 1], [3], DataType.int8, {"keep_dims": True})
+    assert semantic_checker.is_operator_semantic_valid(op)
+
+    op = create_mean([2, 6, 6, 16], [2, 1, 1, 16], [0], DataType.int8, {"keep_dims": True})
+    assert not semantic_checker.is_operator_semantic_valid(op)
+    op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True})
+    assert semantic_checker.is_operator_semantic_valid(op)
+
+    op = create_mean([2, 6, 6, 16], [2, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
     assert not semantic_checker.is_operator_semantic_valid(op)
     op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
-    assert not semantic_checker.is_operator_semantic_valid(op)
+    assert semantic_checker.is_operator_semantic_valid(op)
+
     op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
     assert semantic_checker.is_operator_semantic_valid(op)
     op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 28dead1..a12eeb3 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1982,58 +1982,59 @@
         max_kernel_size = 4096
         max_height = 64
         inp, axis = op.inputs
-        shape = inp.shape
-        ofm_shape = op.ofm.shape
-        dims = len(shape)
-        dims_ofm = len(ofm_shape)
+        dims = len(inp.shape)
+        dims_ofm = len(op.ofm.shape)
         ofmq = op.ofm.quantization
         ifmq = op.ifm.quantization
 
-        # Height and width axes have different index depending on dimensions
-        if axis.shape == [] or axis.shape[0] == 1:  # single axis
-            axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
-            # If dims is 4, axis 1 refers to h-dimension
-            if dims == 4:
-                reduce_h, reduce_w = (True, False) if axis == 1 else (False, True)
-            else:
-                reduce_h, reduce_w = (True, False) if axis == 0 else (False, True)
-        else:  # multiple axes
-            axis = sorted(axis.values)
-            reduce_h, reduce_w = (True, True)
+        # reduce_axis[i] is true if axis i should be reduced
+        if axis.shape == []:
+            reduce_axis = [True if i == axis.values else False for i in range(dims)]
+        else:
+            reduce_axis = [True if i in axis.values else False for i in range(dims)]
 
-        # Change dimensions to 4
-        def extend_dims(dim, in_shape):
-            if dim < 4:
-                in_shape = [1] + in_shape
-                if dim == 2:
-                    in_shape += [1]
-            return in_shape
+        ifm_shape = inp.shape.copy()
+        intermediate_shape = op.ofm.shape.copy()
 
-        if dims < 4 or dims_ofm < 4:
-            # Fix the ofm dimension when keep_dims is false
-            # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
-            if isinstance(axis, int) and dims_ofm + 1 == dims:
-                ofm_shape.insert(axis, 1)
-            elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
-                for i in axis:
-                    ofm_shape.insert(i, 1)
-            shape = extend_dims(dims, shape)
-            dims_ofm = len(ofm_shape)
-            ofm_shape = extend_dims(dims_ofm, ofm_shape)
-            op.set_ifm_ofm_shapes()
+        # Fix intermediate_shape when keep_dims is false
+        # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the intermediate_shape should be 1xHx1xC
+        if dims_ofm < dims:
+            for i in range(dims):
+                if reduce_axis[i]:
+                    intermediate_shape.insert(i, 1)
 
-        # Compute kernel sizes for our convolutions
-        h = shape[1] if reduce_h else 1
-        w = shape[2] if reduce_w else 1
+        # Reshape to 4D
+        if dims == 2:
+            # Reshape WxC -> 1xHxWx1 to support both axes
+            reduce_axis = [False] + reduce_axis + [False]
+            ifm_shape = [1] + ifm_shape + [1]
+            intermediate_shape = [1] + intermediate_shape + [1]
+        elif dims == 3:
+            # Reshape to 4D HxWxC -> 1xHxWxC
+            reduce_axis = [False] + reduce_axis
+            ifm_shape = [1] + ifm_shape
+            intermediate_shape = [1] + intermediate_shape
+
+        # If all dimensions to reduce have shape 1, the operation is essentially a memcpy.
+        # We can then remove the whole op by propagating ofm to previous ops
+        if not any([reduce_axis[i] and ifm_shape[i] > 1 for i in range(4)]):
+            op.type = Op.Memcpy
+            op = bypass_memory_only_ops(op, arch, nng)
+            return op
+
+        # Compute kernel sizes for our convolutions.
+        # batch and depth axes are only supported if their shapes are 1.
+        # hence reduction in batch or depth axis is implicit.
+        h = ifm_shape[1] if reduce_axis[1] else 1
+        w = ifm_shape[2] if reduce_axis[2] else 1
+
         num_elements_in_axis = h * w
 
         # If one convolution is enough, but height is greater than max kernel height
         # reshape from HxW to 1x(HxW)
         # This can only be done if the mean is computed over both H and W
-        if h > max_height and num_elements_in_axis <= max_kernel_size and reduce_h and reduce_w:
-            shape = [shape[0], 1, h * w, shape[3]]
-            op.ifm_shapes[0] = Shape4D(shape)
-            op.ifm.shape = shape
+        if h > max_height and num_elements_in_axis <= max_kernel_size and reduce_axis[1] and reduce_axis[2]:
+            ifm_shape = [ifm_shape[0], 1, h * w, ifm_shape[3]]
             w = h * w
             h = 1
 
@@ -2065,10 +2066,11 @@
                 }
             )
 
-            b, _, _, c = shape
+            b, _, _, c = ifm_shape
 
             intermediate_tensor = op.ofm.clone(suffix=f"_conv_sum_{i}", set_unique=True)
             intermediate_tensor.dtype = DataType.int32
+            intermediate_tensor.shape = intermediate_shape
             intermediate_op.set_output_tensor(intermediate_tensor)
 
             # as we have several convs, scaling/rounding must be done after the sum has been calculated
@@ -2081,11 +2083,11 @@
                 weight_h = height_per_conv
 
             # compute ifm read offset and shape for the convolution
-            read_shape_h = weight_h if reduce_h else shape[1]
-            read_shape_w = w if reduce_w else shape[2]
+            read_shape_h = weight_h if reduce_axis[1] else ifm_shape[1]
+            read_shape_w = w if reduce_axis[2] else ifm_shape[2]
 
             intermediate_op.read_offsets[0] = Shape4D([0, i * height_per_conv, 0, 0])
-            intermediate_op.read_shapes[0] = Shape4D(shape).with_hw(read_shape_h, read_shape_w)
+            intermediate_op.read_shapes[0] = Shape4D(ifm_shape).with_hw(read_shape_h, read_shape_w)
 
             weight_quant = QuantizationParameters(0, 255, scale_f32=1.0, zero_point=0)
             weight_shape = [weight_h, w, c, b]
@@ -2112,9 +2114,9 @@
             intermediate_op.inputs.append(bias)
             intermediate_op.set_ifm_ofm_shapes()
 
-            # We want to avoid reshaping the tensor directly, to not affect other ops
+            # We want to avoid reshaping the ifm tensor directly, to not affect other ops
             # so we update the shape explicitly for this operation
-            intermediate_op.ifm_shapes[0] = Shape4D(shape)
+            intermediate_op.ifm_shapes[0] = Shape4D(ifm_shape)
 
             convs.append(intermediate_op)
             DebugDatabase.add_optimised(op, intermediate_op)
@@ -2128,6 +2130,7 @@
             while len(convs):
                 intermediate_tensor = op.ofm.clone(suffix=f"_add_sum_{idx}", set_unique=True)
                 intermediate_tensor.dtype = DataType.int32
+                intermediate_tensor.shape = intermediate_shape
 
                 one_scale_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
 
@@ -2136,7 +2139,6 @@
                     ifm2 = convs.pop().ofm
                 else:
                     ifm2 = prev_add_op.ofm
-
                 intermediate_op = create_add(f"{op.name}_add_{idx}", ifm, ifm2, one_scale_quant)
                 intermediate_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1])
                 intermediate_op.set_output_tensor(intermediate_tensor)
@@ -2180,6 +2182,7 @@
         )
         op.set_input_tensor(scalar, 1)
         op.set_ifm_ofm_shapes()
+        op.ofm_shapes[0] = Shape4D(intermediate_shape)
 
         # Reference using TFL rounding for the multiply
         op.rounding_mode = RoundingMode.TFLite
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 444c04a..56dce14 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -696,14 +696,36 @@
 
     @staticmethod
     def constraint_mean_axis(op):
-        "Axis indices must correspond to height and width axes"
-        dims = len(op.inputs[0].shape)
-        axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
-        if dims == 2 or dims == 3:
-            valid = axis in (0, 1, [0], [1], [0, 1], [1, 0])
-        elif dims == 4:
-            valid = axis in (1, 2, [1], [2], [1, 2], [2, 1])
-        return valid, f"Axis is {axis}"
+        """Requirements for axis parameter:
+        When IFM tensor is 2D:
+          - Reduction in both axes is supported.
+        When IFM tensor is 3D or 4D:
+          - Reduction in Batch axis is only supported if batch size is 1.
+          - Reduction in both Height and Width axes is supported.
+          - Reduction in Depth axis is only supported if depth is 1."""
+        input_shape = op.inputs[0].shape
+        dims = len(input_shape)
+        if op.inputs[1].shape == []:
+            axis = [int(op.inputs[1].values)]
+        else:
+            axis = list(op.inputs[1].values)
+        valid = True
+
+        for ax in axis:
+            if ax < 0 or ax >= dims:
+                return False, "Axis parameter is out of bounds. axis: {axis}, dims: {dims}. "
+            elif dims == 3:
+                # depth is only supported if size is 1
+                if ax == 2 and input_shape[ax] != 1:
+                    valid = False
+                    break
+            else:  # 4D
+                # batch and depth are only supported if sizes are 1
+                if ax in [0, 3] and input_shape[ax] != 1:
+                    valid = False
+                    break
+
+        return valid, f"Shape is {input_shape}, Axis is {axis}."
 
     @staticmethod
     def constraint_matching_in_out_quant(op):
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 92a7f3c..597e0a2 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -843,13 +843,20 @@
     @classmethod
     @docstring_format_args([mean_kernel_product_int8, mean_kernel_product_uint8, mean_kernel_product_int16])
     def constraint_mean_height_width_product(cls, op):
-        """Product of height and width must be no greater than:
+        """Product of reduced axes must be no greater than:
         - {} for signed 8-bit inputs
         - {} for unsigned 8-bit inputs
         - {} for signed 16-bit inputs"""
         shape = op.inputs[0].shape
-        hi = 0 if len(shape) < 4 else 1
-        h, w = shape[hi : hi + 2]
+        if op.inputs[1].shape == []:
+            axis = [int(op.inputs[1].values)]
+        else:
+            axis = list(op.inputs[1].values)
+
+        # compute the product of the shape of all reduced axes
+        axis_shapes = [shape[ax] for ax in axis]
+        prod = np.prod(axis_shapes)
+
         if op.ifm.dtype == DataType.int16:
             max_prod = cls.mean_kernel_product_int16
             datatype = "int16"
@@ -859,43 +866,18 @@
         else:
             max_prod = cls.mean_kernel_product_int8
             datatype = "int8"
-        return h * w <= max_prod, f"Datatype is {datatype}, product of height and width is {h * w}"
+        return prod <= max_prod, f"Datatype is {datatype}, product of axes is {prod}"
 
     @classmethod
     @docstring_format_args([mean_width_size])
     def constraint_mean_width(cls, op):
-        """Width must be no greater than {}"""
+        """If Width axis is reduced its shape must be no greater than {}."""
         shape = op.inputs[0].shape
         hi = 0 if len(shape) < 4 else 1
         h, w = shape[hi : hi + 2]
         max_width = cls.mean_width_size
         return w <= max_width, f"Width is {w}"
 
-    @classmethod
-    @docstring_format_args([dilated_height_range[1]])
-    def constraint_mean_height_single_axis(cls, op):
-        """For single axis averages across the height dimension:
-        IFM height must be no greater than {}"""
-        inp, axis = op.inputs
-        if axis.shape == [] or axis.shape[0] == 1:  # single axis
-            axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
-        else:
-            # Multiple axes
-            return True, ""
-
-        shape = inp.shape
-        if len(shape) < 3:
-            # No height dimension present in IFM
-            return True, ""
-        if axis != len(shape) - 3:
-            # Not averaging across the height dimension
-            return True, ""
-
-        h = shape[axis]
-        ifm, ofm = op.get_ifm_ofm()
-
-        return h <= cls.dilated_height_range[1], f"Height is {h}"
-
     @staticmethod
     def constraint_reshape_shape_constant(op):
         "Shape must be constant"