MLBEDSW-2306 Added more supported mem-cfgs

Additional supported memory configurations:
-Permanent_storage = DRAM
-Tensor arena either in DRAM or SRAM

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I20beb7151e306bfdba540e7c0b2a7b478b4d94e1
diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py
index fef2c40..e33c5d5 100644
--- a/ethosu/vela/architecture_features.py
+++ b/ethosu/vela/architecture_features.py
@@ -28,6 +28,7 @@
 from .operation import NpuBlockType
 from .supported_operators import SupportedOperators
 from .tensor import MemArea
+from .tensor import MemType
 from .tensor import TensorFormat
 from .tensor import TensorPurpose
 
@@ -168,11 +169,6 @@
 
         is_yoda_system = "yoda-" in self.accelerator_config
 
-        if is_yoda_system:
-            self.sram_size = 256 * 1024
-        else:
-            self.sram_size = 200 * 1024 * 1024
-
         self.ncores = accel_config.cores
         self.ofm_ublock = accel_config.ofm_ublock
         self.ifm_ublock = accel_config.ifm_ublock
@@ -233,7 +229,8 @@
         self.default_weight_format = TensorFormat.WeightsCompressed
         self.default_feature_map_format = TensorFormat.NHWC
 
-        if permanent_storage != MemArea.OffChipFlash:
+        # This is to ignore permanent_storage = On/OffChipflash for Yoda
+        if not is_yoda_system and permanent_storage != MemArea.OffChipFlash:
             self.permanent_storage_mem_area = permanent_storage
 
         self.tensor_storage_mem_area = {
@@ -243,10 +240,10 @@
             TensorPurpose.FeatureMap: self.feature_map_storage_mem_area,
         }
 
