MLBEDSW-2935: LUT fusing with preceding operator

Allows fusing of LUT with a preceding operator regardless of
input/output scale.

Change-Id: Ia378adbb3fe61d71299feb085f7313377e0efa39
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index a89f8e6..b9110b8 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -823,28 +823,21 @@
         and len(ifm.ops) == 1
         and len(prev_op.outputs[0].consumers()) == 1
         and prev_op.attrs.get("fused_activation_function", None) is None
-        and ifm.is_scaling_equal(ofm)
     )
     if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
         # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
         # LUT currently only works correctly for elementwise ops
         fuse = False
-    if fuse and op.activation_lut is not None:
-        # Check if LUT can be used with prev_op
-        prev_ifm, prev_ifm2, _, _ = prev_op.get_ifm_ifm2_weights_ofm()
-        fuse = prev_ifm is not None and prev_ifm.quantization is not None and prev_ifm.is_scaling_equal(ifm)
-        if prev_ifm2 is not None:
-            fuse = fuse and prev_ifm2.quantization is not None and prev_ifm2.is_scaling_equal(ifm)
     if not fuse:
         return op
     # Move the fused activation function + corresponding info to prev_op
-    for attr in ("fused_activation_function", "alpha", "forced_output_quantization"):
+    for attr in ("fused_activation_function", "forced_output_quantization"):
         if attr in op.attrs:
             prev_op.attrs[attr] = op.attrs[attr]
     if op.activation_lut is not None:
         prev_op.set_activation_lut(op.activation_lut)
     # Bypass op
-    prev_op.set_output_tensor(op.outputs[0])
+    prev_op.set_output_tensor(ofm)
     return op
 
 
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 4b83b39..e7fd97c 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -200,6 +200,10 @@
 
         return ifm_tensor, ifm2_tensor, weight_tensor, bias_tensor, ofm_tensor
 
+    def get_ofm(self):
+        _, _, _, ofm = self.get_ifm_ifm2_weights_ofm()
+        return ofm
+
     def is_concat_op(self):
         return self.type in ("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped")
 
@@ -361,3 +365,6 @@
             "Conv2DBackpropInputSwitchedBias",
             "FullyConnectedAct",
         )
+
+    def get_output_quantization(self):
+        return self.attrs.get("forced_output_quantization", self.get_ofm().quantization)
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 0a35647..8f34e63 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -909,7 +909,11 @@
                 if tens is None:
                     continue
 
-                need_zero_point = (faf is not None) or (fmf == "ConcatSliceWrite") or fused_quantize
+                need_zero_point = (
+                    (faf is not None and forced_ofm_quantization is None)
+                    or (fmf == "ConcatSliceWrite")
+                    or fused_quantize
+                )
                 if (
                     (
                         primary_op.type in set(("AvgPool", "AvgPoolAct", "ResizeBilinear", "CLZ", "SHL"))
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 175646b..2374cd4 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -416,13 +416,13 @@
     first_consumer_op = tens.consumer_list[0]
     ifm_dtype = first_consumer_op.inputs[0].dtype
     ifm_scale = first_consumer_op.inputs[0].quantization.scale_f32
-    ofm_scale = first_consumer_op.outputs[0].quantization.scale_f32
+    ofm_scale = first_consumer_op.get_output_quantization().scale_f32
     weight_scales = first_consumer_op.inputs[1].quantization.scale_f32
 
     # biases can have multiple consumers for rnn cells. if so, then check that they are all the same
     for op in tens.consumer_list[1:]:
         assert ifm_scale == op.inputs[0].quantization.scale_f32
-        assert ofm_scale == op.outputs[0].quantization.scale_f32
+        assert ofm_scale == op.get_output_quantization().scale_f32
         assert weight_scales == op.inputs[1].quantization.scale_f32
 
     if not hasattr(weight_scales, "__iter__"):