MLBEDSW-2337: Intermediate feature maps in fast storage

Attempts to use fast storage for feature maps used in between
cascaded passes.

This is only relevant for system configurations where feature maps
are by default not placed in SRAM, but there is SRAM for fast storage.

Change-Id: I207b7cf32cfcb5bea3e6b93c2da1161c4af5221d
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 24453d8..5c2ddab 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -959,52 +959,66 @@
         self.sg.cascaded_passes = cascaded_passes
         self.sg.build_cascaded_pass_links()
 
-        if self.options.use_nhcwb16_between_cascaded_passes:
-            # Check if NHCWB16 can be used in between cascaded passes
-            # (NHCWB16 within cascaded passes has been handled earlier in this function)
-            if self.sg.placement == PassPlacement.Npu:
-                last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op
-                for ps in self.sg.cascaded_passes:
-                    if ps.placement != PassPlacement.Npu:
+        # Check if NHCWB16 and/or fast storage can be used in between cascaded passes
+        # (NHCWB16 within cascaded passes has been handled earlier in this function)
+        if self.sg.placement == PassPlacement.Npu:
+            # Dictionary tensor -> list of ops, containing feature maps that can be attempted
+            # to be moved to fast storage
+            fast_storage_tensor_rewrites = {}
+            last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op
+            for ps in self.sg.cascaded_passes:
+                if ps.placement != PassPlacement.Npu:
+                    continue
+                for output in ps.outputs:
+                    if output.purpose != TensorPurpose.FeatureMap or output.avoid_NHCWB16:
                         continue
-                    for output in ps.outputs:
-                        if output.purpose != TensorPurpose.FeatureMap or output.avoid_NHCWB16:
+
+                    use_NHCWB16 = True
+                    use_fast_storage = True
+                    rewrites = []
+                    for op in output.consumer_list:
+                        if op is None:
+                            use_NHCWB16 = False
+                            use_fast_storage = False
                             continue
-
-                        use_NHCWB16 = True
-                        rewrites = []
-                        for op in output.consumer_list:
-                            if op is None or (op.type == "ReduceSum" and output.dtype == DataType.int32):
-                                use_NHCWB16 = False
-                            elif op.type == "Reshape":
-                                # Detect no-op reshapes by comparing their full input and output tensor shapes.
-                                inshape = full_shape(4, op.inputs[0].shape, 1)
-                                outshape = full_shape(4, op.outputs[0].shape, 1)
-                                # Using NHCWB16 format for a no-op reshape is only an option if subsequent
-                                # consumers do not also need to perform a reshape or if the OFM is going to
-                                # be processed by CPU operations. No-op reshape consumers with empty lists
-                                # (those that have no consumers, or null-consumers used as list terminators)
-                                # must use normal NHWC output.
-                                incompatible_consumers = [
-                                    (
-                                        not consumer.run_on_npu
-                                        or consumer.type == "Reshape"
-                                        or (consumer is last_op_in_subgraph)
-                                    )
-                                    for consumer in op.outputs[0].consumer_list
-                                    if consumer is not None
-                                ]
-                                if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers):
-                                    rewrites.append(op)
-                                else:
-                                    use_NHCWB16 = False
+                        if op.type == "ReduceSum" and output.dtype == DataType.int32:
+                            use_NHCWB16 = False
+                        elif op.type == "Reshape":
+                            # Detect no-op reshapes by comparing their full input and output tensor shapes.
+                            inshape = full_shape(4, op.inputs[0].shape, 1)
+                            outshape = full_shape(4, op.outputs[0].shape, 1)
+                            # Using NHCWB16 format for a no-op reshape is only an option if subsequent
+                            # consumers do not also need to perform a reshape or if the OFM is going to
+                            # be processed by CPU operations. No-op reshape consumers with empty lists
+                            # (those that have no consumers, or null-consumers used as list terminators)
+                            # must use normal NHWC output.
+                            incompatible_consumers = [
+                                (
+                                    not consumer.run_on_npu
+                                    or consumer.type == "Reshape"
+                                    or (consumer is last_op_in_subgraph)
+                                )
+                                for consumer in op.outputs[0].consumer_list
+                                if consumer is not None
+                            ]
+                            if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers):
+                                rewrites.append(op)
                             else:
-                                use_NHCWB16 &= op.run_on_npu
+                                use_NHCWB16 = False
+                                use_fast_storage = False
+                        use_NHCWB16 &= op.run_on_npu
+                        use_fast_storage &= op.run_on_npu
 