-        self.tensor_load_mem_area = dict(self.tensor_storage_mem_area)
-
-        if self.tensor_storage_mem_area[TensorPurpose.Weights] in (MemArea.OffChipFlash,):
-            self.tensor_load_mem_area[TensorPurpose.Weights] = MemArea.Sram
+        self.tensor_storage_mem_type = {
+            TensorPurpose.Weights: MemType.Permanent_NPU,
+            TensorPurpose.FeatureMap: MemType.Scratch,
+        }
 
         self.min_block_sizes = {
             NpuBlockType.Default: (dpu_min_height, dpu_min_width),
@@ -278,7 +275,7 @@
         self.max_sram_used_weight = 1000
 
         if is_yoda_system:
-            self.max_sram_used_weight = 0
+            self.max_sram_used_weight = 1000
 
         # Shared Buffer Block allocations
         self.shram_bank_size = 1024  # bytes
@@ -589,14 +586,21 @@
 
             self.fast_storage_mem_area = MemArea[self.__sys_config("fast_storage_mem_area", "Sram")]
             self.feature_map_storage_mem_area = MemArea[self.__sys_config("feature_map_storage_mem_area", "Sram")]
+
+            if self.fast_storage_mem_area != self.feature_map_storage_mem_area:
+                raise Exception(
+                    "Invalid memory configuration fast_storage_mem_area must be same as feature_map_storage_mem_area"
+                )
             self.permanent_storage_mem_area = MemArea[self.__sys_config("permanent_storage_mem_area", "OffChipFlash")]
-            if self.permanent_storage_mem_area not in set((MemArea.OnChipFlash, MemArea.OffChipFlash)):
+            if self.permanent_storage_mem_area not in set((MemArea.OnChipFlash, MemArea.OffChipFlash, MemArea.Dram)):
                 raise Exception(
                     "Invalid permanent_storage_mem_area = "
                     + str(self.permanent_storage_mem_area)
-                    + " (must be 'OnChipFlash' or 'OffChipFlash'). To store the weights and other constant data in SRAM"
-                    " select 'OnChipFlash'"
+                    + " (must be 'OnChipFlash', 'OffChipFlash' or 'DRAM')."
+                    " To store the weights and other constant data in SRAM on ethosu-55 select 'OnChipFlash'"
                 )
+            self.sram_size = 1024 * int(self.__sys_config("sram_size_kb", "204800"))
+
         except Exception:
             print("Error: Reading System Configuration in vela configuration file, section {}".format(section_key))
             raise
diff --git a/ethosu/vela/compiler_driver.py b/ethosu/vela/compiler_driver.py
index 9c345db..e495f1c 100644
--- a/ethosu/vela/compiler_driver.py
+++ b/ethosu/vela/compiler_driver.py
@@ -33,7 +33,7 @@
 from .nn_graph import PassPlacement
 from .nn_graph import TensorAllocator
 from .rewrite_graph import verify_graph_health
-from .tensor import MemArea
+from .tensor import MemType
 
 
 class CompilerOptions:
@@ -120,9 +120,6 @@
     # block config, and calc and pack the scales and biases
     weight_compressor.update_pass_weight_and_scale_tensors(nng, arch)
 
-    # Memory area for all non-constant tensors (Cpu and Npu)
-    non_const_mem_area = MemArea.Sram
-
     # LiveRanges for constant tensors for all Npu subgraphs
     permanent_storage = arch.permanent_storage_mem_area
     lr_graph_flash = live_range.LiveRangeGraph()
@@ -135,7 +132,11 @@
     for sg in nng.subgraphs:
         if sg.placement == PassPlacement.Npu:
             lr_graph_flash = live_range.extract_live_ranges_from_cascaded_passes(
-                sg, permanent_storage, ignore_subgraph_input_output_tensors=True, lr_graph=lr_graph_flash
+                sg,
+                permanent_storage,
+                MemType.Permanent_NPU,
+                ignore_subgraph_input_output_tensors=True,
+                lr_graph=lr_graph_flash,
             )
 
     if len(nng.subgraphs) > 1:
@@ -143,12 +144,12 @@
         # processed first during serialization into tensors
         first_npu_sg = nng.subgraphs[1]
         assert first_npu_sg.placement == PassPlacement.Npu
-        # Use the linear allocator for constant tensors
         tensor_allocation.allocate_tensors(
             nng,
             first_npu_sg,
             arch,
             permanent_storage,
+            set((MemType.Permanent_NPU,)),
             scheduler_options.use_ifm_ofm_overlap,
             TensorAllocator.LinearAlloc,
             options.verbose_allocation,
@@ -159,19 +160,36 @@
     # Allocate all non-constant tensors to the root, i.e. Cpu, subgraph. This step
     # will start at the root subgraph's input and traverse from top to bottom. When
     # it comes across an Npu-op it will extract live ranges for it's corresponding
-    # Npu subgraph and add them to the root's live range graph. Finally, all of the
-    # non-constant tensors are allocated together
+    # Npu subgraph and add them to the root's live range graph.
+    # The non-constant tensors are stored either in arch.feature_map_storage_mem_area or
+    # arch.fast_storage_mem_area.
+    # When these memory areas are the same, all non-constant tensors are allocated together.
+    # Otherwise they are allocated separately.
+
     root_sg = nng.get_root_subgraph()
-    tensor_allocation.allocate_tensors(
-        nng,
-        root_sg,
-        arch,
-        non_const_mem_area,
-        scheduler_options.use_ifm_ofm_overlap,
-        options.tensor_allocator,
-        options.verbose_allocation,
-        options.show_minimum_possible_allocation,
-    )
+
+    alloc_list = []
+    if arch.feature_map_storage_mem_area == arch.fast_storage_mem_area:
+        mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))
+        alloc_list.append(mem_alloc_scratch)
+    else:
+        mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch,)))
+        mem_alloc_scratch_fast = (arch.fast_storage_mem_area, set((MemType.Scratch_fast,)))
+        alloc_list.append(mem_alloc_scratch)
+        alloc_list.append(mem_alloc_scratch_fast)
+
+    for alloc in alloc_list:
+        tensor_allocation.allocate_tensors(
+            nng,
+            root_sg,
+            arch,
+            alloc[0],
+            alloc[1],
+            scheduler_options.use_ifm_ofm_overlap,
+            options.tensor_allocator,
+            options.verbose_allocation,
+            options.show_minimum_possible_allocation,
+        )
 
     # Generate command streams and serialise Npu-ops into tensors
     for sg in nng.subgraphs:
@@ -194,6 +212,7 @@
         root_sg,
         arch,
         permanent_storage,
+        set((MemType.Permanent_CPU,)),
         scheduler_options.use_ifm_ofm_overlap,
         TensorAllocator.LinearAlloc,
         options.verbose_allocation,
diff --git a/ethosu/vela/insert_dma.py b/ethosu/vela/insert_dma.py
index 7049a05..5c05fc8 100644
--- a/ethosu/vela/insert_dma.py
+++ b/ethosu/vela/insert_dma.py
@@ -19,6 +19,7 @@
 from .operation import NpuBlockType
 from .operation import Operation
 from .tensor import MemArea
+from .tensor import MemType
 from .tensor import TensorPurpose
 
 
@@ -30,29 +31,34 @@
         return op  # Already rewritten
     for idx, tens in enumerate(op.inputs):
 
