MLBEDSW-7449: Add function description and type annotations

Add function description and type annotations to the optimization
functions missing them.
Fix type annotation issue when re-assigning variable
value to a different type.

Change-Id: I1ee442ff7a29cc07708fdd013430131eff599dd5
Signed-off-by: Raul Farkas <raul.farkas@arm.com>
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index 220ba1a..c0099ff 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -26,6 +26,7 @@
 from .errors import UnsupportedFeatureError
 from .errors import VelaError
 from .operation import Op
+from .operation import Operation
 from .operation_util import create_avgpool_nop
 from .shape4d import Shape4D
 from .tensor import Tensor
@@ -192,8 +193,8 @@
     return max(filter_size - (input_size % stride), 0)
 
 
-# Set input/output tensor equivalence to the same id for memory operations
-def set_tensor_equivalence(op, arch, nng):
+def set_tensor_equivalence(op: Operation, arch, nng) -> Operation:
+    """Set input/output tensor equivalence to the same id for memory operations."""
     if op.type in memory_only_ops:
         eid = op.outputs[0].equivalence_id
         for inp in op.inputs:
@@ -300,16 +301,16 @@
     return op
 
 
-def convert_depthwise_to_conv(op, arch, nng):
-    # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and
-    # the ofm depth equals the depth multipler.
-    # If those conditions are true, then we can perform a simple
-    # switch of the operator type (and weight order)
-
+def convert_depthwise_to_conv(op: Operation, arch, nng) -> Operation:
+    """Convert DepthwiseConv2DBias to Conv2D to allow support for DepthwiseConv2DBias ops with 'depth multiplier' > 1,
+    as long as IFM depth = 1 and OFM depth is equal to the depth multiplier.
+    """
     if op.type == Op.DepthwiseConv2DBias and (op.attrs["depth_multiplier"] != 1):
         ifm_shape = op.ifm_shapes[0]
         weight_tensor = op.inputs[1]
         ofm_shape = op.ofm_shapes[0]
+        # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and
+        # the ofm depth equals the depth multipler.
         if (ifm_shape.depth == 1) and (ofm_shape.depth == op.attrs["depth_multiplier"]):
             # Change op type to Conv2d
             op.type = Op.Conv2DBias
@@ -321,8 +322,8 @@
             DebugDatabase.add_optimised(op, op)
         else:
             raise UnsupportedFeatureError(
-                f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},",
-                f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}",
+                f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},"
+                f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}"
             )
     return op
 
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 21c02f3..28dead1 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -264,7 +264,7 @@
     return padding, skirt
 
 
-def fixup_conv2d_backprop(op, arch, nng):
+def fixup_conv2d_backprop(op: Operation, arch, nng) -> Operation:
     if op.type == Op.Conv2DBackpropInput:
         # flip the inputs
         op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
@@ -455,7 +455,7 @@
     return op
 
 
-def convert_argmax_to_depthwise_conv_and_max_pool(op, arch, nng):
+def convert_argmax_to_depthwise_conv_and_max_pool(op: Operation, arch, nng) -> Operation:
     """
     Convert ArgMax to DWConv2D->MaxPool->DWConv2D, see details below.
 
@@ -762,7 +762,8 @@
     return op
 
 
-def fixup_resize(op, arch, nng):
+def fixup_resize(op: Operation, arch, nng) -> Operation:
+    """Fixup resize ops to increase support for ResizeNearestNeighbor cases."""
     if op.type.is_resize_op() and op.run_on_npu:
         if op.ifm_shapes[0] == op.ofm_shapes[0]:
             # Bypass the resize op which is essentially a NOP
@@ -787,7 +788,8 @@
     return op
 
 
-def rewrite_fully_connected_input(op: Operation, arch, nng):
+def rewrite_fully_connected_input(op: Operation, arch, nng) -> Operation:
+    """Rewrite FullyConnected shape as 2D to allow it to run on NPU."""
     # If the operation already have a read shape do not modify
     # the ifm shape, since that will already be correct
     if op.type == Op.FullyConnected and not op.read_shapes[0]:
@@ -803,7 +805,8 @@
     return op
 
 
-def convert_batched_fc_shape(op, arch, nng):
+def convert_batched_fc_shape(op: Operation, arch, nng) -> Operation:
+    """Convert batched FullyConnected op shape to allow for support on NPU."""
     if op.type == Op.FullyConnected:
         # Check if the first dimension indicates batching
         if op.ifm_shapes[0].batch > 1:
@@ -940,7 +943,7 @@
     return op
 
 
-def reorder_depthwise_weights(op, arch, nng):
+def reorder_depthwise_weights(op: Operation, arch, nng) -> Operation:
     if op.type.is_depthwise_conv2d_op():
         weight_tensor = op.inputs[1]
         if not weight_tensor.weight_transpose_depthwise:
@@ -1159,10 +1162,11 @@
     return op
 
 
-def convert_conv_to_fc(op, arch, nng):
+def convert_conv_to_fc(op: Operation, arch, nng) -> Operation:
+    """Convert 1x1 Conv2D that behave like FullyConnected to FullyConnected, since they don't need any weight
+    buffering.
+    """
     # Conv 1x1 can be equivalent to Fully Connected.
-    # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
-    # caching/double buffering for the weights.
     # (Weights dont need to be reloaded for convs when IFM H and W are 1)
     if op.type == Op.Conv2DBias:
         h = op.ifm_shapes[0].height
@@ -1184,7 +1188,8 @@
     return op
 
 
-def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
+def fixup_relus_with_differing_ifm_ofm_scaling(op: Operation, arch, nng) -> Operation:
+    """Fixup Relu with different IFM and OFM to allow fusing by adding its own primary op."""
     if op.run_on_npu and op.type.is_relu_op():
         ifm = op.inputs[0]
         ofm = op.outputs[0]
@@ -1209,21 +1214,24 @@
     return op
 
 
-def convert_lstm(op, arch, nng):
+def convert_lstm(op: Operation, arch, nng) -> Operation:
+    """Convert LSTM op into its basic opearations to allow for support on NPU."""
     if op.type == Op.UnidirectionalSequenceLstm:
         lstm = Lstm(op)
         op = lstm.get_graph()
     return op
 
 
-def convert_softmax(op, arch, nng):
+def convert_softmax(op: Operation, arch, nng) -> Operation:
+    """Convert Softmax op into its basic operations to allow for support on NPU."""
     if op.type == Op.Softmax and op.run_on_npu:
         softmax = SoftMax(op)
         op = softmax.get_graph()
     return op
 
 
-def convert_prelu(op, arch, nng):
+def convert_prelu(op: Operation, arch, nng) -> Operation:
+    """Convert PReLU op to other ops based on alpha values to allow for support on NPU."""
     if op.type == Op.Prelu:
         ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
         if None in (ifm, alpha, ofm):
@@ -1340,7 +1348,7 @@
     return op
 
 
-def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
+def convert_mul_max_to_abs_or_lrelu(op: Operation, arch, nng) -> Operation:
     r"""Whenever there is a subgraph with this topology:
 
     Input    X   For X = -1 or X > 0
