MLBEDSW-6263: Use separate tensors for double buffering

Uses separate tensors for the individual weight buffers
in case of weight double buffering.

Each weight buffer tensor gets its own individual live range.

This patch is a clone of a previously reverted patch, but with some
additional bug fixes applied.

Signed-off-by: Rickard Bolin <rickard.bolin@arm.com>
Change-Id: I868c70d15821eb9f1399186f2da6e7345f6ee343
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
index 4703583..e7105e2 100644
--- a/ethosu/vela/cascade_builder.py
+++ b/ethosu/vela/cascade_builder.py
@@ -146,10 +146,8 @@
             # Keep track of which Ops are in the proposed cascade as well as the best cascade so far
             ops_in_cascade = [op]
             ops_in_best_cascade = [op]
-            # Get the size of the weight buffer
-            weight_buffer = 0
-            if ref_cost[op].buffered_weight_tensor:
-                weight_buffer = ref_cost[op].buffered_weight_tensor.storage_size()
+            # Get the size of the weight buffer(s)
+            weight_buffer = sum(tens.storage_size() for tens in ref_cost[op].buffered_weight_tensors)
 
             # The first IFM needs to be stored in full
             cascade_ifm_size = op.ifm_size_in_bytes() if not self.spilling else 0
@@ -192,10 +190,8 @@
                 op_full_ofm = current_op.ofm_size_in_bytes()
                 _, op_ifm_buffer = buffers.get_buffer(producer, current_op, ref_cost)
 
-                # Get the size of the weight buffer
-                op_weight_buffer = 0
-                if ref_cost[current_op].buffered_weight_tensor:
-                    op_weight_buffer = ref_cost[current_op].buffered_weight_tensor.storage_size()
+                # Get the size of the weight buffer(s)
+                op_weight_buffer = sum(tens.storage_size() for tens in ref_cost[current_op].buffered_weight_tensors)
 
                 # Calculate the uncascaded memory requirement for current Op
                 uncascaded_sram_usage = op_full_ifm + op_full_ofm + self.non_local_mem_usage.get(current_op, 0)
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index 136f5a9..81c0d5b 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -204,9 +204,12 @@
                 if op_info.npu_weights_tensor:
                     weight_box = Box([0, 0, 0, start_channel], [1, 1, 1, end_channel])
 
-                    if op_info.buffered_weight_tensor and is_first_h_stripe:
-                        yield from dma_if_necessary(sched_op.parent_ps, weight_box, op_info.buffered_weight_tensor)
-                        weight_tensor = op_info.buffered_weight_tensor
+                    if op_info.buffered_weight_tensors and is_first_h_stripe:
+                        idx = depth_idx % len(op_info.buffered_weight_tensors)
+                        yield from dma_if_necessary(
+                            sched_op.parent_ps, weight_box, op_info.buffered_weight_tensors[idx]
+                        )
+                        weight_tensor = op_info.buffered_weight_tensors[idx]
                 else:
                     weight_box = None
 
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 3a78d6f..e6bfc1c 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -68,7 +68,6 @@
 from .tensor import Tensor
 from .tensor import TensorFormat
 from .tensor import TensorPurpose
-from .tensor import TensorSubPurpose
 from .weight_compressor import NpuWeightTensor
 from .weight_compressor import WeightKey
 
@@ -202,9 +201,15 @@
     return mem_limits
 
 