-        if tens.mem_area in (MemArea.Dram, MemArea.OffChipFlash) and tens.mem_area != arch.fast_storage_mem_area:
-            if tens.purpose == TensorPurpose.Weights or (
-                tens.purpose == TensorPurpose.FeatureMap and op.type in binary_elementwise_op and tens.shape != []
+        if tens.mem_type not in (MemType.Scratch, MemType.Scratch_fast):
+            # Tensor is in permanent storage
+            # Only when permanent storage differs from fast storage, there is a point moving the data
+            if tens.mem_area in (MemArea.Dram, MemArea.OffChipFlash) and (
+                arch.permanent_storage_mem_area != arch.fast_storage_mem_area
             ):
-                only_vector_product_consumers = True
-                for oper in tens.consumers():
-                    if oper is None or oper.attrs.get("npu_block_type") != NpuBlockType.VectorProduct:
-                        only_vector_product_consumers = False
-                        break
+                if tens.purpose == TensorPurpose.Weights or (
+                    tens.purpose == TensorPurpose.FeatureMap and op.type in binary_elementwise_op and tens.shape != []
+                ):
+                    only_vector_product_consumers = True
+                    for oper in tens.consumers():
+                        if oper is None or oper.attrs.get("npu_block_type") != NpuBlockType.VectorProduct:
+                            only_vector_product_consumers = False
+                            break
 
-                # Tensor products has no need for DMA, tensors are only read once and can be in flash.
-                # Other operations re-reads tensors, this is better done from SRAM.
-                if not only_vector_product_consumers:
-                    # Insert a DMA command here, as well as a new tensor situated in SRAM of the same size.
-                    new_tens = tens.clone_into_fast_storage(arch)
-                    dma_cmd = Operation("DMA", tens.ops[0].name + "_dma")
-                    dma_cmd.inputs = [tens]
-                    dma_cmd.outputs = [new_tens]
-                    dma_cmd.attrs["source"] = tens.mem_area
-                    dma_cmd.attrs["destination"] = new_tens.mem_area
-                    dma_cmd.run_on_npu = True
-                    new_tens.ops = [dma_cmd]
-                    op.inputs[idx] = new_tens
+                    # Tensor products has no need for DMA, tensors are only read once and can be in flash.
+                    # Other operations re-reads tensors, this is better done from SRAM.
+                    if not only_vector_product_consumers:
+                        # Insert a DMA command here, as well as a new tensor situated in SRAM of the same size.
+                        new_tens = tens.clone_into_fast_storage(arch)
+                        dma_cmd = Operation("DMA", tens.ops[0].name + "_dma")
+                        dma_cmd.inputs = [tens]
+                        dma_cmd.outputs = [new_tens]
+                        dma_cmd.attrs["source"] = tens.mem_area
+                        dma_cmd.attrs["destination"] = new_tens.mem_area
+                        dma_cmd.run_on_npu = True
+                        new_tens.ops = [dma_cmd]
+                        op.inputs[idx] = new_tens
     return op
 
 
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 2a35a11..8fe3d57 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -18,7 +18,7 @@
 # Can work with either a pass packed subgraph or a scheduled subgraph.
 from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_cascaded_pass
 from .nn_graph import PassPlacement
-from .tensor import MemArea
+from .tensor import MemType
 from .tensor import Tensor
 
 
@@ -220,6 +220,7 @@
 def extract_live_ranges_from_cascaded_passes(
     sg,
     target_mem_area,
+    target_mem_type_set,
     mark_output_tensors_overlapping_with_input_tensors=False,
     use_ifm_ofm_overlap=True,
     ignore_subgraph_input_output_tensors=False,
@@ -236,8 +237,8 @@
         lr_graph.ignore_tensors.update(sg.input_tensors)
         lr_graph.ignore_tensors.update(sg.output_tensors)
 
-    def tensor_should_be_ignored(tens, target_mem_area):
-        if tens.mem_area != target_mem_area:
+    def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
+        if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
             return True
         if tens in lr_graph.ignore_tensors:
             return True
@@ -247,9 +248,24 @@
             return True
         return False
 
+    def merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area, target_mem_type_set):
+        for ps in sg.passes:
+            if ps.placement == PassPlacement.MemoryOnly:
+                # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
+                input_tensor = ps.inputs[0]
+                output_tensor = ps.outputs[0]
+                # If the input or output tensor is tied to a Cpu tensor, i.e. a subgraph input
+                # or output, fuse the live-range with the Cpu tensors' live-range instead.
+                input_tensor = input_tensor.cpu_tensor if input_tensor.cpu_tensor is not None else input_tensor
+                output_tensor = output_tensor.cpu_tensor if output_tensor.cpu_tensor is not None else output_tensor
+                if not tensor_should_be_ignored(input_tensor, target_mem_area, target_mem_type_set) and not (
+                    tensor_should_be_ignored(output_tensor, target_mem_area, target_mem_type_set)
+                ):
+                    lr_graph.fuse_ranges(input_tensor, output_tensor)
+
     # Merge only memory operations in the NPU subgraphs
     if sg.placement == PassPlacement.Npu:
-        merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area)
+        merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area, target_mem_type_set)
 
     for cps in sg.cascaded_passes:
         cps.time = lr_graph.current_time
@@ -259,19 +275,21 @@
         is_element_wise = cps.is_element_wise
 
         for tens in cps.inputs:
-            if tensor_should_be_ignored(tens, target_mem_area):
+            if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
                 continue
             rng = lr_graph.get_or_create_range(tens)
             rng.mark_usage(time_for_pass)
 
         cps_primary_op = cps.passes[0].primary_op