@@ -1432,7 +1440,8 @@
     return op
 
 
-def convert_hardswish_to_lut(op, arch, nng):
+def convert_hardswish_to_lut(op: Operation, arch, nng) -> Operation:
+    """Convert HardSwish to LUT to allow for support on NPU."""
     if op.type == Op.HardSwish:
         ifm, ofm = op.get_ifm_ofm()
         # Generate the LUT
@@ -1645,8 +1654,8 @@
     return convert_to_lut(op, values, "lrelu")
 
 
-def convert_lrelu(op, arch, nng):
-    # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
+def convert_lrelu(op: Operation, arch, nng) -> Operation:
+    """Convert LeakyRelu to a LUT based solution if possible, otherwise a mul + max."""
     if op.type != Op.LeakyRelu:
         return op
     ifm, ofm = op.get_ifm_ofm()
@@ -1667,8 +1676,8 @@
     return convert_lrelu_to_mul_max(op, arch)
 
 
-def convert_tanh_sigmoid_to_lut(op, arch, nng):
-    # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
+def convert_tanh_sigmoid_to_lut(op: Operation, arch, nng) -> Operation:
+    """Convert int8/uint8 Sigmoid and Tanh to a LUT based solution."""
     if op.type == Op.Sigmoid:
         return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
     elif op.type == Op.Tanh:
@@ -1717,7 +1726,7 @@
     return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
 
 
-def replace_pad_by_hw_pad(op: Operation, arch, nng):
+def replace_pad_by_hw_pad(op: Operation, arch, nng) -> Operation:
     """
     Tries to completely remove a PAD operator by using hardware padding.
     E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
@@ -1889,7 +1898,8 @@
     return avgpool_op
 
 
-def fixup_bias_tensors(op, arch, nng, dtype=None):
+def fixup_bias_tensors(op: Operation, arch, nng, dtype=None) -> Operation:
+    """Fixup ops that require a bias and don't have one by adding a bias tensor filled with zeros."""
     if op.type.needs_bias() and op.bias is None:
         # Op has no bias, add bias tensor filled with zeros
         nr_biases = op.inputs[1].shape[-1]
@@ -1924,7 +1934,7 @@
     return False
 
 
-def fixup_asymmetric_weights(op, arch, nng):
+def fixup_asymmetric_weights(op: Operation, arch, nng) -> Operation:
     if detect_asymmetric_weights(op):
         if op.run_on_npu:
             print("Zero points have been adjusted.")
@@ -2180,7 +2190,8 @@
     return op
 
 
-def convert_ops_to_lut(op, arch, nng):
+def convert_ops_to_lut(op: Operation, arch, nng) -> Operation:
+    """Convert Exp to 8bit or 16bit LUT to allow for support on NPU."""
     if op.type == Op.Exp:
         if op.ifm.dtype == DataType.int8:
             return create_lut_8bit_op(op, math.exp, "exp")
@@ -2290,7 +2301,8 @@
     return op
 
 
-def fixup_dilation_gt2(op, arch, nng):
+def fixup_dilation_gt2(op: Operation, arch, nng) -> Operation:
+    """Fixup Conv2DBias and DepthwiseConv2DBias to allow dilation greater than 2."""
     assert op.run_on_npu
     if op.type == Op.Conv2DBias or op.type == Op.DepthwiseConv2DBias:
         dilation_w, dilation_h = op.get_kernel_dilation()