MLBEDSW-2337: Intermediate feature maps in fast storage

Attempts to use fast storage for feature maps used in between
cascaded passes.

This is only relevant for system configurations where feature maps
are by default not placed in SRAM, but there is SRAM for fast storage.

Change-Id: I207b7cf32cfcb5bea3e6b93c2da1161c4af5221d
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/vela/compiler_driver.py b/ethosu/vela/compiler_driver.py
index 92fe584..6c1142d 100644
--- a/ethosu/vela/compiler_driver.py
+++ b/ethosu/vela/compiler_driver.py
@@ -88,6 +88,45 @@
     __repr__ = __str__
 
 
+def next_sram_factor(alloc_results):
+    # Bisects to find the max SRAM usage that successfully can be fitted with the tensor allocator.
+    # Returns tuple (factor, dry_test), with factor is None (stop) or 0 <= factor <= 1 (next SRAM factor to try),
+    # dry_test is True while still bisecting.
+    upper = 1.0
+    lower = 0.7
+    MAX_ITERATIONS = 8
+    if len(alloc_results) == 0:
+        # First iteration, try max SRAM, keep the result if it succeeds
+        return (upper, False)
+    elif len(alloc_results) == 1:
+        if alloc_results[0]:
+            # The allocator succeeded at first try; stop
+            return (None, False)
+        else:
+            # Start bisecting, try lowerbound SRAM
+            return (lower, True)
+    elif len(alloc_results) > MAX_ITERATIONS:
+        # Stop
+        return (None, False)
+    if not alloc_results[1]:
+        # Allocation at lower failed; search interval 0 - lower
+        upper = lower
+        lower = 0
+    best = lower
+    for success in alloc_results[2:]:
+        middle = (lower + upper) / 2
+        if success:
+            best = max(best, middle)
+            lower = middle
+        else:
+            upper = middle
+    if len(alloc_results) == MAX_ITERATIONS:
+        # Done bisecting; repeat the best match, but not as dry test
+        return (best, False)
+    # Next try; run only as dry test
+    return ((lower + upper) / 2, True)
+
+
 def compiler_driver(nng, arch, options, scheduler_options):
     assert verify_graph_health(nng)
     nng = graph_optimiser.optimise_graph_a(nng, arch, options.verbose_graph)
@@ -156,11 +195,11 @@
             arch,
             permanent_storage,
             set((MemType.Permanent_NPU,)),
-            scheduler_options.use_ifm_ofm_overlap,
-            TensorAllocator.LinearAlloc,
-            options.verbose_allocation,
-            options.show_minimum_possible_allocation,
-            lr_graph_flash,
+            use_ifm_ofm_overlap=scheduler_options.use_ifm_ofm_overlap,
+            tensor_allocator=TensorAllocator.LinearAlloc,
+            verbose_allocation=options.verbose_allocation,
+            show_minimum_possible_allocation=options.show_minimum_possible_allocation,
+            lr_graph=lr_graph_flash,
         )
 
     # Allocate all non-constant tensors to the root, i.e. Cpu, subgraph. This step
@@ -175,28 +214,68 @@
     root_sg = nng.get_root_subgraph()
 
     alloc_list = []
-    if arch.feature_map_storage_mem_area == arch.fast_storage_mem_area:
+    feature_maps_in_fast_storage = arch.feature_map_storage_mem_area == arch.fast_storage_mem_area
+    if feature_maps_in_fast_storage:
         mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))
         alloc_list.append(mem_alloc_scratch)
     else:
-        mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch,)))
         mem_alloc_scratch_fast = (arch.fast_storage_mem_area, set((MemType.Scratch_fast,)))
-        alloc_list.append(mem_alloc_scratch)
+        mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch,)))
+        # Order is important
         alloc_list.append(mem_alloc_scratch_fast)
+        alloc_list.append(mem_alloc_scratch)
 