-        if cps_primary_op and cps_primary_op.type == "NpuOp" and target_mem_area in set((MemArea.Sram, MemArea.Dram)):
+
+        if cps_primary_op and cps_primary_op.type == "NpuOp" and MemType.Permanent_CPU not in target_mem_type_set:
             # If the primary-op is an NpuOp that means this is where an Npu subgraph
             # is called. Go into said subgraph and extract live ranges before continuing.
             npu_sg = cps_primary_op.attrs["subgraph"]
             lr_graph = extract_live_ranges_from_cascaded_passes(
                 npu_sg,
                 target_mem_area,
+                target_mem_type_set,
                 mark_output_tensors_overlapping_with_input_tensors,
                 use_ifm_ofm_overlap,
                 False,
@@ -282,13 +300,13 @@
             cps.time = time_for_pass
 
         for tens in cps.intermediates:
-            if tensor_should_be_ignored(tens, target_mem_area):
+            if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
                 continue
             rng = lr_graph.get_or_create_range(tens)
             rng.mark_usage(time_for_pass)
 
         for tens in cps.outputs:
-            if tensor_should_be_ignored(tens, target_mem_area):
+            if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
                 continue
             rng = lr_graph.get_or_create_range(tens)
             output_time = time_for_pass
@@ -303,8 +321,8 @@
             if (
                 ifm_tensor is not None
                 and ofm_tensor is not None
-                and not tensor_should_be_ignored(ifm_tensor, target_mem_area)
-                and not tensor_should_be_ignored(ofm_tensor, target_mem_area)
+                and not tensor_should_be_ignored(ifm_tensor, target_mem_area, target_mem_type_set)
+                and not tensor_should_be_ignored(ofm_tensor, target_mem_area, target_mem_type_set)
             ):
                 lr_graph.allowed_overlaps[(ifm_tensor, ofm_tensor)] = calc_allowed_ofm_ifm_overlap_for_cascaded_pass(
                     cps
@@ -318,7 +336,7 @@
         end_time = max(end_time, rng.end_time)
 
     for tens in sg.output_tensors:
-        if tensor_should_be_ignored(tens, target_mem_area):
+        if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
             continue
         rng = lr_graph.get_or_create_range(tens)
         rng.mark_usage(end_time)
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py
index c4f2bae..705f839 100644
--- a/ethosu/vela/mark_tensors.py
+++ b/ethosu/vela/mark_tensors.py
@@ -18,6 +18,7 @@
 from . import rewrite_graph
 from . import weight_compressor
 from .errors import OperatorError
+from .tensor import MemType
 from .tensor import TensorFormat
 from .tensor import TensorPurpose
 from .tflite_mapping import custom_prefix
@@ -254,11 +255,13 @@
         else:
             assert 0, "Cannot resolve tensor purpose %s and %s for tensor %s" % (tens.purpose, purpose, tens)
         tens.mem_area = arch.tensor_storage_mem_area[tens.purpose]
+        tens.mem_type = arch.tensor_storage_mem_type[tens.purpose]
 
         if len(tens.ops) == 1 and tens.ops[0].type == "Const":
             tens.mem_area = (
                 arch.permanent_storage_mem_area
             )  # special case constants, as they must be in permanent storage
+            tens.mem_type = MemType.Permanent_NPU
 
     def rewrite_mark_tensor_purpose(op, arch):
         # find disconnected outputs and mark as parameters
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index ea35c08..247e6cc 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -137,6 +137,7 @@
         self.flash_tensor = None
 
         self.memory_used = {}
+        self.memory_used_per_type = {}
 
     def __str__(self):
         return "<nng.Subgraph '%s',  n_passes=%d, n_cascaded_passes=%d>" % (
@@ -349,9 +350,15 @@
         for idx, op in enumerate(all_ops):
             print(idx, op.type, op.name)
             for idx, tens in enumerate(op.inputs):
-                print("    Input  %02d %20s %20s %s" % (idx, tens.purpose.name, tens.mem_area.name, tens))
+                print(
+                    "    Input  %02d %20s %20s %20s %s"
+                    % (idx, tens.purpose.name, tens.mem_area.name, tens.mem_type.name, tens)
+                )
             for idx, tens in enumerate(op.outputs):
-                print("    Output %02d %20s %20s %s" % (idx, tens.purpose.name, tens.mem_area.name, tens))
+                print(
+                    "    Output %02d %20s %20s %20s %s"
+                    % (idx, tens.purpose.name, tens.mem_area.name, tens.mem_type.name, tens)
+                )
             print()
 
     def print_graph_with_tensor_quantization(self):
diff --git a/ethosu/vela/npu_serialisation.py b/ethosu/vela/npu_serialisation.py
index 18d38f3..bd13a3e 100644
--- a/ethosu/vela/npu_serialisation.py
+++ b/ethosu/vela/npu_serialisation.py
@@ -24,14 +24,16 @@
 from .nn_graph import PassPlacement
 from .operation import Operation
 from .tensor import MemArea
+from .tensor import MemType
 from .tensor import Tensor
 from .tensor import TensorFormat
 from .tensor import TensorPurpose
 
 
-def make_memory_tensor(name, mem_area, sz, want_values, arch):
+def make_memory_tensor(name, mem_area, mem_type, sz, want_values, arch):
     tens = Tensor([sz], DataType.uint8, name)
     tens.mem_area = mem_area
+    tens.mem_type = mem_type
     tens.purpose = TensorPurpose.FeatureMap
     tens.set_format(TensorFormat.NHWC, arch)
     if want_values:
@@ -58,7 +60,7 @@
         return scratch_tens, flash_tens
 
     flash_area = arch.permanent_storage_mem_area
-    scratch_area = MemArea.Sram
+    scratch_area = arch.feature_map_storage_mem_area
 
     flash_size = sg.memory_used.get(flash_area, 0)
     scratch_size = sg.memory_used.get(scratch_area, 0)
@@ -85,9 +87,13 @@
 
     if flash_tens == scratch_tens is None:
         # First Npu subgraph, create scratch and flash tensors
-        sg.scratch_tensor = make_memory_tensor(sg.name + "_scratch", scratch_area, scratch_size, False, arch)
+        sg.scratch_tensor = make_memory_tensor(
+            sg.name + "_scratch", scratch_area, MemType.Scratch, scratch_size, False, arch
+        )
         sg.scratch_tensor.purpose = TensorPurpose.Scratch
-        sg.flash_tensor = make_memory_tensor(sg.name + "_flash", flash_area, flash_size, True, arch)
+        sg.flash_tensor = make_memory_tensor(
+            sg.name + "_flash", flash_area, MemType.Permanent_CPU, flash_size, True, arch
+        )
     else:
         sg.scratch_tensor = scratch_tens
         sg.scratch_tensor.shape[0] += scratch_size
@@ -108,13 +114,15 @@
 
                     copy_compressed_values_to_memory_tensor(sg.flash_tensor, ps.scale_tensor)
 
-                if ps.ifm_tensor is not None and ps.ifm_tensor.mem_area != MemArea.Sram:
+                if ps.ifm_tensor is not None and ps.ifm_tensor.mem_type not in (MemType.Scratch, MemType.Scratch_fast):
                     copy_ifm_values_to_memory_tensor(sg.flash_tensor, ps.ifm_tensor)
-                if ps.ifm2_tensor is not None and ps.ifm2_tensor.mem_area != MemArea.Sram:
+                if ps.ifm2_tensor is not None and (
+                    ps.ifm2_tensor.mem_type not in (MemType.Scratch, MemType.Scratch_fast)
+                ):
                     copy_ifm_values_to_memory_tensor(sg.flash_tensor, ps.ifm2_tensor)
 
     sg.command_stream_tensor = make_memory_tensor(
-        sg.name + "_command_stream", flash_area, command_stream_size_bytes, True, arch
+        sg.name + "_command_stream", flash_area, MemType.Permanent_CPU, command_stream_size_bytes, True, arch
     )
     sg.command_stream_tensor.values = np.frombuffer(payload_bytes, dtype=np.uint8)
 
@@ -156,4 +164,5 @@
                         prev_cps.sram_used += sz
 
                     if callee.scratch_tensor is not None:
-                        cps.sram_used += callee.scratch_tensor.storage_size()
+                        if callee.scratch_tensor.mem_area == MemArea.Sram:
+                            cps.sram_used += callee.scratch_tensor.storage_size()
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index c46016d..9dd290a 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -50,7 +50,7 @@
 from .numeric_util import round_up_to_int
 from .operation import NpuBlockType
 from .shared_buffer_allocation import SharedBufferAllocation
-from .tensor import MemArea
+from .tensor import MemType
 from .tensor import TensorBlockTraversal
 from .tensor import TensorFormat
 
@@ -79,8 +79,9 @@
 
 
 class BasePointerIndex(IntEnum):
-    ReadOnly = 0  # base address slot index for weights and scaling
-    Scratch = 1  # base address slot index for scratch memory area
+    WeightTensor = 0  # base address index for the Weight tensor
+    ScratchTensor = 1  # base address index for the Scratch_tensor in the TensorArena
+    ScratchFastTensor = 2  # base address for the Scratch_fast_tensor
 
 
 # TODO: Replace with definitions from ethos_u55_regs
@@ -322,12 +323,20 @@
 def generate_register_command_stream(nng, sg, arch, verbose=False):
     emit = CommandStreamEmitter()
 
-    base_ptr_idx_map = {
-        MemArea.Sram: BasePointerIndex.Scratch,
-        MemArea.OnChipFlash: BasePointerIndex.ReadOnly,
-        MemArea.OffChipFlash: BasePointerIndex.ReadOnly,
-        MemArea.Dram: BasePointerIndex.ReadOnly,
-    }
+    if arch.feature_map_storage_mem_area == arch.fast_storage_mem_area:
+        base_ptr_idx_map = {
+            MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
+            MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
+            MemType.Scratch: BasePointerIndex.ScratchTensor,
+            MemType.Scratch_fast: BasePointerIndex.ScratchTensor,
+        }
+    else:
+        base_ptr_idx_map = {
+            MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
+            MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
+            MemType.Scratch: BasePointerIndex.ScratchTensor,
+            MemType.Scratch_fast: BasePointerIndex.ScratchFastTensor,
+        }
 
     # Maps an AccumulatorType enum to the corresponding acc_format value
     acc_format_map = {
@@ -377,8 +386,8 @@
                 param = min(param, 0xFFFF)  # Clamp to allowable wait amount
 
         if relative_dep[CommandType.DMA] is not None:
-            param = relative_dep[CommandType.DMA][0]
-            param = min(param, 0xF)  # Clamp to allowable wait amount
+            # TODO This can be optimized for yoda
+            param = 0
             emit.cmd_wait(cmd0.NPU_OP_DMA_WAIT, param, absolute_dep[CommandType.DMA][0])
 
     for cmd in cmd_stream:
@@ -394,10 +403,10 @@
             else:
                 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
 
-            # TODO: Yoda support needs to use feature_maps_not_in_fast_storage and force_outputs_to_fast_storage
-            emit.cmd0_with_param(cmd0.NPU_SET_DMA0_SRC_REGION, base_ptr_idx_map[cmd.in_tensor.mem_area])
+            emit.cmd0_with_param(cmd0.NPU_SET_DMA0_SRC_REGION, base_ptr_idx_map[cmd.in_tensor.mem_type])
             emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_SRC, src_addr)
-            emit.cmd0_with_param(cmd0.NPU_SET_DMA0_DST_REGION, base_ptr_idx_map[cmd.out_tensor.mem_area])
+            emit.cmd0_with_param(cmd0.NPU_SET_DMA0_DST_REGION, base_ptr_idx_map[cmd.out_tensor.mem_type])
+
             emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_DST, dst_addr)
             emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_LEN, sz)
             dma_channel = 0
@@ -682,10 +691,7 @@
                 stream_index = cmd.weight_tensor.compressed_stream_index_from_coord(cmd.weight_box.start_coord)
                 weight_addr = cmd.weight_tensor.address_for_coordinate(cmd.weight_box.start_coord)
                 weight_len = cmd.weight_tensor.size_of_compressed_stream(stream_index)
-                # Select weight/scale region depending on where permanent storage was defined
-                weight_region = base_ptr_idx_map[cmd.weight_tensor.mem_area]
-                if arch.permanent_storage_mem_area == MemArea.Sram:
-                    weight_region = BasePointerIndex.ReadOnly
+                weight_region = base_ptr_idx_map[cmd.weight_tensor.mem_type]
                 emit.cmd0_with_param(cmd0.NPU_SET_WEIGHT_REGION, weight_region)
                 emit.cmd1_with_offset(cmd1.NPU_SET_WEIGHT_BASE, weight_addr)
                 emit.cmd1_with_offset(cmd1.NPU_SET_WEIGHT_LENGTH, weight_len)
@@ -699,9 +705,7 @@
                         cmd.scale_tensor.address_for_coordinate(cmd.weight_box.end_coord[-1:], True) - scale_addr
                     )
                     # Emit base address for NPU to access scale & bias data
