[MLBEDSW-3227] Improve u65 softmax performance

Improve u65 softmax performance by selecting more feature map
tensors as SRAM candidates.

Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Change-Id: I239c9dbebbf2a929004eb01bb0f3efe77f5b97aa
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 59c2b58..56f4aaa 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -958,6 +958,9 @@
             # to be moved to fast storage
             fast_storage_tensor_rewrites = {}
             last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op
+            # Memory only passes have no primary_op, so use the last op in ops
+            if last_op_in_subgraph is None:
+                last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].ops[-1]
             for ps in self.sg.cascaded_passes:
                 if ps.placement != PassPlacement.Npu:
                     continue
@@ -976,25 +979,32 @@
                         if op.type == Op.ReduceSum and output.dtype == DataType.int32:
                             use_NHCWB16 = False
                         elif op.type == Op.Reshape:
-                            # Detect no-op reshapes by comparing their full input and output tensor shapes.
-                            inshape = full_shape(4, op.inputs[0].shape, 1)
-                            outshape = full_shape(4, op.outputs[0].shape, 1)
                             # Using NHCWB16 format for a no-op reshape is only an option if subsequent
                             # consumers do not also need to perform a reshape or if the OFM is going to
                             # be processed by CPU operations. No-op reshape consumers with empty lists
                             # (those that have no consumers, or null-consumers used as list terminators)
                             # must use normal NHWC output.
-                            incompatible_consumers = [
-                                (
-                                    not consumer.run_on_npu
-                                    or consumer.type == Op.Reshape
-                                    or (consumer is last_op_in_subgraph)
-                                )
-                                for consumer in op.outputs[0].consumer_list
-                                if consumer is not None
-                            ]
-                            if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers):
-                                rewrites.append(op)
+                            def incompatible_consumers(oper):
+                                if oper and oper.type == Op.Reshape:
+                                    for consumer in oper.outputs[0].consumer_list:
+                                        yield from incompatible_consumers(consumer)
+                                yield not oper or not oper.run_on_npu or oper is last_op_in_subgraph
+
+                            if not any(incompatible_consumers(op)):
+
+                                def get_rewrites(oper):
+                                    if oper and oper.type == Op.Reshape:
+                                        for consumer in oper.outputs[0].consumer_list:
+                                            yield from get_rewrites(consumer)
+                                        yield oper
+
+                                rewrites.extend(get_rewrites(op))
+                                # Detect no-op reshapes by comparing their full input and output tensor shapes.
+                                inshape = full_shape(4, op.inputs[0].shape, 1)
+                                compatible_shape = [
+                                    (inshape == full_shape(4, oper.outputs[0].shape, 1)) for oper in get_rewrites(op)
+                                ]
+                                use_NHCWB16 = compatible_shape and all(compatible_shape)
                             else:
                                 use_NHCWB16 = False
                                 use_fast_storage = False
@@ -1069,7 +1079,7 @@
     tens_list = sorted([(len(tens_to_cps[tens]), -tens.storage_size(), tens.name, tens) for tens in tens_to_cps])
     for _, _, _, tens in tens_list:
         cps_list = tens_to_cps[tens]
-        if len(cps_list) <= 1:
+        if len(cps_list) < 1:
             continue
         sz = tens.storage_size()
         fits_in_fast_storage = all([cps.sram_used + sz <= sram_limit for cps in cps_list])