-    for alloc in alloc_list:
-        tensor_allocation.allocate_tensors(
-            nng,
-            root_sg,
-            arch,
-            alloc[0],
-            alloc[1],
-            scheduler_options.use_ifm_ofm_overlap,
-            options.tensor_allocator,
-            options.verbose_allocation,
-            options.show_minimum_possible_allocation,
-            allocation_alignment=options.allocation_alignment,
-        )
+    for mem_area, mem_type_set in alloc_list:
+        if feature_maps_in_fast_storage or mem_area != arch.fast_storage_mem_area:
+            tensor_allocation.allocate_tensors(
+                nng,
+                root_sg,
+                arch,
+                mem_area,
+                mem_type_set,
+                use_ifm_ofm_overlap=scheduler_options.use_ifm_ofm_overlap,
+                tensor_allocator=options.tensor_allocator,
+                verbose_allocation=options.verbose_allocation,
+                show_minimum_possible_allocation=options.show_minimum_possible_allocation,
+                allocation_alignment=options.allocation_alignment,
+            )
+        else:
+            # For the case where scratch_fast != scratch: attempt to place feature maps used between
+            # cascaded passes in fast storage. Bisection is used to find the max possible usage of SRAM.
+            alloc_results = []
+            while True:
+                assert len(alloc_results) < 10, "Infinite allocator loop"
+                sram_factor, dry_test = next_sram_factor(alloc_results)
+                if sram_factor is None:
+                    break
+                # Try to move as many feature maps as possible to SRAM before allocating
+                sram_limit = sram_factor * arch.sram_size
+                for sg in nng.subgraphs:
+                    scheduler.use_fast_storage_for_feature_maps(sg, sram_limit, arch)
+                alloc_success = tensor_allocation.allocate_tensors(
+                    nng,
+                    root_sg,
+                    arch,
+                    mem_area,
+                    mem_type_set,
+                    max_size=arch.sram_size,
+                    dry_test=dry_test,
+                    use_ifm_ofm_overlap=scheduler_options.use_ifm_ofm_overlap,
+                    tensor_allocator=options.tensor_allocator,
+                    verbose_allocation=options.verbose_allocation,
+                    show_minimum_possible_allocation=options.show_minimum_possible_allocation,
+                    allocation_alignment=options.allocation_alignment,
+                )
+                if dry_test or not alloc_success:
+                    for sg in nng.subgraphs:
+                        scheduler.undo_use_fast_storage(sg, arch)
+                alloc_results.append(alloc_success)
+            if not alloc_results[-1]:
+                raise VelaError(
+                    "Sram limit {} bytes, has been exceeded by the scratch fast tensor. "
+                    "Increasing the value of --weight-estimation-scaling may help to resolve the issue. "
+                    "See OPTIONS.md for more information.".format(arch.sram_size)
+                )
 
     # Generate command streams and serialise Npu-ops into tensors
     for sg in nng.subgraphs:
@@ -213,16 +292,6 @@
 
     npu_serialisation.rewrite_npu_call_ops(nng, root_sg, arch)
 
-    if root_sg is not None and (arch.feature_map_storage_mem_area != arch.fast_storage_mem_area):
-        if root_sg.memory_used_per_type.get(MemType.Scratch_fast, 0) > arch.sram_size:
-            raise VelaError(
-                "Sram limit {} bytes, has been exceeded by the scratch fast tensor {} bytes. "
-                "Increasing the value of --weight-estimation-scaling may help to resolve the issue. "
-                "See OPTIONS.md for more information.".format(
-                    arch.sram_size, root_sg.memory_used_per_type.get(MemType.Scratch_fast, 0)
-                )
-            )
-
     # Allocate all Cpu constant tensors, this is done last because the Npu-ops
     # have to be serialized into flash and scratch tensors first
     tensor_allocation.allocate_tensors(
@@ -231,10 +300,10 @@
         arch,
         permanent_storage,
         set((MemType.Permanent_CPU,)),
-        scheduler_options.use_ifm_ofm_overlap,
-        TensorAllocator.LinearAlloc,
-        options.verbose_allocation,
-        options.show_minimum_possible_allocation,
+        use_ifm_ofm_overlap=scheduler_options.use_ifm_ofm_overlap,
+        tensor_allocator=TensorAllocator.LinearAlloc,
+        verbose_allocation=options.verbose_allocation,
+        show_minimum_possible_allocation=options.show_minimum_possible_allocation,
         allocation_alignment=options.allocation_alignment,
     )
 
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index 21cd80b..58aab61 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -141,6 +141,8 @@
         self.placement = placement
         self.command_stream_tensor = None
         self.flash_tensor = None