-                    scale_region = base_ptr_idx_map[cmd.scale_tensor.mem_area]
-                    if arch.permanent_storage_mem_area == MemArea.Sram:
-                        scale_region = BasePointerIndex.ReadOnly
+                    scale_region = base_ptr_idx_map[cmd.scale_tensor.mem_type]
                     emit.cmd0_with_param(cmd0.NPU_SET_SCALE_REGION, scale_region)
                     emit.cmd1_with_offset(cmd1.NPU_SET_SCALE_BASE, scale_addr)
                     emit.cmd1_with_offset(cmd1.NPU_SET_SCALE_LENGTH, round_up(scale_len, 16))
@@ -850,10 +854,7 @@
                     else:
                         assert False
 
-                if tens.mem_area == MemArea.Sram:
-                    emit.cmd0_with_param(region_op, BasePointerIndex.Scratch)
-                else:
-                    emit.cmd0_with_param(region_op, BasePointerIndex.ReadOnly)
+                emit.cmd0_with_param(region_op, base_ptr_idx_map[tens.mem_type])
 
                 for idx, addr in enumerate(addresses):
                     if addr is None:
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 0b59431..be104b8 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -38,6 +38,7 @@
 from .shared_buffer_allocation import find_block_configs_suitable_for_pass_and_shared_buffer
 from .shared_buffer_allocation import shared_buffer_allocation_for_pass_and_block_config
 from .tensor import MemArea
