MLBEDSW-2639: Moved the IFM/IFM2 order switch to register cmd stream generator

For binary elementwise ops with broadcasting in first IFM.

Signed-off-by: Jacob Bohlin <jacob.bohlin@arm.com>
Change-Id: I25af67be8d3a852247989bc3ddc8e08e946f6bfa
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 8fb95f0..fab00e0 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -259,20 +259,6 @@
                     ofm_tensor = op.outputs[0]
                 build_pass((op,), ofm_tensor)
 
-    def broadcast_input_check(ps):
-        if len(ps.inputs) == 1 or ps.inputs[0].shape == ps.inputs[1].shape:
-            return
-
-        if ps.inputs[0].shape == [] or ps.inputs[1].shape == []:
-            return
-
-        for idx in range(len(ps.inputs[1].shape)):
-            if ps.inputs[1].shape[idx] != ps.inputs[0].shape[idx] and ps.inputs[0].shape[idx] != 1:
-                return
-
-        ps.inputs[0], ps.inputs[1] = ps.inputs[1], ps.inputs[0]
-        ps.primary_op.inputs[0], ps.primary_op.inputs[1] = ps.primary_op.inputs[1], ps.primary_op.inputs[0]
-
     def build_pass(start_ops_to_process, ofm_tensor=None):
         reverse_ops_list = []
         curr_flags = PassFlags.Empty
@@ -413,14 +399,9 @@
         ps.inputs = ordered_input_list
         ps.intermediates = intermediates
         ps.outputs = list(ops_list[-1].outputs)
-        ps.ifm_tensor = ifm_tensor
 
         # ElementWise operation, 2 IFMs
         if ps.primary_op and ps.primary_op.type in binary_elem_wise_main_ops:
-            # Swap broadcast input if applicable
-            broadcast_input_check(ps)
-
-            # If only 1 input, IFM and IFM2 will be the same tensor
             ps.ifm_tensor = ps.inputs[0]
             ps.ifm2_tensor = ps.inputs[-1]
 
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index e0f114e..471953d 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -330,6 +330,21 @@
     return (explicit_padding[1], explicit_padding[0])
 
 
+def ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
+    if ifm_shape == []:
+        # Scalar needs to be in IFM2
+        return False
+    elif ifm2_shape == []:
+        return True
+
+    for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
+        if ifm != ifm2 and ifm == 1:
+            # Broadcasted FM needs to be in IFM2
+            return False
+
+    return True
+
+
 def generate_register_command_stream(nng, sg, arch, verbose=False):
     emit = CommandStreamEmitter()
 
@@ -472,7 +487,7 @@
                     IFM2Broadcast.ReverseOperandOrder if primary_op.attrs.get("reverse_op_order", False) else 0
                 )
 
-                if cmd.ifm_tensor.shape == []:
+                if not ifm_ifm2_correct_order(cmd.ifm_tensor.shape, cmd.ifm2_tensor.shape):
                     # The scalar has to be the ifm2 tensor so switch the ifms
                     cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
                     cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box