+        # Scratch information locally used in the scheduler
+        self.scheduling_info = {}
 
         self.memory_used = {}
         self.memory_used_per_type = {}
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 24453d8..5c2ddab 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -959,52 +959,66 @@
         self.sg.cascaded_passes = cascaded_passes
         self.sg.build_cascaded_pass_links()
 
-        if self.options.use_nhcwb16_between_cascaded_passes:
-            # Check if NHCWB16 can be used in between cascaded passes
-            # (NHCWB16 within cascaded passes has been handled earlier in this function)
-            if self.sg.placement == PassPlacement.Npu:
-                last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op
-                for ps in self.sg.cascaded_passes:
-                    if ps.placement != PassPlacement.Npu:
+        # Check if NHCWB16 and/or fast storage can be used in between cascaded passes
+        # (NHCWB16 within cascaded passes has been handled earlier in this function)
+        if self.sg.placement == PassPlacement.Npu:
+            # Dictionary tensor -> list of ops, containing feature maps that can be attempted
+            # to be moved to fast storage
+            fast_storage_tensor_rewrites = {}
+            last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op
+            for ps in self.sg.cascaded_passes:
+                if ps.placement != PassPlacement.Npu:
+                    continue
+                for output in ps.outputs:
+                    if output.purpose != TensorPurpose.FeatureMap or output.avoid_NHCWB16:
                         continue
-                    for output in ps.outputs:
-                        if output.purpose != TensorPurpose.FeatureMap or output.avoid_NHCWB16:
+
+                    use_NHCWB16 = True
+                    use_fast_storage = True
+                    rewrites = []
+                    for op in output.consumer_list:
+                        if op is None:
+                            use_NHCWB16 = False
+                            use_fast_storage = False
                             continue
-
-                        use_NHCWB16 = True
-                        rewrites = []
-                        for op in output.consumer_list:
-                            if op is None or (op.type == "ReduceSum" and output.dtype == DataType.int32):
-                                use_NHCWB16 = False
-                            elif op.type == "Reshape":
-                                # Detect no-op reshapes by comparing their full input and output tensor shapes.
-                                inshape = full_shape(4, op.inputs[0].shape, 1)
-                                outshape = full_shape(4, op.outputs[0].shape, 1)
-                                # Using NHCWB16 format for a no-op reshape is only an option if subsequent
-                                # consumers do not also need to perform a reshape or if the OFM is going to
-                                # be processed by CPU operations. No-op reshape consumers with empty lists
-                                # (those that have no consumers, or null-consumers used as list terminators)
-                                # must use normal NHWC output.
-                                incompatible_consumers = [
-                                    (
-                                        not consumer.run_on_npu
-                                        or consumer.type == "Reshape"
-                                        or (consumer is last_op_in_subgraph)
-                                    )
-                                    for consumer in op.outputs[0].consumer_list
-                                    if consumer is not None
-                                ]
-                                if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers):
-                                    rewrites.append(op)
-                                else:
-                                    use_NHCWB16 = False
+                        if op.type == "ReduceSum" and output.dtype == DataType.int32:
+                            use_NHCWB16 = False
+                        elif op.type == "Reshape":
+                            # Detect no-op reshapes by comparing their full input and output tensor shapes.
+                            inshape = full_shape(4, op.inputs[0].shape, 1)
+                            outshape = full_shape(4, op.outputs[0].shape, 1)
+                            # Using NHCWB16 format for a no-op reshape is only an option if subsequent
+                            # consumers do not also need to perform a reshape or if the OFM is going to
+                            # be processed by CPU operations. No-op reshape consumers with empty lists
+                            # (those that have no consumers, or null-consumers used as list terminators)
+                            # must use normal NHWC output.
+                            incompatible_consumers = [
+                                (
+                                    not consumer.run_on_npu
+                                    or consumer.type == "Reshape"
+                                    or (consumer is last_op_in_subgraph)
+                                )
+                                for consumer in op.outputs[0].consumer_list
+                                if consumer is not None
+                            ]
+                            if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers):
+                                rewrites.append(op)
                             else:
