MLBEDSW-6832 PReLU support in Vela

Added PReLU support in graph optimiser.

Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Change-Id: I3a188675e3edcdf0b4a4bfcdd134fda0bf8a560f
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index ed8fa1e..3646b01 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -746,6 +746,58 @@
     return op
 
 
+def convert_prelu(op, arch, nng):
+    if op.type == Op.Prelu:
+        ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
+        if None in (ifm, alpha, ofm):
+            return op
+
+        no_scale_quant = ifm.quantization.clone()
+        no_scale_quant.scale_f32 = None
+        no_scale_quant.zero_point = 0
+        zero = create_const_tensor("zero_const", [1, 1, 1, 1], ifm.dtype, [0], quantization=no_scale_quant)
+
+        # Select values < 0
+        min_op = Operation(Op.Minimum, op.name + "_min")
+        min_op.add_input_tensor(ifm)
+        min_op.add_input_tensor(zero)
+        fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
+        min_op.set_output_tensor(fm_negative)
+        min_op.set_ifm_ofm_shapes()
+        DebugDatabase.add_optimised(op, min_op)
+
+        # and multiply with alpha tensor
+        mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
+        mul_alpha.add_input_tensor(fm_negative)
+        mul_alpha.add_input_tensor(alpha)
+        fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
+        mul_alpha.set_output_tensor(fm_alpha)
+        mul_alpha.set_ifm_ofm_shapes()
+        DebugDatabase.add_optimised(op, mul_alpha)
+
+        # Select (and scale) values > 0
+        relu_op = Operation(Op.Relu, op.name + "_relu")
+        relu_op.add_input_tensor(ifm)
+        fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
+        relu_op.set_output_tensor(fm_scaled)
+        relu_op.set_ifm_ofm_shapes()
+        DebugDatabase.add_optimised(op, relu_op)
+
+        # Add scaled and alpha multiplied values (without scaling)
+        add_op = Operation(Op.RescaleAdd, op.name + "_add")
+        add_op.rescale = (1, 0)  # No scale or shift
+        add_op.add_input_tensor(fm_alpha)
+        add_op.add_input_tensor(fm_scaled)
+        add_op.set_output_tensor(ofm)
+        add_op.set_ifm_ofm_shapes()
+
+        DebugDatabase.add_optimised(op, add_op)
+        ifm.consumer_list.remove(op)
+        op = add_op
+
+    return op
+
+
 def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
     r"""Whenever there is a subgraph with this topology:
 
@@ -1648,6 +1700,7 @@
         convert_depthwise_to_conv,
         convert_conv_to_fc,
         convert_softmax,
+        convert_prelu,
         optimise_strided_conv,
         convert_hardswish_to_lut,
         rewrite_fully_connected_input,