+from .tensor import MemType
 from .tensor import TensorFormat
 from .tensor import TensorPurpose
 from .tensor import TensorSubPurpose
@@ -833,6 +834,7 @@
             for rewrite_op, tens, sub_purpose, param_a, param_b, ps in strat.rewrite_list:
                 if rewrite_op == SchedulerRewrite.ChangeTensorSubPurpose:
                     tens.mem_area = self.arch.fast_storage_mem_area
+                    tens.mem_type = MemType.Scratch_fast
                     tens.set_new_sub_purpose(sub_purpose, param_a, param_b)
                 else:
                     assert 0, "unknown rewrite_op " + str(rewrite_op)
diff --git a/ethosu/vela/stats_writer.py b/ethosu/vela/stats_writer.py
index 9bbb9db..c90d987 100644
--- a/ethosu/vela/stats_writer.py
+++ b/ethosu/vela/stats_writer.py
@@ -201,7 +201,10 @@
                             for k in indices[2]:
                                 res += round_up_to_int(ps.bandwidths[i, j, k])
                         stats.append(res)
-                    stats += [ps.sram_used]
+                    try:
+                        stats += [ps.sram_used]
+                    except AttributeError:
+                        stats += [0]
 
                     writer.writerow(stats)
 
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 42d9526..3990164 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -26,6 +26,27 @@
 from .range_set import MemoryRangeSet
 
 
