MLBEDSW-6870 Optimisations for PReLU

Added optimisations for PReLU when the alpha values allows it.

Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Change-Id: Iff9124e691663ee495379f89900e7c35dbc5f948
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 38e3f60..aaa778e 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -753,10 +753,74 @@
         if None in (ifm, alpha, ofm):
             return op
 
+        if alpha.values is not None:
+            # If const alpha check for possible optimisations
+            alpha_zp = alpha.quantization.zero_point
+            alpha_scale = alpha.quantization.scale_f32
+            # If all alpha values are the same the PReLU can be converted to LeakyRelu
+            alpha_min = (alpha.values.min().astype(np.int) - alpha_zp) * alpha_scale
+            alpha_max = (alpha.values.max().astype(np.int) - alpha_zp) * alpha_scale
+            if alpha_min == alpha_max:
+                # or even a Relu
+                if alpha_min == 0:
+                    new_op = Op.Relu
+                else:
+                    new_op = Op.LeakyRelu
+                    op.attrs["alpha"] = alpha_min
+                    # setup alpha_scaling for bit exact result
+                    ifm_scale = ifm.quantization.scale_f32
+                    ofm_scale = ofm.quantization.scale_f32
+                    alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha_scale, ofm_scale)
+                    op.attrs["alpha_scaling"] = (alpha.values.min() - alpha_zp, alpha_scale, alpha_shift)
+                # Change op type
+                op.type = new_op
+                op.name = op.name.replace("Prelu", new_op.name)
+                del op.inputs[1]  # Remove alpha tensor
+                return op
+            elif alpha_max < 1:
+                # If alpha_max is less than 1 convert PReLU to Max(alpha * IFM, identity * IFM)
+                # Multiply with alpha tensor
+                mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
+                mul_alpha.add_input_tensor(ifm)
+                mul_alpha.add_input_tensor(alpha)
+                fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
+                mul_alpha.set_output_tensor(fm_alpha)
+                mul_alpha.set_ifm_ofm_shapes()
+                DebugDatabase.add_optimised(op, mul_alpha)
+                if check_quantized_tens_scaling_equal(ifm, ofm):
+                    # No scaling is needed
+                    fm_id = ifm
+                else:
+                    # Add multiplication with identity
+                    mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
+                    mul_identity.add_input_tensor(ifm)
+                    # Create const tensor containing identity as scalar
+                    quantization = ifm.quantization.clone()
+                    quantization.scale_f32 = np.float32(1)
+                    quantization.zero_point = 0
+                    one = create_const_tensor("one_const", [], ifm.dtype, [1], quantization=quantization)
+                    mul_identity.add_input_tensor(one)
+                    # Make sure that fm_id is allocated to a different address than fm_alpha
+                    fm_id = ofm.clone(op.name + "_id", set_unique=True)
+                    mul_identity.set_output_tensor(fm_id)
+                    mul_identity.set_ifm_ofm_shapes()
+
+                # Combine scaled and alpha multiplied values
+                max_op = Operation(Op.Maximum, op.name + "_max")
+                max_op.add_input_tensor(fm_alpha)
+                max_op.add_input_tensor(fm_id)
+                max_op.set_output_tensor(ofm)
+                max_op.set_ifm_ofm_shapes()
+
+                DebugDatabase.add_optimised(op, max_op)
+                ifm.consumer_list.remove(op)
+                return max_op
+
+        # Catch all PReLU conversion for the cases that could not be optimised above
         no_scale_quant = ifm.quantization.clone()
         no_scale_quant.scale_f32 = None
         no_scale_quant.zero_point = 0
-        zero = create_const_tensor("zero_const", [1, 1, 1, 1], ifm.dtype, [0], quantization=no_scale_quant)
+        zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
 
         # Select values < 0
         min_op = Operation(Op.Minimum, op.name + "_min")
@@ -816,7 +880,12 @@
             mul = muls[0].ops[0]
         elif len(muls) == 2:
             # In the case both inputs are Muls, find the one with the same input as the Max
-            mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
+            mul_ifms = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1]
+            if len(mul_ifms):
+                mul = mul_ifms[0].ops[0]
+            else:
+                # Not using same input
+                return op
         else:
             # No Mul inputs
             return op
@@ -954,17 +1023,19 @@
     quantization.min = 0
     quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
     quantization.zero_point = 0
-    if np.isinf(1 / alpha):
+    if "alpha_scaling" in op.attrs:
+        # The LeakyRelu was the result from convert_prelu
+        scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
+        mul_alpha.type = Op.RescaleMul
+        mul_alpha.rescale = [alpha_scale, alpha_shift]
+    elif np.isinf(1 / alpha):
         # Handling of alpha near zero
         quantization.scale_f32 = np.float32(1)
         scalar = 0
     else:
         quantization.scale_f32 = alpha
-        scalar = alpha
-    alpha_tens = create_const_tensor(
-        op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
-    )
-    alpha_tens.values = np.array([1])
+        scalar = 1
+    alpha_tens = create_const_tensor(op.name + "_alpha_scalar", [], ifm.dtype, [scalar], quantization=quantization)
     mul_alpha.add_input_tensor(alpha_tens)
     fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
     mul_alpha.set_output_tensor(fm_alpha)