MLBEDSW-2688: LeakyRelu rewrite to LUT or MUL/MAX

Replaces LeakyRelu operations with LUT activation function when possible,
else to a combination of multiplication/maximization.

Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Change-Id: I3d2eb2dba7145997c3cc711d0ef18ab355fbb416
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 78c0dcd..8d920d8 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -20,6 +20,7 @@
 
 import numpy as np
 
+from . import lut
 from . import rewrite_graph
 from .data_type import DataType
 from .errors import UnsupportedFeatureError
@@ -585,6 +586,12 @@
         # make sure the Mul doesn't have a faf
         if mul.attrs["fused_activation_function"]:
             return op
+        ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
+        if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
+            return op
+        if not ifm.is_scaling_equal(ofm):
+            # rewrite to LeakyRelu currently only makes sense if the quantization is identical
+            return op
 
         # finds the branched input that goes to both the Max and the Mul
         shared = set(op.inputs) & set(mul.inputs)
@@ -599,6 +606,8 @@
             # check that it is a constant
             if const.type != "Const":
                 return op
+            # Remove the Mul from the shared input's consumers
+            shared_in.consumer_list.remove(mul)
         else:
             return op
 
@@ -618,6 +627,147 @@
     return op
 
 
+def convert_lrelu_to_mul_max(op, arch):
+    # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
+    # (the opposite of convert_mul_max_to_abs_or_lrelu)
+    ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
+
+    # Add multiplication with alpha
+    mul_alpha = Operation("MulAct", op.name + "_mul_alpha")
+    mul_alpha.add_input_tensor(ifm)
+    # Create const tensor containing alpha as scalar
+    alpha = op.attrs["alpha"]
+    quantization = ifm.quantization.clone()
+    quantization.min = 0
+    quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
+    quantization.scale_f32 = alpha
+    quantization.zero_point = 0
+    alpha_tens = create_const_tensor(op.name + "_alpha_scalar", [], ifm.dtype, [1], np.int8, quantization=quantization)
+    mul_alpha.add_input_tensor(alpha_tens)
+    fm_alpha = ofm.clone(op.name + "_alpha")
+    mul_alpha.set_output_tensor(fm_alpha)
+
+    if ifm.is_scaling_equal(ofm):
+        # No identity multiplication is needed
+        fm_id = ifm
+    else:
+        # Add multiplication with identity
+        mul_identity = Operation("MulAct", op.name + "_mul_identity")
+        mul_identity.add_input_tensor(ifm)
+        # Create const tensor containing identity as scalar
+        quantization = ifm.quantization.clone()
+        quantization.min = 0
+        quantization.max = quantization.quant_max - quantization.quant_min
+        quantization.scale_f32 = 1
+        quantization.zero_point = 0
+        identity_tens = create_const_tensor(
+            op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
+        )
+        mul_identity.add_input_tensor(identity_tens)
+        fm_id = ofm.clone(op.name + "_id")
+        mul_identity.set_output_tensor(fm_id)
+
+    # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
+    op.type = "Maximum"
+    op.name = op.name.replace("LeakyRelu", "Maximum")
+    op.inputs = []
+    ifm.consumer_list.remove(op)
+    op.add_input_tensor(fm_alpha)
+    op.add_input_tensor(fm_id)
+    return op
+
+
+def convert_lrelu_to_lut(op, arch):
+    ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
+    # Rewrite LeakyRelu by Add with scalar 0 + LUT activation
+    op.type = "AddAct"
+    op.name = op.name + "_add"
+    op.attrs.update({"npu_block_type": NpuBlockType.ElementWise})
+    # Mark as no-op to enable potential fusing optimizations
+    op.attrs["is_nop"] = True
+    # Create an input tensor containing scalar zero
+    quantization = QuantizationParameters(0.0, 255.0)
+    quantization.scale_f32 = 1.0
+    quantization.zero_point = 0
+    tens = create_const_tensor(op.inputs[0].name + "_add", [], ifm.dtype, [0], np.uint8, quantization=quantization)
+    op.add_input_tensor(tens)
+    alpha = op.attrs["alpha"]
+    zp = ofm.quantization.zero_point
+    # Generate the LUT
+    if ifm.dtype.size_in_bytes() == 1:
+        dtype = DataType.int8
+        ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
+        values = [int(x) if x >= zp else int(round(zp - alpha * (zp - x))) for x in ix]
+    else:
+        # int16
+        dtype = DataType.int32
+        values = []
+        for ix in range(512):
+            x = (ix - 256) * 128
+            if x >= zp:
+                base = x
+                slope = 128
+            else:
+                base = int(round(zp - alpha * (zp - x)))
+                next_base = int(round(zp - alpha * (zp - (x + 127))))
+                slope = int(round(128 * (next_base - base) / 127))
+            value = ((slope << 16) & 0xFFFF0000) + (base & 0xFFFF)
+            values.append(value)
+    lut_tensor = lut.create_lut_tensor(op.name + "_lut", values, dtype)
+    op.set_activation_lut(lut_tensor)
+    return op
+
+
+def convert_lrelu(op, arch):
+    # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
+    if op.type != "LeakyRelu":
+        return op
+    ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
+    use_lut = (ifm.is_scaling_equal(ofm)) and (ifm.dtype == ofm.dtype) and ifm.dtype in (DataType.uint8, DataType.int8)
+    if use_lut:
+        return convert_lrelu_to_lut(op, arch)
+    return convert_lrelu_to_mul_max(op, arch)
+
+
+def fuse_activation_function_with_prev(op, arch):
+    # if op is a no-op: attempts to move the activation function to the preceding op
+    if not op.attrs.get("is_nop", False) or op.attrs.get("fused_activation_function", None) is None:
+        return op
+    ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
+    # finds the input(s) to the operation
+    prev_op = ifm.ops[0]
+    # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
+    fuse = (
+        prev_op.run_on_npu
+        and prev_op.attrs["npu_block_type"] != NpuBlockType.Default
+        and len(ifm.ops) == 1
+        and len(prev_op.outputs[0].consumers()) == 1
+        and prev_op.attrs.get("fused_activation_function", None) is None
+        and ifm.is_scaling_equal(ofm)
+    )
+    if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
+        # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
+        # LUT currently only works correctly for elementwise ops
+        fuse = False
+    if fuse and op.activation_lut is not None:
+        # Check if LUT can be used with prev_op
+        prev_ifm, prev_ifm2, _, _ = prev_op.get_ifm_ifm2_weights_ofm()
+        fuse = prev_ifm is not None and prev_ifm.quantization is not None and prev_ifm.is_scaling_equal(ifm)
+        if prev_ifm2 is not None:
+            fuse = fuse and prev_ifm2.quantization is not None and prev_ifm2.is_scaling_equal(ifm)
+    if not fuse:
+        return op
+    # Move the fused activation function + corresponding info to prev_op
+    for attr in ("fused_activation_function", "alpha"):
+        if attr in op.attrs:
+            prev_op.attrs[attr] = op.attrs[attr]
+    if op.activation_lut is not None:
+        prev_op.set_activation_lut(op.activation_lut)
+    # Bypass op
+    prev_op.set_output_tensor(op.outputs[0])
+    return op
+
+
 def add_attrs_to_resizebilinear(op, arch):
     if op.type == "ResizeBilinear" and op.run_on_npu:
         input_tensor = op.inputs[0]
@@ -679,7 +829,8 @@
         reorder_depthwise_weights,
         fixup_resizebilinear,
         add_bias_tensor,
-        # convert_mul_max_to_abs_or_lrelu # TODO: enable optimisation once quantisation issues are resolved
+        convert_mul_max_to_abs_or_lrelu,
+        convert_lrelu,
     ]
 
     for idx, sg in enumerate(nng.subgraphs):
@@ -689,8 +840,10 @@
         )
 
     for idx, sg in enumerate(nng.subgraphs):
-        # remove passthrough tensors
-        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(sg, arch, [remove_passthrough_tensor], [])
+        # remove passthrough tensors and attempt further optimizations
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev]
+        )
 
     if verbose_graph:
         nng.print_graph()