MLBEDSW-6832 PReLU support in Vela

Added PReLU support in graph optimiser.

Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Change-Id: I3a188675e3edcdf0b4a4bfcdd134fda0bf8a560f
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 47f4fe0..54e823a 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -229,7 +229,7 @@
     PadV2 = OperatorInfo()
     Placeholder = OperatorInfo()  # Only used in CPU subgraphs
     Pow = OperatorInfo()
-    Prelu = OperatorInfo()
+    Prelu = OperatorInfo(indices=NNG_IFM_IFM2_INDICES)
     Prod = OperatorInfo()
     Quantize = OperatorInfo(indices=NNG_IFM_INDICES)
     QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index ed8fa1e..3646b01 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -746,6 +746,58 @@
     return op
 
 
+def convert_prelu(op, arch, nng):
+    if op.type == Op.Prelu:
+        ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
+        if None in (ifm, alpha, ofm):
+            return op
+
+        no_scale_quant = ifm.quantization.clone()
+        no_scale_quant.scale_f32 = None
+        no_scale_quant.zero_point = 0
+        zero = create_const_tensor("zero_const", [1, 1, 1, 1], ifm.dtype, [0], quantization=no_scale_quant)
+
+        # Select values < 0
+        min_op = Operation(Op.Minimum, op.name + "_min")
+        min_op.add_input_tensor(ifm)
+        min_op.add_input_tensor(zero)
+        fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
+        min_op.set_output_tensor(fm_negative)
+        min_op.set_ifm_ofm_shapes()
+        DebugDatabase.add_optimised(op, min_op)
+
+        # and multiply with alpha tensor
+        mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
+        mul_alpha.add_input_tensor(fm_negative)
+        mul_alpha.add_input_tensor(alpha)
+        fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
+        mul_alpha.set_output_tensor(fm_alpha)
+        mul_alpha.set_ifm_ofm_shapes()
+        DebugDatabase.add_optimised(op, mul_alpha)
+
+        # Select (and scale) values > 0
+        relu_op = Operation(Op.Relu, op.name + "_relu")
+        relu_op.add_input_tensor(ifm)
+        fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
+        relu_op.set_output_tensor(fm_scaled)
+        relu_op.set_ifm_ofm_shapes()
+        DebugDatabase.add_optimised(op, relu_op)
+
+        # Add scaled and alpha multiplied values (without scaling)
+        add_op = Operation(Op.RescaleAdd, op.name + "_add")
+        add_op.rescale = (1, 0)  # No scale or shift
+        add_op.add_input_tensor(fm_alpha)
+        add_op.add_input_tensor(fm_scaled)
+        add_op.set_output_tensor(ofm)
+        add_op.set_ifm_ofm_shapes()
+
+        DebugDatabase.add_optimised(op, add_op)
+        ifm.consumer_list.remove(op)
+        op = add_op
+
+    return op
+
+
 def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
     r"""Whenever there is a subgraph with this topology:
 
@@ -1648,6 +1700,7 @@
         convert_depthwise_to_conv,
         convert_conv_to_fc,
         convert_softmax,
+        convert_prelu,
         optimise_strided_conv,
         convert_hardswish_to_lut,
         rewrite_fully_connected_input,
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index c515d23..3ccedc7 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -732,7 +732,7 @@
         ),
         TFLITE_NO_INDICES,
     ),
-    BuiltinOperator.PRELU: (Op.Prelu, None, TFLITE_NO_INDICES),
+    BuiltinOperator.PRELU: (Op.Prelu, None, TFLITE_IFM_IFM2_INDICES),
     BuiltinOperator.MAXIMUM: (Op.Maximum, OptionsSerializer("MaximumMinimumOptions"), TFLITE_IFM_IFM2_INDICES),
     BuiltinOperator.ARG_MAX: (
         Op.ArgMax,
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 5d25e37..1915d43 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -123,7 +123,15 @@
             Op.Clip,
         )
     )
-    activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax, Op.HardSwish))
+    activation_ops = relu_ops | set(
+        (
+            Op.Tanh,
+            Op.Sigmoid,
+            Op.Softmax,
+            Op.HardSwish,
+            Op.Prelu,
+        )
+    )
     npu_post_ops = (
         # activation functions
         activation_ops