MLBEDSW-7653: Extend Mean support for depth axis

If any of H,W axes have shape 1, the IFM can be reshaped to support
reduction over the depth axis.

Signed-off-by: Alexander Hansson <Alexander.Hansson@arm.com>
Change-Id: I432ff1c399b7cee4ca5f0a8f4461e9c0a936d804
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index d642fc5..7b46b8b 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -19,7 +19,7 @@
 # Supported Ops
 
 This file was automatically generated by Vela using the `--supported-ops-report` parameter.  
-Vela version: `3.8.1.dev14+ge59d5ed1.d20230707`
+Vela version: `3.8.1.dev14+gefc7d21e.d20230707`
 
 This file complies with
 [**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md)
@@ -257,12 +257,13 @@
         When IFM tensor is 3D or 4D:  
           - Reduction in Batch axis is only supported if batch size is 1.  
           - Reduction in both Height and Width axes is supported.  
-          - Reduction in Depth axis is only supported if depth is 1.
+          - Reduction in Depth axis is supported if at least one of H,W,C are of size 1.
 - Product of reduced axes must be no greater than:  
-        - 16777216 for signed 8-bit inputs  
-        - 8388608 for unsigned 8-bit inputs  
-        - 65536 for signed 16-bit inputs
+        - 16777216 for signed 8-bit inputs.  
+        - 8388608 for unsigned 8-bit inputs.  
+        - 65536 for signed 16-bit inputs.
 - If Width axis is reduced its shape must be no greater than 4096.
+- If Depth axis is reduced its shape must be no greater than 4096.
 
 ### TFLite MINIMUM Constraints
 
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index a12eeb3..31d3ae1 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -2004,16 +2004,9 @@
                     intermediate_shape.insert(i, 1)
 
         # Reshape to 4D
-        if dims == 2:
-            # Reshape WxC -> 1xHxWx1 to support both axes
-            reduce_axis = [False] + reduce_axis + [False]
-            ifm_shape = [1] + ifm_shape + [1]
-            intermediate_shape = [1] + intermediate_shape + [1]
-        elif dims == 3:
-            # Reshape to 4D HxWxC -> 1xHxWxC
-            reduce_axis = [False] + reduce_axis
-            ifm_shape = [1] + ifm_shape
-            intermediate_shape = [1] + intermediate_shape
+        reduce_axis = full_shape(4, reduce_axis, False)
+        ifm_shape = full_shape(4, ifm_shape, 1)
+        intermediate_shape = full_shape(4, intermediate_shape, 1)
 
         # If all dimensions to reduce have shape 1, the operation is essentially a memcpy.
         # We can then remove the whole op by propagating ofm to previous ops
@@ -2022,9 +2015,25 @@
             op = bypass_memory_only_ops(op, arch, nng)
             return op
 
-        # Compute kernel sizes for our convolutions.
-        # batch and depth axes are only supported if their shapes are 1.
-        # hence reduction in batch or depth axis is implicit.
+        # Support mean over depth-axis by left-shifting the C channel
+        # From semantics checks we can assume that one of H,W,C has shape 1
+        if reduce_axis[3] and ifm_shape[3] > 1:
+            assert 1 in ifm_shape[1:], "Mean reduction over depth channel, but none of H,W,C has shape 1"
+            # If W=1 reshape NxHx1xC -> NxHxCx1, else reshape Nx1xWxC -> NxWxCx1
+            idx_to_del = 2 if ifm_shape[2] == 1 else 1
+
+            # Delete axis with size 1
+            del reduce_axis[idx_to_del]
+            del ifm_shape[idx_to_del]
+            del intermediate_shape[idx_to_del]
+
+            # Add another element to set channel-axis to one
+            reduce_axis.append(False)
+            ifm_shape.append(1)
+            intermediate_shape.append(1)
+
+        # Compute kernel sizes for our convolutions
+        # Batch axis is implicit as it is only supported if batch size is 1.
         h = ifm_shape[1] if reduce_axis[1] else 1
         w = ifm_shape[2] if reduce_axis[2] else 1
 
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 56dce14..3ac78b2 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -702,7 +702,7 @@
         When IFM tensor is 3D or 4D:
           - Reduction in Batch axis is only supported if batch size is 1.
           - Reduction in both Height and Width axes is supported.