+class MemType(enum.IntFlag):
+    Unknown = 0
+    Permanent_NPU = 1
+    Permanent_CPU = 2
+    Scratch = 3
+    Scratch_fast = 4
+    Size = Scratch_fast + 1
+
+    def display_name(self):
+        return ("Unknown", "Permanent_NPU", "Permanent_CPU", "Scratch", "Scratch_fast", "Size")[self.value]
+
+    def identifier_name(self):
+        return ("unknown", "permanent_npu", "permanent_cpu", "scratch", "scratch_fast", "size")[self.value]
+
+    def all():
+        return (MemType.Permanent_NPU, MemType.Permanent_CPU, MemType.Scratch, MemType.Scratch_fast)
+
+    def __str__(self):
+        return self.name
+
+
 class MemArea(enum.IntFlag):
     Unknown = 0
     Sram = 1
@@ -209,6 +230,7 @@
         "quant_values",
         "compressed_values",
         "mem_area",
+        "mem_type",
         "format",
         "purpose",
         "sub_purpose",
@@ -252,6 +274,7 @@
         self.quant_values = None
         self.compressed_values = None
         self.mem_area = MemArea.Unknown
+        self.mem_type = MemType.Unknown
         self.format = TensorFormat.Unknown
         self.purpose = TensorPurpose.Unknown
         self.sub_purpose = TensorSubPurpose.Standard
@@ -291,6 +314,7 @@
         res.values = self.values
         res.quant_values = self.quant_values
         res.mem_area = self.mem_area
+        res.mem_type = self.mem_type
         res.format = self.format
         res.purpose = self.purpose
         res.sub_purpose = self.sub_purpose
@@ -312,6 +336,7 @@
     def clone_into_fast_storage(self, arch):
         res = self.clone(suffix="_fast_storage")
         res.mem_area = arch.fast_storage_mem_area
+        res.mem_type = MemType.Scratch_fast
         return res
 
     def copy_compressed_weight_info(self, src_tens):
@@ -641,6 +666,11 @@
         assert address_offset <= self.storage_size()
         return address_offset
 
+    def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area):
+        if self.mem_area == scratch_tensor_mem_area and (self.mem_type in set((MemType.Scratch, MemType.Scratch_fast))):
+            return True
+        return False
+
     def __str__(self):
         return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)
 
diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py
index e3952df..f29296d 100644
--- a/ethosu/vela/tensor_allocation.py
+++ b/ethosu/vela/tensor_allocation.py
@@ -25,6 +25,7 @@
 from .greedy_allocation import allocate_live_ranges as greedy_allocate_live_ranges
 from .nn_graph import TensorAllocator
 from .tensor import MemArea
+from .tensor import MemType
 
 
 def linear_allocate_live_ranges(live_ranges, alloc_granularity=16):
@@ -66,12 +67,13 @@
             ps.sram_used = sram_used
 
 
-def print_allocation(lrs, mem_area, sg, verbose_allocation, show_minimum_possible_allocation):
+def print_allocation(lrs, mem_area, mem_type_set, sg, verbose_allocation, show_minimum_possible_allocation):
     if verbose_allocation:
-        if mem_area == MemArea.Sram:
-            print("allocation for", mem_area, "- non-constant tensors in Cpu and Npu subgraphs")
-        else:
+        if mem_type_set == set((MemType.Permanent_NPU,)) or mem_type_set == set((MemType.Permanent_CPU,)):
             print("allocation for", mem_area, "- constant tensors in", sg.placement.name, "subgraph(s)")
