Vela: Fix diff in mean op
- Extend ifm/ofm dimensions explicitly in mean op
This fix a bug when ifm/ofm shape has different dimensions
e.g. IFM=1x19x18x25 axis=2 OFM=1x19x25,
the ofm_shape should be 1x19x1x25, not 1x1x19x25
- Fix wrong weight shape
Change-Id: I269eb71ea56c09deee2aa6c6433d9b2baa98a113
Signed-off-by: Diqing Zhong <diqing.zhong@arm.com>
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 97e30ad..3815eed 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1172,7 +1172,9 @@
keep_dims = op.attrs.get("keep_dims", False)
inp, axis = op.inputs
shape = inp.shape
+ ofm_shape = op.ofm.shape
dims = len(shape)
+ dims_ofm = len(ofm_shape)
# Height and width axes have different index depending on dimensions
if axis.shape == [] or axis.shape[0] == 1: # single axis
@@ -1301,10 +1303,25 @@
op.forced_input_quantization = fiq
# Change dimensions to 4
- if dims < 4:
- shape = [1] + shape
- if dims == 2:
- shape += [1]
+ def extend_dims(dim, in_shape):
+ if dim < 4:
+ in_shape = [1] + in_shape
+ if dim == 2:
+ in_shape += [1]
+ return in_shape
+
+ if dims < 4 or dims_ofm < 4:
+ # Fix the ofm dimension when keep_dims is false
+ # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
+ if isinstance(axis, int) and dims_ofm + 1 == dims:
+ ofm_shape.insert(axis, 1)
+ elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
+ for i in axis:
+ ofm_shape.insert(i, 1)
+ shape = extend_dims(dims, shape)
+ dims_ofm = len(ofm_shape)
+ ofm_shape = extend_dims(dims_ofm, ofm_shape)
+ op.set_ifm_ofm_shapes()
# If height is greater than max kernel height, reshape from HxW to 1x(HxW)
if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
@@ -1325,7 +1342,8 @@
weight_quant.zero_point = 0
# Set weight shape to [H,W,C,B]
- weight_shape = shape[1:4] + [shape[0]]
+ weight_shape = [h, w, shape[3], shape[0]]
+
# Add unit weight tensor
op.set_input_tensor(
create_const_tensor(