vela: Move common functionality
There is a repeating pattern of setting the 3 different shapes in a
tensor to a single shape value.
This adds a new function in the tensor class that does this for you.
Changed existing instances of manually setting shape to use this new
function.
Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com>
Change-Id: Ibc74e741ea47cec473e6be42cc102f721ec63b11
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index cb0cc64..23ddf83 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -205,7 +205,7 @@
reshape_op.inputs = [inp, new_shape_tens]
reshape_op.attrs["new_shape"] = desired_shape
reshape_out = inp.clone("_reshaped")
- reshape_out.shape = reshape_out.storage_shape = reshape_out.bandwidth_shape = desired_shape
+ reshape_out.set_all_shapes(desired_shape)
reshape_out.ops = [reshape_op]
reshape_op.outputs = [reshape_out]
@@ -235,7 +235,7 @@
reshape_op.inputs = [inp, new_shape_tens]
reshape_op.attrs["new_shape"] = desired_shape
reshape_out = inp.clone("_reshaped")
- reshape_out.shape = reshape_out.storage_shape = reshape_out.bandwidth_shape = desired_shape
+ reshape_out.set_all_shapes(desired_shape)
reshape_out.ops = [reshape_op]
reshape_op.outputs = [reshape_out]
@@ -308,7 +308,7 @@
reshape_op = Operation("Reshape", reshape_name)
reshape_op.outputs = [out_tens]
reshape_in = out_tens.clone("_reshaped")
- reshape_in.shape = reshape_in.storage_shape = reshape_in.bandwidth_shape = reshape_input_shape
+ reshape_in.set_all_shapes(reshape_input_shape)
reshape_in.ops = [op]
out_tens.ops = [reshape_op]
reshape_op.inputs = [reshape_in, new_shape_tens]
@@ -425,9 +425,7 @@
del op.attrs["depth_multiplier"]
weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2))
- weight_tensor.shape = weight_tensor.storage_shape = weight_tensor.bandwidth_shape = list(
- weight_tensor.quant_values.shape
- )
+ weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
else:
raise UnsupportedFeatureError(
"Unsupported DepthwiseConv2d with depth_multiplier = {}, ifm channels = {}, ofm channels = {}".format(
@@ -441,9 +439,7 @@
if "DepthwiseConv2d" in op.type:
weight_tensor = op.inputs[1]
weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2))
- weight_tensor.shape = weight_tensor.storage_shape = weight_tensor.bandwidth_shape = list(
- weight_tensor.quant_values.shape
- )
+ weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
weight_tensor.weight_transpose_depthwise = True
return op