-                                use_NHCWB16 &= op.run_on_npu
+                                use_NHCWB16 = False
+                                use_fast_storage = False
+                        use_NHCWB16 &= op.run_on_npu
+                        use_fast_storage &= op.run_on_npu
 
-                        if use_NHCWB16:
-                            output.set_format(TensorFormat.NHCWB16, arch)
-                            for rewrite_op in rewrites:
-                                rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch)
+                    if use_fast_storage:
+                        fast_storage_tensor_rewrites[output] = rewrites
+                    if use_NHCWB16 and self.options.use_nhcwb16_between_cascaded_passes:
+                        output.set_format(TensorFormat.NHCWB16, arch)
+                        for rewrite_op in rewrites:
+                            rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch)
+            if self.feature_maps_not_in_fast_storage:
+                # Remember feature maps that can be moved to fast storage for later use
+                # in use_fast_storage_for_feature_maps
+                self.sg.scheduling_info["feature_map_rewrites"] = fast_storage_tensor_rewrites
 
 
 def schedule_passes(nng, arch, options: SchedulerOptions):
@@ -1027,3 +1041,75 @@
 
         if options.verbose_schedule:
             sg.print_cascaded_passes()
+
+
+def _calc_tens_to_cps(sg, tensor_rewrites):
+    # Determines for each tensor the list of affected cascaded passes, in terms of SRAM consumption.
+    # Returns dictionary tensor -> list of cascaded passes
+    # Note: if cascaded passes are A, B, C, D, and a tensor is output
+    # of A and input to D, then it also consumes SRAM in passes B and C.
+    if "tens_to_cps" in sg.scheduling_info:
+        return sg.scheduling_info["tens_to_cps"]
+    # Determine life-time of tensors
+    min_index = {}
+    max_index = {}
+    index = 0
+    cps_list = [cps for cps in sg.cascaded_passes if cps.placement == PassPlacement.Npu]
+    for cps in cps_list:
+        for tens in cps.inputs + cps.outputs:
+            if tens in tensor_rewrites:
+                min_index[tens] = min(index, min_index.get(tens, len(cps_list)))
+                max_index[tens] = index
+        index += 1
+    # Convert to affected cps-es
+    tens_to_cps = {}
+    for tens in min_index:
+        tens_to_cps[tens] = cps_list[min_index[tens] : max_index[tens] + 1]
+    sg.scheduling_info["tens_to_cps"] = tens_to_cps
+    return tens_to_cps
+
+
+def use_fast_storage_for_feature_maps(sg, sram_limit, arch):
+    # Attempts to use as much fast storage as possible for feature maps shared between cascaded passes.
+    tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {})
+    tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites)
+    # Sort tensors first on life-time (smallest first), then on size (biggest first)
+    tens_list = sorted([(len(tens_to_cps[tens]), -tens.storage_size(), tens.name, tens) for tens in tens_to_cps])
+    for _, _, _, tens in tens_list:
+        cps_list = tens_to_cps[tens]
+        if len(cps_list) <= 1:
+            continue
+        sz = tens.storage_size()
+        fits_in_fast_storage = all([cps.sram_used + sz <= sram_limit for cps in cps_list])
+        if fits_in_fast_storage:
+            tens.mem_area = arch.fast_storage_mem_area
+            tens.mem_type = MemType.Scratch_fast
+            tens.set_new_sub_purpose(TensorSubPurpose.Standard, None, None)
+            assert tens in tensor_rewrites
+            # Also rewrite reshapes
+            for rewrite_op in tensor_rewrites[tens]:
+                tens2 = rewrite_op.outputs[0]
+                tens2.mem_area = arch.fast_storage_mem_area
+                tens2.mem_type = MemType.Scratch_fast
+                tens2.set_new_sub_purpose(TensorSubPurpose.Standard, None, None)
+            for cps in cps_list:
+                cps.sram_used += sz
+
+
+def undo_use_fast_storage(sg, arch):
+    # Undoes the effects of a previous call to use_fast_storage_for_feature_maps
+    tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {})
+    tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites)
+    mem_area = arch.tensor_storage_mem_area[TensorPurpose.FeatureMap]
+    for tens, cps_list in tens_to_cps.items():
+        if tens.mem_type == MemType.Scratch_fast:
+            sz = tens.storage_size()
+            tens.mem_area = mem_area
+            tens.mem_type = MemType.Scratch
+            # Also undo reshapes
+            for rewrite_op in tensor_rewrites[tens]:
+                tens2 = rewrite_op.outputs[0]
+                tens2.mem_area = mem_area
+                tens2.mem_type = MemType.Scratch
+            for cps in cps_list:
+                cps.sram_used -= sz
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index eedbada..c0786bf 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -278,7 +278,7 @@
     def set_address_for_tens(cls, tens_id, mem_type, address):
         # Check previous address if there is one
         previous_address = cls.address_map[tens_id].get(mem_type)
