MLBEDSW-2506: Swap broadcast input if applicable

Signed-off-by: Charles Xu <charles.xu@arm.com>
Change-Id: I6e8a97486aa2e1a21101f7cc32cd3024a376162a
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 4cfac33..fff192d 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -256,6 +256,20 @@
                     ofm_tensor = op.outputs[0]
                 build_pass((op,), ofm_tensor)
 
+    def broadcast_input_check(ps):
+        if len(ps.inputs) == 1 or ps.inputs[0].shape == ps.inputs[1].shape:
+            return
+
+        if ps.inputs[0].shape == [] or ps.inputs[1].shape == []:
+            return
+
+        for idx in range(len(ps.inputs[1].shape)):
+            if ps.inputs[1].shape[idx] != ps.inputs[0].shape[idx] and ps.inputs[0].shape[idx] != 1:
+                return
+
+        ps.inputs[0], ps.inputs[1] = ps.inputs[1], ps.inputs[0]
+        ps.primary_op.inputs[0], ps.primary_op.inputs[1] = ps.primary_op.inputs[1], ps.primary_op.inputs[0]
+
     def build_pass(start_ops_to_process, ofm_tensor=None):
         reverse_ops_list = []
         curr_flags = PassFlags.Empty
@@ -400,6 +414,9 @@
 
         # ElementWise operation, 2 IFMs
         if ps.primary_op and ps.primary_op.type in binary_elem_wise_main_ops:
+            # Swap broadcast input if applicable
+            broadcast_input_check(ps)
+
             ps.ifm_tensor = ps.inputs[0]
 
             if len(ps.inputs) == 1: