MLBEDSW-7449: Add function description and type annotations
Add function description and type annotations to the optimization
functions missing them.
Fix type annotation issue when re-assigning variable
value to a different type.
Change-Id: I1ee442ff7a29cc07708fdd013430131eff599dd5
Signed-off-by: Raul Farkas <raul.farkas@arm.com>
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index 220ba1a..c0099ff 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -26,6 +26,7 @@
from .errors import UnsupportedFeatureError
from .errors import VelaError
from .operation import Op
+from .operation import Operation
from .operation_util import create_avgpool_nop
from .shape4d import Shape4D
from .tensor import Tensor
@@ -192,8 +193,8 @@
return max(filter_size - (input_size % stride), 0)
-# Set input/output tensor equivalence to the same id for memory operations
-def set_tensor_equivalence(op, arch, nng):
+def set_tensor_equivalence(op: Operation, arch, nng) -> Operation:
+ """Set input/output tensor equivalence to the same id for memory operations."""
if op.type in memory_only_ops:
eid = op.outputs[0].equivalence_id
for inp in op.inputs:
@@ -300,16 +301,16 @@
return op
-def convert_depthwise_to_conv(op, arch, nng):
- # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and
- # the ofm depth equals the depth multipler.
- # If those conditions are true, then we can perform a simple
- # switch of the operator type (and weight order)
-
+def convert_depthwise_to_conv(op: Operation, arch, nng) -> Operation:
+ """Convert DepthwiseConv2DBias to Conv2D to allow support for DepthwiseConv2DBias ops with 'depth multiplier' > 1,
+ as long as IFM depth = 1 and OFM depth is equal to the depth multiplier.
+ """
if op.type == Op.DepthwiseConv2DBias and (op.attrs["depth_multiplier"] != 1):
ifm_shape = op.ifm_shapes[0]
weight_tensor = op.inputs[1]
ofm_shape = op.ofm_shapes[0]
+ # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and
+ # the ofm depth equals the depth multipler.
if (ifm_shape.depth == 1) and (ofm_shape.depth == op.attrs["depth_multiplier"]):
# Change op type to Conv2d
op.type = Op.Conv2DBias
@@ -321,8 +322,8 @@
DebugDatabase.add_optimised(op, op)
else:
raise UnsupportedFeatureError(
- f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},",
- f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}",
+ f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},"
+ f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}"
)
return op