-        if previous_address is not None:
+        if address is not None and previous_address is not None:
             assert previous_address == address, "Two different addresses cannot be assigned to the same tensor."
 
         # Set tensor's address for memory type
diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py
index d53babc..1efcd68 100644
--- a/ethosu/vela/tensor_allocation.py
+++ b/ethosu/vela/tensor_allocation.py
@@ -128,7 +128,10 @@
     show_minimum_possible_allocation=False,
     lr_graph=None,
     allocation_alignment=Tensor.AllocationQuantum,
+    max_size=None,
+    dry_test=False,
 ):
+    # Allocates addresses to tensors, returns False if tensors could not be fit within max_size
     ignore_subgraph_input_output_tensors = False
     lrs = live_range.extract_live_ranges_from_cascaded_passes(
         sg,
@@ -149,6 +152,12 @@
             total_sz = linear_allocate_live_ranges(lrs, allocation_alignment)
         else:
             assert 0
+        alloc_ok = max_size is None or total_sz <= max_size
+        if dry_test or not alloc_ok:
+            # Dry test or allocation failed; undo allocation
+            for lr in lrs.ranges.values():
+                lr.set_address(None)
+            return alloc_ok
 
         if sg.memory_used.get(mem_area, 0) == 0:
             sg.memory_used[mem_area] = total_sz
@@ -179,3 +188,4 @@
                 nng.bits_per_element[mem_area] = nng.total_size[mem_area] * 8 / nng.total_elements[mem_area]
             except ZeroDivisionError:
                 nng.bits_per_element[mem_area] = 0.0
+    return True
diff --git a/ethosu/vela/test/test_compiler_driver.py b/ethosu/vela/test/test_compiler_driver.py
new file mode 100644
index 0000000..56a90c4
--- /dev/null
+++ b/ethosu/vela/test/test_compiler_driver.py
@@ -0,0 +1,44 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Description:
+# Unit tests for compiler driver
+from ethosu.vela.compiler_driver import next_sram_factor
+
+
+def test_next_sram_factor():
+    lower = 0.7
+    assert (1.0, False) == next_sram_factor([])
+    assert (None, False) == next_sram_factor([True])
+    assert (lower, True) == next_sram_factor([False])
+    assert ((1 + lower) / 2, True) == next_sram_factor([False, True])
+    assert (lower / 2, True) == next_sram_factor([False, False])
+    # Tests next_sram_factor for a range of simulated allocator efficiencies
+    for i in range(20):
+        allocator_factor = i / 20.0  # The simulated allocator efficiency
+        alloc_results = []
+        bisected_factor = 0  # The end result of the bisect search
+        while True:
+            factor, dry_test = next_sram_factor(alloc_results)
+            if factor is None:
+                break
+            alloc_result = factor < allocator_factor
+            if alloc_result and not dry_test:
+                bisected_factor = factor
+            alloc_results.append(alloc_result)
+            assert len(alloc_results) < 100
+        assert bisected_factor <= allocator_factor
+        assert abs(bisected_factor - allocator_factor) < 0.02