-def get_double_buffer_offset(arch: ArchitectureFeatures, range_index: int, core: int) -> int:
-    """Returns 0 if the first half of a double buffer should be used, 1 if the second half should be used"""
-    return ((range_index - core) // arch.ncores) % 2
+def get_upscale(op: Operation) -> NpuResamplingMode:
+    upscale = NpuResamplingMode.NONE
+    if op.type == Op.ResizeBilinear:
+        # perform nearest neighbor upscale
+        upscale = NpuResamplingMode.NEAREST
+    elif op.type == Op.Conv2DBackpropInputSwitchedBias:
+        # perform insert zero upscale
+        upscale = NpuResamplingMode.TRANSPOSE
+    return upscale
 
 
 def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
@@ -314,20 +319,13 @@
         key = WeightKey(core, weight_box.start_coord[-1])
         if key in w_tensor_src.encoded_ranges:
             weight_range = w_tensor_src.encoded_ranges[key]
-            if weight_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
-                assert weight_tensor != w_tensor_src
-                # Double buffered inside weight_tensor
-                address = weight_tensor.address + core_offset
-                address += get_double_buffer_offset(arch, weight_range.index, core) * w_tensor_src.max_range_bytes
-                core_offset += round_up(weight_range.total_bytes, 16)
+            if weight_tensor == w_tensor_src:
+                # Straight from source tensor
+                address = weight_tensor.address + weight_range.offset
             else:
-                if weight_tensor == w_tensor_src:
-                    # Straight from source tensor
-                    address = weight_tensor.address + weight_range.offset
-                else:
-                    # Single buffered inside weight tensor
-                    address = weight_tensor.address + core_offset
-                    core_offset += round_up(weight_range.total_bytes, 16)
+                # Weight buffered tensor
+                address = weight_tensor.address + core_offset
+                core_offset += round_up(weight_range.total_bytes, 16)
 
             # Location of weights in tensor
             addr_range = NpuAddressRange(
@@ -526,13 +524,7 @@
                 if core == 0:
                     weight_range = cmd.in_tensor.encoded_ranges[key]
                     src_addr = cmd.in_tensor.address + weight_range.offset
-
-                    if cmd.out_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
-                        dest_addr = cmd.out_tensor.address + cmd.in_tensor.max_range_bytes * (
-                            get_double_buffer_offset(arch, weight_range.index, core)
-                        )
-                    else:
-                        dest_addr = cmd.out_tensor.address
+                    dest_addr = cmd.out_tensor.address
     else:
         start_coord = cmd.box.start_coord
         src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 19d0c11..ccf4929 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -63,7 +63,7 @@
     def mark_usage(self, op_time, op_length=1):
         op_time_start = max(op_time, 0)
         op_time_end = op_time + op_length
-        if op_time_end <= op_time_start:
+        if op_time_end < op_time_start:
             return
 
         self.start_time = min(self.start_time, op_time_start)
@@ -325,13 +325,20 @@
 
             rng.mark_usage(time_to_set)
 
-        weight_tens = op_info.buffered_weight_tensor
-        if weight_tens and weight_tens.mem_type in target_mem_type_set and weight_tens.mem_area == target_mem_area:
-            rng = lr_graph.get_or_create_range(weight_tens)
-            if weight_tens.pre_buffer:
-                rng.mark_usage(time_to_set - 1, 2)
-            else:
-                rng.mark_usage(time_to_set)
+        for idx, weight_tens in enumerate(op_info.buffered_weight_tensors):
+            if weight_tens.mem_type in target_mem_type_set and weight_tens.mem_area == target_mem_area:
+                rng = lr_graph.get_or_create_range(weight_tens)
+                start_time = time_to_set
+                length = 1
+                if weight_tens.pre_buffer:
+                    start_time -= 1
+                    length += 1
+                if len(op_info.buffered_weight_tensors) > 1:
+                    last_idx = len(op_info.ofm_depth_slices) % len(op_info.buffered_weight_tensors)
+                    # Double buffering: reduce end time of the buffer that is not used last
+                    if last_idx != idx:
+                        length -= 1
+                rng.mark_usage(start_time, length)
 
         if time_to_set == lr_graph.current_time:
             lr_graph.current_time += 2
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index 81d0be7..0c8a907 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -620,8 +620,8 @@
     prev_cost = schedule.cost_map[prev_op] if prev_op else None
     if op.parent_op.bias:
         query.const_shape = Shape4D(1, 1, 1, op.ofm.shape.depth)
-        if cost.buffered_weight_tensor:
-            query.const_memory_area = cost.buffered_weight_tensor.mem_area
+        if cost.buffered_weight_tensors:
+            query.const_memory_area = cost.buffered_weight_tensors[0].mem_area
         else:
             query.const_memory_area = cost.npu_weights_tensor.mem_area
 
@@ -649,7 +649,7 @@
             # LUT read from SHRAM TODO remove?
             scaled_bws[lut_tensor.mem_area][lut_tensor.purpose][BandwidthDirection.Read] += bw
 
-    if cost.npu_weights_tensor and cost.buffered_weight_tensor:
+    if cost.npu_weights_tensor and cost.buffered_weight_tensors:
         # DMA Weight Transfer
         sz = 0
         # Get the size of the first DMA
@@ -661,10 +661,10 @@
 
         total_sz = len(cost.npu_weights_tensor.buffer)
         bws[cost.npu_weights_tensor.mem_area][TensorPurpose.Weights][BandwidthDirection.Read] += total_sz
-        bws[cost.buffered_weight_tensor.mem_area][TensorPurpose.Weights][BandwidthDirection.Write] += total_sz
+        bws[cost.buffered_weight_tensors[0].mem_area][TensorPurpose.Weights][BandwidthDirection.Write] += total_sz
 
         ws_first_transfer_cycles = measure_mem2mem_cycles(
-            arch, cost.npu_weights_tensor.mem_area, cost.buffered_weight_tensor.mem_area, sz
+            arch, cost.npu_weights_tensor.mem_area, cost.buffered_weight_tensors[0].mem_area, sz
         )
 
         # Add cycles for Weight + Scale Transfer
@@ -720,7 +720,7 @@
         bw = access.const_read[0] * bandwidth_compression_scale_approx
         bws[query.const_memory_area][TensorPurpose.Weights][BandwidthDirection.Read] += bw
 
-        if not cost.buffered_weight_tensor:
+        if not cost.buffered_weight_tensors:
             scaled_bws[query.const_memory_area][TensorPurpose.Weights][BandwidthDirection.Read] += bw
 
     if access.const_read[1] > 0:
@@ -728,7 +728,7 @@
         bw = access.const_read[1] * op.parent_op.bias.element_size()
         bws[query.const_memory_area][TensorPurpose.FSBias][BandwidthDirection.Read] += bw
 
-        if not cost.buffered_weight_tensor:
+        if not cost.buffered_weight_tensors:
             scaled_bws[query.const_memory_area][TensorPurpose.FSBias][BandwidthDirection.Read] += bw
 
     update_summary_cycles(arch, scaled_bws, cycles_a)
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 3cfde28..67b890e 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -107,7 +107,7 @@
         self.ofm_depth_slices: List[int] = [0, stripe.depth]
         self.npu_weights_tensor: Optional[NpuWeightTensor] = None
         self.npu_scales_tensor: Optional[NpuWeightTensor] = None
-        self.buffered_weight_tensor: Optional[Tensor] = None
+        self.buffered_weight_tensors: List[Tensor] = []
         self.cycles: Optional[CycleCost] = None
         self.slack_buffering_cycles = 0
         self.slack_buffering_memory = 0
@@ -131,9 +131,8 @@
         res += f"\t\tIFM2 Stripe  = {self.stripe_input2}\n"
         res += f"\t\tOFM Stripe   = {self.stripe}\n"
         res += f"\t\tEncoded Weights = {self.npu_weights_tensor and len(self.npu_weights_tensor.buffer)} bytes\n"
-        res += (
-            f"\t\tWeight buffer = {self.buffered_weight_tensor and self.buffered_weight_tensor.storage_size()} bytes\n"
-        )
+        for idx, tens in enumerate(self.buffered_weight_tensors):
+            res += f"\t\tWeight buffer{idx + 1} = {tens.storage_size()} bytes\n"
         res += f"\t\tDepth slices = {self.ofm_depth_slices}\n"
         res += f"\t\tAssigned Cascade = {self.cascade}"
         return res
@@ -734,7 +733,7 @@
                     # Chosen buffering might not fit at all, iterate until it does
                     # or until the minimum usable slice size is reached
                     if (
-                        encoded_weights.max_range_bytes <= half_buffer_limit
+                        encoded_weights.double_buffer_size() <= buffer_limit_bytes
                         or prebuffer_depth == ArchitectureFeatures.OFMSplitDepth
                     ):
                         break
@@ -751,24 +750,42 @@
                 cost.slack_buffering_cycles = tail_cycles.op_cycles
 
         # Determine whether the weights need to be double buffered
-        weight_buffer_size = min(len(encoded_weights.buffer), encoded_weights.max_range_bytes)
+        weight_buffer_size = min(len(encoded_weights.buffer), encoded_weights.max_range_bytes())
 
         # Only buffer weights if there's still space left for the buffer
         if weight_buffer_size <= buffer_limit_bytes:
             assert weight_buffer_size % 16 == 0
             # Determine whether to double buffer or single buffer
-            if (weight_buffer_size * 2 <= buffer_limit_bytes) and (weight_buffer_size < len(encoded_weights.buffer)):
-                weight_buffer_size = weight_buffer_size * 2
+            double_buffer_size = encoded_weights.double_buffer_size()
+            if (double_buffer_size <= buffer_limit_bytes) and (weight_buffer_size < len(encoded_weights.buffer)):
                 weight_tensor_purpose = TensorSubPurpose.DoubleBuffer
             else:
                 weight_tensor_purpose = TensorSubPurpose.Standard
 
-            cost.buffered_weight_tensor = self.buffer_tensor(
-                encoded_weights, weight_tensor_purpose, weight_buffer_size, weight_tensor.name
-            )
+            cost.buffered_weight_tensors = [
+                self.buffer_tensor(
+                    encoded_weights,
+                    weight_tensor_purpose,
+                    encoded_weights.double_buffer_sizes[0],
+                    weight_tensor.name + "_buffer",
+                )
+            ]
+            if weight_tensor_purpose == TensorSubPurpose.DoubleBuffer:
+                buf2 = self.buffer_tensor(
+                    encoded_weights,
+                    weight_tensor_purpose,
+                    encoded_weights.double_buffer_sizes[1],
+                    weight_tensor.name + "_buffer2",
+                )
+                cost.buffered_weight_tensors.append(buf2)
+
+            last_used_buffer_idx = len(cost.ofm_depth_slices) % len(cost.buffered_weight_tensors)
+            weight_buffer_size = encoded_weights.double_buffer_sizes[last_used_buffer_idx]
+
             if ref_cost.cascade == 0:
-                # Determine if the lifetime can be extended and pre-buffer weights under the previous operation
-                cost.buffered_weight_tensor.pre_buffer = weight_buffer_size < slack_memory
+                # Determine if the lifetime can be extended and pre-buffer the first weight buffer
+                # under the previous operation
+                cost.buffered_weight_tensors[0].pre_buffer = encoded_weights.double_buffer_size() < slack_memory
 
             cost.slack_buffering_memory -= weight_buffer_size
         else:
@@ -781,7 +798,7 @@
         cost.npu_scales_tensor = encoded_scales
 
     def buffer_tensor(self, src_tensor: Tensor, sub_purpose: TensorSubPurpose, buffer_size: int, name: str) -> Tensor:
-        buffered_weight_tensor = Tensor([1, 1, 1, buffer_size], DataType.uint8, name + "_buffer")
+        buffered_weight_tensor = Tensor([1, 1, 1, buffer_size], DataType.uint8, name)
         buffered_weight_tensor.src_tensor = src_tensor
         buffered_weight_tensor.mem_area = self.arch.fast_storage_mem_area
         buffered_weight_tensor.mem_type = MemType.Scratch_fast
@@ -823,11 +840,16 @@
             # Create a cost entry with the new stripe
             cost = sched_op.create_scheduler_info(self.nng, stripe)
 
-            if ref_cost[sched_op].buffered_weight_tensor:
+            weight_tensor = cost.npu_weights_tensor
+            for idx, buffered_tens in enumerate(ref_cost[sched_op].buffered_weight_tensors):
                 # If the weights are buffered in the reference schedule they should be in the new proposal
-                weight_tensor = cost.npu_weights_tensor
-                cost.buffered_weight_tensor = self.buffer_tensor(
-                    weight_tensor, TensorSubPurpose.Standard, len(weight_tensor.buffer), weight_tensor.name
+                cost.buffered_weight_tensors.append(
+                    self.buffer_tensor(
+                        weight_tensor,
+                        buffered_tens.sub_purpose,
+                        weight_tensor.double_buffer_sizes[idx],
+                        buffered_tens.name,
+                    )
                 )
 
             # Estimate performance
@@ -856,9 +878,7 @@
                 peak_mem_usage = max(cascade_info.mem_usage, peak_mem_usage)
             else:
                 # This Op is not part of a cascade - calculate the memory usage
-                op_weight_buffer = 0
-                if cost[sched_op].buffered_weight_tensor:
-                    op_weight_buffer = cost[sched_op].buffered_weight_tensor.storage_size()
+                op_weight_buffer = sum(tens.storage_size() for tens in cost[sched_op].buffered_weight_tensors)
 
                 op_mem_usage = (
                     sched_op.ifm_size_in_bytes()
@@ -1013,8 +1033,8 @@
         weight_ops = {}
         for sched_op in self.sched_ops:
             cost = self.sg.schedule.cost_map[sched_op]
-            if cost.buffered_weight_tensor:
-                weight_ops[cost.buffered_weight_tensor] = sched_op
+            for tens in cost.buffered_weight_tensors:
+                weight_ops[tens] = sched_op
 
         # Filter out weight buffer live ranges
         weight_lrs = []
@@ -1088,8 +1108,8 @@
             sched_op.parent_ps.block_config = op_info.block_config.old_style_representation()
 
             # Ensure that the src_tensor reference is set correctly
-            if op_info.buffered_weight_tensor:
-                op_info.buffered_weight_tensor.src_tensor = op_info.npu_weights_tensor
+            for tens in op_info.buffered_weight_tensors:
+                tens.src_tensor = op_info.npu_weights_tensor
 
     def use_fast_storage_for_feature_maps(self, schedule, staging_limit):
         max_mem_usage = []
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 86b424a..78c4351 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -68,12 +68,19 @@
     def __init__(self, name):
         Tensor.__init__(self, None, None, name + "_npu_encoded_weights")
         self.buffer = []
-        self.max_range_bytes = 0
+        self.double_buffer_sizes = [0, 0]  # Required sizes if double buffering is used
         self.encoded_ranges = OrderedDict()
         self.hw_traversal = NpuBlockTraversal.DEPTH_FIRST
         self.dtype = DataType.uint8
         self.scale_compression_config = None
 
+    def max_range_bytes(self):
+        return max(self.double_buffer_sizes)
+
+    def double_buffer_size(self):
+        """Return total required size for double buffering"""
+        return sum(self.double_buffer_sizes)
+
 
 class CompressedWeightCache:
     """Global tensor weight compression cache"""
@@ -357,7 +364,7 @@
             weights = np.flip(weights, axis=(0, 1))
 
     encoded_stream = bytearray()
-    max_single_buffer_len = 0
+    double_buffer_sizes = [0, 0]
     is_depthwise = npu_block_type == NpuBlockType.ConvolutionDepthWise
 
     # Bias & scale
@@ -435,11 +442,11 @@
                 npu_tensor.encoded_ranges[key] = weight_range
 
         # Remember maximum encoded length for DoubleBuffering
-        max_single_buffer_len = max(max_single_buffer_len, len(encoded_stream) - buffer_start_offset)
+        double_buffer_sizes[idx % 2] = max(double_buffer_sizes[idx % 2], len(encoded_stream) - buffer_start_offset)
 
     # Attach buffer to tensor
     npu_tensor.buffer = encoded_stream
-    npu_tensor.max_range_bytes = max_single_buffer_len
+    npu_tensor.double_buffer_sizes = double_buffer_sizes
     npu_tensor.set_all_shapes([1, 1, 1, len(encoded_stream)])
     npu_tensor.format = TensorFormat.WeightsCompressed