+        else:
+            print("allocation for", mem_area, "- non-constant tensors in Cpu and Npu subgraphs")
+
         for start_time, start, end, name, end_time in sorted(
             (
                 lr.start_time,
@@ -99,6 +101,7 @@
     sg,
     arch,
     mem_area,
+    mem_type_set,
     use_ifm_ofm_overlap=True,
     tensor_allocator=TensorAllocator.Greedy,
     verbose_allocation=False,
@@ -109,6 +112,7 @@
     lrs = live_range.extract_live_ranges_from_cascaded_passes(
         sg,
         mem_area,
+        mem_type_set,
         mark_output_tensors_overlapping_with_input_tensors=False,
         use_ifm_ofm_overlap=use_ifm_ofm_overlap,
         ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
@@ -120,16 +124,26 @@
         if tens_alloc == TensorAllocator.Greedy:
             total_sz = greedy_allocate_live_ranges(sg, arch, lrs, mem_area, verbose_allocation)
         elif tens_alloc == TensorAllocator.LinearAlloc:
-            total_sz = linear_allocate_live_ranges(lrs)
+            total_sz = linear_allocate_live_ranges(lrs, 16)
         else:
             assert 0
 
-        sg.memory_used[mem_area] = total_sz
+        if sg.memory_used.get(mem_area, 0) == 0:
+            sg.memory_used[mem_area] = total_sz
+        else:
+            sg.memory_used[mem_area] += total_sz
+
+        # Keep track of how much should be used for scratch or permanent storage for NPU
+        for mem_type in mem_type_set:
+            if sg.memory_used_per_type.get(mem_type, 0) == 0:
+                sg.memory_used_per_type[mem_type] = total_sz
+            else:
+                sg.memory_used_per_type[mem_type] += total_sz
 
         nng.total_size[mem_area] = nng.total_size.get(mem_area, 0) + sum(tens.storage_size() for tens in lrs.ranges)
         nng.total_elements[mem_area] = nng.total_elements.get(mem_area, 0) + sum(tens.elements() for tens in lrs.ranges)
 
-        print_allocation(lrs, mem_area, sg, verbose_allocation, show_minimum_possible_allocation)
+        print_allocation(lrs, mem_area, mem_type_set, sg, verbose_allocation, show_minimum_possible_allocation)
 
         if mem_area == MemArea.Sram:
             # Mark Sram usage for all subgraphs
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 8db3e5b..7e805e3 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -22,7 +22,7 @@
 from flatbuffers.builder import UOffsetTFlags
 
 from .nn_graph import PassPlacement
-from .tensor import MemArea
+from .tensor import MemType
 from .tensor import TensorPurpose
 from .tflite import Buffer
 from .tflite import Metadata
@@ -74,6 +74,7 @@
         self.nng = nng
 
         self.scratch_buf_id = 0  # Always assign scratch to buffer 0
+        self.scratch_fast_buf_id = 1  # Always assign scratch_fast to buffer 1
         self.buffer_offsets_map = {}
         self.buffers_to_write = []  # have an empty array there
 
@@ -140,11 +141,16 @@
             scratch_tensor_mem_area = None  # all tensors are initialised to MemArea.Unknown
 
         buffer_map = {}
+
         buf_idx = 1
 
         for tens in tensors:
-            if tens.mem_area == scratch_tensor_mem_area:
+            # Set buffer ids depending on allocation
+            if tens.is_allocated_in_tensor_arena(scratch_tensor_mem_area):
                 buffer_map[tens] = self.scratch_buf_id
+            elif tens.mem_type == MemType.Scratch_fast:
+                # For Scratch_fast when not co-allocated with scratch in the TensorArena:
+                buffer_map[tens] = self.scratch_fast_buf_id
             else:
                 buffer_map[tens] = buf_idx
                 buf_idx += 1
@@ -229,11 +235,9 @@
 
         if tens.purpose == TensorPurpose.Scratch:
             tens_shape = [0]
-            self.buffers_to_write[self.scratch_buf_id] = values.flatten().view(np.uint8)
 
         buf_id = self.buffer_map[tens]
-        if buf_id != self.scratch_buf_id:
-            self.buffers_to_write[buf_id] = values.flatten().view(np.uint8)
+        self.buffers_to_write[buf_id] = values.flatten().view(np.uint8)
 
         shape = self.write_int_vector(tens_shape)
 
@@ -396,7 +400,8 @@
 
         # Ensure that the order of the offsets match the order of the tensors
         for tens, idx in self.tensor_map.items():
-            if tens.mem_area == MemArea.Sram:
+            # Set offsets for tensor allocated in Tensor Arena or in the scratch_fast area
+            if tens.mem_type in set((MemType.Scratch, MemType.Scratch_fast)):
                 offsets[idx] = np.int32(tens.address)
 
         metadata_buffer = np.array([version, subgraph_idx, nbr_tensors] + offsets)