-                        if use_NHCWB16:
-                            output.set_format(TensorFormat.NHCWB16, arch)
-                            for rewrite_op in rewrites:
-                                rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch)
+                    if use_fast_storage:
+                        fast_storage_tensor_rewrites[output] = rewrites
+                    if use_NHCWB16 and self.options.use_nhcwb16_between_cascaded_passes:
+                        output.set_format(TensorFormat.NHCWB16, arch)
+                        for rewrite_op in rewrites:
+                            rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch)
+            if self.feature_maps_not_in_fast_storage:
+                # Remember feature maps that can be moved to fast storage for later use
+                # in use_fast_storage_for_feature_maps
+                self.sg.scheduling_info["feature_map_rewrites"] = fast_storage_tensor_rewrites
 
 
 def schedule_passes(nng, arch, options: SchedulerOptions):
@@ -1027,3 +1041,75 @@
 
         if options.verbose_schedule:
             sg.print_cascaded_passes()
+
+
+def _calc_tens_to_cps(sg, tensor_rewrites):
+    # Determines for each tensor the list of affected cascaded passes, in terms of SRAM consumption.
+    # Returns dictionary tensor -> list of cascaded passes
+    # Note: if cascaded passes are A, B, C, D, and a tensor is output
+    # of A and input to D, then it also consumes SRAM in passes B and C.
+    if "tens_to_cps" in sg.scheduling_info:
+        return sg.scheduling_info["tens_to_cps"]
+    # Determine life-time of tensors
+    min_index = {}
+    max_index = {}
+    index = 0
+    cps_list = [cps for cps in sg.cascaded_passes if cps.placement == PassPlacement.Npu]
+    for cps in cps_list:
+        for tens in cps.inputs + cps.outputs:
+            if tens in tensor_rewrites:
+                min_index[tens] = min(index, min_index.get(tens, len(cps_list)))
+                max_index[tens] = index
+        index += 1
+    # Convert to affected cps-es
+    tens_to_cps = {}
+    for tens in min_index:
+        tens_to_cps[tens] = cps_list[min_index[tens] : max_index[tens] + 1]
+    sg.scheduling_info["tens_to_cps"] = tens_to_cps
+    return tens_to_cps
+
+
+def use_fast_storage_for_feature_maps(sg, sram_limit, arch):
+    # Attempts to use as much fast storage as possible for feature maps shared between cascaded passes.
+    tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {})
+    tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites)
+    # Sort tensors first on life-time (smallest first), then on size (biggest first)
+    tens_list = sorted([(len(tens_to_cps[tens]), -tens.storage_size(), tens.name, tens) for tens in tens_to_cps])
+    for _, _, _, tens in tens_list:
+        cps_list = tens_to_cps[tens]
+        if len(cps_list) <= 1:
+            continue
+        sz = tens.storage_size()
+        fits_in_fast_storage = all([cps.sram_used + sz <= sram_limit for cps in cps_list])
+        if fits_in_fast_storage:
+            tens.mem_area = arch.fast_storage_mem_area
+            tens.mem_type = MemType.Scratch_fast
+            tens.set_new_sub_purpose(TensorSubPurpose.Standard, None, None)
+            assert tens in tensor_rewrites
+            # Also rewrite reshapes
+            for rewrite_op in tensor_rewrites[tens]:
+                tens2 = rewrite_op.outputs[0]
+                tens2.mem_area = arch.fast_storage_mem_area
+                tens2.mem_type = MemType.Scratch_fast
+                tens2.set_new_sub_purpose(TensorSubPurpose.Standard, None, None)
+            for cps in cps_list:
+                cps.sram_used += sz
+
+
+def undo_use_fast_storage(sg, arch):
+    # Undoes the effects of a previous call to use_fast_storage_for_feature_maps
+    tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {})
+    tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites)
+    mem_area = arch.tensor_storage_mem_area[TensorPurpose.FeatureMap]
+    for tens, cps_list in tens_to_cps.items():
+        if tens.mem_type == MemType.Scratch_fast:
+            sz = tens.storage_size()
+            tens.mem_area = mem_area
+            tens.mem_type = MemType.Scratch
+            # Also undo reshapes
+            for rewrite_op in tensor_rewrites[tens]:
+                tens2 = rewrite_op.outputs[0]
+                tens2.mem_area = mem_area
+                tens2.mem_type = MemType.Scratch
+            for cps in cps_list:
+                cps.sram_used -= sz