Vela: Fix diff in mean op

  - Extend ifm/ofm dimensions explicitly in mean op
    This fix a bug when ifm/ofm shape has different dimensions
    e.g. IFM=1x19x18x25 axis=2 OFM=1x19x25,
         the ofm_shape should be 1x19x1x25, not 1x1x19x25
  - Fix wrong weight shape

Change-Id: I269eb71ea56c09deee2aa6c6433d9b2baa98a113
Signed-off-by: Diqing Zhong <diqing.zhong@arm.com>
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 97e30ad..3815eed 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1172,7 +1172,9 @@
         keep_dims = op.attrs.get("keep_dims", False)
         inp, axis = op.inputs
         shape = inp.shape
+        ofm_shape = op.ofm.shape
         dims = len(shape)
+        dims_ofm = len(ofm_shape)
 
         # Height and width axes have different index depending on dimensions
         if axis.shape == [] or axis.shape[0] == 1:  # single axis
@@ -1301,10 +1303,25 @@
             op.forced_input_quantization = fiq
 
         # Change dimensions to 4
-        if dims < 4:
-            shape = [1] + shape
-            if dims == 2:
-                shape += [1]
+        def extend_dims(dim, in_shape):
+            if dim < 4:
+                in_shape = [1] + in_shape
+                if dim == 2:
+                    in_shape += [1]
+            return in_shape
+
+        if dims < 4 or dims_ofm < 4:
+            # Fix the ofm dimension when keep_dims is false
+            # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
+            if isinstance(axis, int) and dims_ofm + 1 == dims:
+                ofm_shape.insert(axis, 1)
+            elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
+                for i in axis:
+                    ofm_shape.insert(i, 1)
+            shape = extend_dims(dims, shape)
+            dims_ofm = len(ofm_shape)
+            ofm_shape = extend_dims(dims_ofm, ofm_shape)
+            op.set_ifm_ofm_shapes()
 
         # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
         if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
@@ -1325,7 +1342,8 @@
         weight_quant.zero_point = 0
 
         # Set weight shape to [H,W,C,B]
-        weight_shape = shape[1:4] + [shape[0]]
+        weight_shape = [h, w, shape[3], shape[0]]
+
         # Add unit weight tensor
         op.set_input_tensor(
             create_const_tensor(