-          - Reduction in Depth axis is only supported if depth is 1."""
+          - Reduction in Depth axis is supported if at least one of H,W,C are of size 1."""
         input_shape = op.inputs[0].shape
         dims = len(input_shape)
         if op.inputs[1].shape == []:
@@ -714,14 +714,22 @@
         for ax in axis:
             if ax < 0 or ax >= dims:
                 return False, "Axis parameter is out of bounds. axis: {axis}, dims: {dims}. "
-            elif dims == 3:
-                # depth is only supported if size is 1
-                if ax == 2 and input_shape[ax] != 1:
+
+            # Batch is only supported if batch shape is 1
+            if dims == 4 and ax == 0:
+                if input_shape[0] != 1:
                     valid = False
                     break
-            else:  # 4D
-                # batch and depth are only supported if sizes are 1
-                if ax in [0, 3] and input_shape[ax] != 1:
+
+            # Depth is supported if any of h,w,c == 1
+            if dims == 3:
+                if ax == 2 and not any([s == 1 for s in input_shape]):
+                    valid = False
+                    break
+
+            # Depth is supported if any of h,w,c == 1
+            if dims == 4:
+                if ax == 3 and not any([s == 1 for s in input_shape[1:]]):
                     valid = False
                     break
 
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 597e0a2..7d54400 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -191,7 +191,7 @@
     filter_range = (1, 8)
     filter_height_range = (1, 256)
     filter_product_range = (1, 256 * 256)
-    mean_width_size = 64 * 64
+    mean_reduced_axis_max_size = 64 * 64
     mean_kernel_product_int8 = 2 ** (24)
     mean_kernel_product_uint8 = 2 ** (23)
     mean_kernel_product_int16 = 2 ** (16)
@@ -315,6 +315,7 @@
         # Mean specific checks:
         self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product)
         self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_width)
+        self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_depth)
 
         # Reshape specific checks:
         self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
@@ -844,9 +845,9 @@
     @docstring_format_args([mean_kernel_product_int8, mean_kernel_product_uint8, mean_kernel_product_int16])
     def constraint_mean_height_width_product(cls, op):
         """Product of reduced axes must be no greater than:
-        - {} for signed 8-bit inputs
-        - {} for unsigned 8-bit inputs
-        - {} for signed 16-bit inputs"""
+        - {} for signed 8-bit inputs.
+        - {} for unsigned 8-bit inputs.
+        - {} for signed 16-bit inputs."""
         shape = op.inputs[0].shape
         if op.inputs[1].shape == []:
             axis = [int(op.inputs[1].values)]
@@ -869,15 +870,35 @@
         return prod <= max_prod, f"Datatype is {datatype}, product of axes is {prod}"
 
     @classmethod
-    @docstring_format_args([mean_width_size])
+    @docstring_format_args([mean_reduced_axis_max_size])
     def constraint_mean_width(cls, op):
         """If Width axis is reduced its shape must be no greater than {}."""
         shape = op.inputs[0].shape
         hi = 0 if len(shape) < 4 else 1
         h, w = shape[hi : hi + 2]
-        max_width = cls.mean_width_size
+        max_width = cls.mean_reduced_axis_max_size
         return w <= max_width, f"Width is {w}"
 
+    @classmethod
+    @docstring_format_args([mean_reduced_axis_max_size])
+    def constraint_mean_depth(cls, op):
+        """If Depth axis is reduced its shape must be no greater than {}."""
+        max_depth = cls.mean_reduced_axis_max_size
+        shape = op.inputs[0].shape
+
+        if op.inputs[1].shape == []:
+            axis = [int(op.inputs[1].values)]
+        else:
+            axis = list(op.inputs[1].values)
+
+        depth_idx = len(shape) - 1
+
+        supported = True
+        if depth_idx in axis and shape[-1] > max_depth:
+            supported = False
+
+        return supported, f"Depth is {shape[-1]}, shape is {shape}, axis is {axis}"
+
     @staticmethod
     def constraint_reshape_shape_constant(op):
         "Shape must be constant"