MLBEDSW-4501: Support MEAN single axis variation

When a MEAN operator with a single reduction axis
specifies the axis index attribute as an array with
a single element rather than a scalar index, the
operator is placed on the CPU even though it is
technically supported.
This commit fixes this issue and also adds some new
tests for the axis constraints.

Signed-off-by: Dwight Lidman <dwight.lidman@arm.com>
Change-Id: Ia287f3b9cc80a805e972cd4b2962e52526a8dc16
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 642f134..7c60368 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -1472,8 +1472,8 @@
         dims = len(shape)
 
         # Height and width axes have different index depending on dimensions
-        if axis.shape == []:  # single axis
-            axis = int(axis.values)
+        if len(axis.shape) <= 1:  # single axis
+            axis = int(axis.values) if len(axis.shape) == 0 else axis.values[0]
             if dims in (2, 3):
                 if axis == 0:
                     h, w = shape[axis], 1
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 5bf2c45..dfa2719 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -1040,11 +1040,11 @@
     def constraint_mean_axis(op):
         "Axis indices must correspond to height and width axes"
         dims = len(op.inputs[0].shape)
-        axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values)
+        axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
         if dims == 2 or dims == 3:
-            valid = axis in (0, 1, [0, 1], [1, 0])
+            valid = axis in (0, 1, [0], [1], [0, 1], [1, 0])
         elif dims == 4:
-            valid = axis in (1, 2, [1, 2], [2, 1])
+            valid = axis in (1, 2, [1], [2], [1, 2], [2, 1])
         return valid, f"Axis is {axis}"
 
     @classmethod
@@ -1082,7 +1082,7 @@
         keep_dims is set to True and
         IFM datatype is int8"""
         shape = op.ifm.shape
-        axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values)
+        axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
         # doesn't apply, size is checked by constraint_mean_height_width_product_avgpool
         # and constraint_mean_height_width_product
         if (
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 355b472..666a5ec 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -840,12 +840,15 @@
     assert support.is_operator_supported(op)
 
 
-def create_mean(input_shape, output_shape, indices, datatype, attrs):
+def create_mean(input_shape, output_shape, axis, datatype, attrs):
     ifm = Tensor(input_shape, datatype, "in")
     ifm.quantization = testutil.default_quant_params()
-    indices = create_const_tensor("indices", [len(indices)], DataType.int32, indices, np.uint8)
     ofm = Tensor(output_shape, datatype, "out")
     ofm.quantization = testutil.default_quant_params()
+    if type(axis) is list:
+        indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis, np.uint8)
+    elif type(axis) is int:
+        indices = create_const_tensor("indices", [], DataType.int32, axis, np.uint8)
     op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
     return op
 
@@ -859,8 +862,22 @@
 
 
 def test_mean_axis():
-    op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
+    op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True})
     assert not support.is_operator_supported(op)
+    op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True})
+    assert not support.is_operator_supported(op)
+    op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True})
+    assert not support.is_operator_supported(op)
+    op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
+    assert not support.is_operator_supported(op)
+    op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+    assert support.is_operator_supported(op)
+    op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
+    assert support.is_operator_supported(op)
+    op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 2, DataType.int8, {"keep_dims": True})
+    assert support.is_operator_supported(op)
+    op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [2, 1], DataType.int8, {"keep_dims": True})
+    assert support.is_operator_supported(op)
 
 
 def test_mean_hw_product():