MLBEDSW-4034: New Scheduler Size or Performance Optimisation

 - Merged dev/scheduler at 83639f90e8c828f70de6e29142355a940224959b

Signed-off-by: Tim Hall <tim.hall@arm.com>
Change-Id: I0050529d4b42da93768c7264296434dd877fb5b4
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 9a1d5a1..652d016 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -16,6 +16,7 @@
 # Description:
 # Compresses and pads the weigths. It also calculates the scales and packs with the biases.
 from collections import namedtuple
+from collections import OrderedDict
 from typing import Tuple
 
 import numpy as np
@@ -25,27 +26,85 @@
 from .architecture_features import ArchitectureFeatures
 from .data_type import DataType
 from .errors import UnsupportedFeatureError
-from .nn_graph import SchedulingStrategy
 from .numeric_util import round_up
-from .numeric_util import round_up_divide
 from .operation import NpuBlockType
 from .operation import Op
 from .scaling import quantise_scale
 from .scaling import reduced_quantise_scale
-from .tensor import create_equivalence_id
-from .tensor import TensorBlockTraversal
+from .tensor import Tensor
 from .tensor import TensorFormat
 from .tensor import TensorPurpose
-from .tensor import TensorSubPurpose
 from ethosu import mlw_codec
 
 
 # Contains meta info for a weight compression. If two tensors have identical weight compression config,
 # then they also will have identical compressed weights.
 WeightCompressionConfig = namedtuple(
-    "WeightCompressionConfig", ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "dilation", "value_id"]
+    "WeightCompressionConfig",
+    ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "dilation", "weight_value_id", "scale_value_id"],
 )
 
+WeightKey = namedtuple("WeightKey", ["core", "depth"])
+
+
+class WeightRange:
+    def __init__(self):
+        self.offset = 0
+        self.scale_bytes = 0
+        self.weight_offset = 0
+        self.weight_bytes = 0
+        self.index = 0
+
+    @property
+    def total_bytes(self):
+        return self.scale_bytes + self.weight_bytes
+
+
+class NpuWeightTensor(Tensor):
+    def __init__(self, name):
+        Tensor.__init__(self, None, None, name + "_npu_encoded_weights")
+        self.buffer = []
+        self.max_range_bytes = 0
+        self.encoded_ranges = OrderedDict()
+        self.hw_traversal = NpuBlockTraversal.DEPTH_FIRST
+        self.dtype = DataType.uint8
+
+
+class CompressedWeightCache:
+    """Global tensor weight compression cache"""
+
+    cache = {}
+
+    @staticmethod
+    def get_tensor_with_same_compression(wcc):
+        return CompressedWeightCache.cache.get(wcc)
+
+    @staticmethod
+    def add(tens):
+        # Adds the compressed weights from the tensor to the cache
+        wcc = tens.weight_compression_config
+        CompressedWeightCache.cache[wcc] = tens
+
+    @staticmethod
+    def has_tensor_with_same_compression(wcc):
+        return wcc in CompressedWeightCache.cache
+
+    @staticmethod
+    def get_unencoded_size_with_same_compression(wcc):
+        cache_obj = CompressedWeightCache.cache.get(wcc)
+        return cache_obj[1] if cache_obj else None
+
+
+def create_weight_compression_config(
+    weight_tens, scale_tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation
+):
+    # Note: for an ofm block only its depth is used in weight compression.
+    # And block depth > ofm depth gives same result as block depth == ofm depth
+    block_depth = min(ofm_block_depth, weight_tens.quant_values.shape[-1])
+    return WeightCompressionConfig(
+        npu_block_type, block_depth, ofm_depth_step, dilation, weight_tens.value_id, scale_tens.value_id
+    )
+
 
 def encode_weights(
     accelerator: Accelerator,
@@ -140,185 +199,13 @@
     return data
 
 
-def create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation):
-    # Note: for an ofm block only its depth is used in weight compression.
-    # And block depth > ofm depth gives same result as block depth == ofm depth
-    block_depth = min(ofm_block_depth, tens.quant_values.shape[-1])
-    return WeightCompressionConfig(npu_block_type, block_depth, ofm_depth_step, dilation, tens.value_id)
-
-
-def set_storage_shape(tens):
-    # Sets the storage shape depending on the tensor's sub purpose
-    if tens.sub_purpose == TensorSubPurpose.DoubleBuffer and len(tens.compressed_values) > 2:
-        offset = 2 * np.amax([len(x) for x in tens.compressed_values])
-        assert offset % 16 == 0
-    else:
-        offset = tens.weight_compressed_offsets[-1]
-    tens.storage_shape = [1, 1, 1, offset]
-
-
-class CompressedWeightCache:
-    # Contains weight compressions for all weight tensors in a graph
-    def __init__(self):
-        self.cache = {}  # maps from WeightCompressionConfig to a tensor clone containing compressed weights
-
-    def has_tensor_with_same_compression(self, wcc):
-        return self.cache.get(wcc) is not None
-
-    def get_tensor_with_same_compression(self, wcc):
-        cache_obj = self.cache.get(wcc)
-        return cache_obj[0] if cache_obj else None
-
-    def get_unencoded_size_with_same_compression(self, wcc):
-        cache_obj = self.cache.get(wcc)
-        return cache_obj[1] if cache_obj else None
-
-    def add(self, tens, unencoded_size):
-        # Adds the compressed weights from the tensor to the cache
-        wcc = tens.weight_compression_config
-        # Clone the tensor to make sure that nothing related to the weight compression is modified
-        tens_clone = tens.clone("_weights{}_{}".format(wcc.ofm_block_depth, wcc.ofm_depth_step))
-        self.cache[wcc] = (tens_clone, unencoded_size)
-
-
 def core_deinterleave(hwio, core, ncores):
     # Put weights back into OHWI
     ohwi = np.transpose(hwio, (3, 0, 1, 2))
     return ohwi[core : ohwi.shape[0] : ncores]
 
 
-# Compress the weights
-def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation):
-    assert tens.purpose == TensorPurpose.Weights
-
-    # Check the weight cache
-    if nng.weight_cache is None:
-        nng.weight_cache = CompressedWeightCache()
-    wcc = create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation)
-    tens.weight_compression_config = wcc
-    # Reassign equivalence id such that tensors with same weight compression get identical equivalence ids,
-    # but tensors with the same values but different compression get different equivalence ids
-    tens.equivalence_id = create_equivalence_id(wcc)
-    tens_cached = nng.weight_cache.get_tensor_with_same_compression(wcc)
-    if tens_cached is not None:
-        # Cache hit, copy weights from the cache
-        tens.copy_compressed_weight_info(tens_cached)
-        set_storage_shape(tens)
-        return nng.weight_cache.get_unencoded_size_with_same_compression(wcc)
-    # No cache hit, perform the compression
-    assert tens.quantization is not None
-    assert tens.quantization.scale_f32 is not None
-    assert tens.quantization.zero_point is not None
-
-    zero_point = tens.quantization.zero_point
-    quant_buf = tens.quant_values.astype(np.int64)
-
-    # Early zero-point correction
-    weights = quant_buf - zero_point
-
-    if len(weights.shape) == 2:
-        weights = np.expand_dims(np.expand_dims(weights, axis=0), axis=0)
-
-    compression_scales = []
-    compressed_offsets = []
-    encoded_streams = []
-    encoded_streams_substream_offsets = []
-    offset = 0
-    max_single_buffer_len = 0
-    unencoded_size = 0
-
-    ifm_bitdepth = tens.consumer_list[0].inputs[0].dtype.size_in_bits()
-    ifm_depth = weights.shape[-2]
-    if npu_block_type == NpuBlockType.ConvolutionDepthWise:
-        tens.block_traversal = TensorBlockTraversal.DepthWise
-    if npu_block_type == NpuBlockType.ConvolutionMxN:
-        # Determine which block traversal strategy has better DPU utilization
-        kernel_size = weights.shape[0] * weights.shape[1]
-        depth_utilization = weights.shape[2] / round_up(weights.shape[2], 32 if ifm_bitdepth == 8 else 16)
-        part_kernel_utilization = (weights.shape[2] / round_up(weights.shape[2], 8)) * (
-            kernel_size / round_up(kernel_size, 4 if ifm_bitdepth == 8 else 2)
-        )
-        if part_kernel_utilization >= depth_utilization or ifm_depth <= 8:
-            # Part-kernel first is always better for ifm depths <= 8
-            tens.block_traversal = TensorBlockTraversal.PartKernelFirst
-        else:
-            tens.block_traversal = TensorBlockTraversal.DepthFirst
-
-    is_depthwise = tens.block_traversal == TensorBlockTraversal.DepthWise
-    if tens.block_traversal == TensorBlockTraversal.PartKernelFirst:
-        block_traversal = NpuBlockTraversal.PART_KERNEL_FIRST
-    else:
-        block_traversal = NpuBlockTraversal.DEPTH_FIRST
-
-    if tens.consumer_list[0].type == Op.Conv2DBackpropInputSwitchedBias:
-        # Transpose Convoluion, reverse weights in H and W axes
-        weights = np.flip(weights, axis=(0, 1))
-
-    # Calculate brick size
-    brick_size = (weights.shape[0], weights.shape[1], weights.shape[2], min(tens.shape[-1], ofm_depth_step))
-    elements_in_brick = np.prod(brick_size)
-
-    # Slice weight stream up depth-ways into bricks and compress
-    full_ofm_depth = quant_buf.shape[-1]
-    for idx in range(0, full_ofm_depth, ofm_depth_step):
-        # Get the weights necessary for this brick
-        count = min(full_ofm_depth - idx, ofm_depth_step)
-        brick_weights = weights[:, :, :, idx : idx + count]
-
-        substream_offsets = [0]
-        encoded_stream = []
-
-        # For each core, deinterleave weights from the larger volume
-        # and generate separate compressed streams.
-        for core in range(0, min(arch.ncores, full_ofm_depth)):
-            core_weights = core_deinterleave(brick_weights, core, arch.ncores)
-
-            block_depth = (ofm_block_depth + arch.ncores - 1 - core) // arch.ncores
-            encoded_substream = []
-            if block_depth != 0:
-                encoded_substream, raw_stream_size = encode_weights(
-                    accelerator=arch.accelerator_config,
-                    weights_volume=core_weights,
-                    dilation_xy=dilation,
-                    ifm_bitdepth=ifm_bitdepth,
-                    ofm_block_depth=block_depth,
-                    is_depthwise=is_depthwise,
-                    block_traversal=block_traversal,
-                )
-                unencoded_size += raw_stream_size
-            encoded_stream.extend(encoded_substream)
-            substream_offsets.append(len(encoded_stream))
-
-        encoded_streams.append(encoded_stream)
-        encoded_streams_substream_offsets.append(substream_offsets)
-
-        # Remember maximum encoded length for DoubleBuffering
-        max_single_buffer_len = max(max_single_buffer_len, len(encoded_stream))
-
-        # Remember where we put it for linear addressing
-        compressed_offsets.append(offset)
-        offset += len(encoded_stream)
-        assert offset % 16 == 0
-
-        # Compression scale tracking
-        compression_scales.append(len(encoded_stream) / elements_in_brick)
-
-    # Track total length as last element of the offsets array
-    compressed_offsets.append(offset)
-
-    tens.weight_compression_scales = compression_scales
-    tens.weight_compressed_offsets = compressed_offsets
-    tens.compression_scale_for_worst_weight_stream = np.amax(compression_scales)
-    tens.storage_compression_scale = tens.bandwidth_compression_scale = np.average(compression_scales)
-    tens.compressed_values = encoded_streams
-    tens.compressed_values_substream_offsets = encoded_streams_substream_offsets
-    tens.brick_size = brick_size
-    set_storage_shape(tens)
-    nng.weight_cache.add(tens, unencoded_size)
-    return unencoded_size
-
-
-def calc_scales_and_pack_biases(tens, arch, ofm_depth_step, rescale_for_faf=False):
+def _prepare_scale_and_bias(arch, tens, rescale_for_faf):
     assert tens.purpose in [TensorPurpose.FeatureMap, TensorPurpose.FSBias]
     assert tens.format == TensorFormat.NHWC
     # the connected operator should expect a bias input unless it is a FullyConnected
@@ -381,79 +268,157 @@
     else:
         quantised_scales = [quantise_scale(scale) for scale in scales]
 
-    # pack the biases and scales
+    # If only 1 quantised scale is used, repeat that value for the length of the biases
     if len(quantised_scales) == 1:
-        # If only 1 quantised scale is used, repeat that value for the length of the biases
         quantised_scales = [quantised_scales[0]] * len(biases)
 
-    assert len(quantised_scales) == len(biases)
-    tens.element_size_bytes = 10
-    tens.compressed_values = []
-    tens.compressed_values_substream_offsets = []
-
-    total_elements = len(quantised_scales)
-    alignment_bytes = 0
-    for i in range(0, total_elements, ofm_depth_step):
-        # Extract streams from brick to generate substreams for each core
-        stream = bytearray()
-        substream_offsets = [0]
-        max_len = min(ofm_depth_step, total_elements - i)
-        for core in range(0, min(arch.ncores, max_len)):
-            core_scales = quantised_scales[i + core : i + core + max_len : arch.ncores]
-            core_biases = biases[i + core : i + core + max_len : arch.ncores]
-            for j, core_bias in enumerate(core_biases):
-                stream.extend(encode_bias(np.int64(core_bias), *core_scales[j]))
-
-            # Align to 16 for start for next substream
-            remainder = (len(stream)) % 16
-            if remainder > 0:
-                stream.extend(bytearray(16 - remainder))
-                alignment_bytes += 16 - remainder
-
-            substream_offsets.append(len(stream))
-
-        # Add to compressed values with their substream offset lists to the tensor
-        tens.compressed_values.append(stream)
-        tens.compressed_values_substream_offsets.append(substream_offsets)
-
-    tens.storage_shape = [total_elements + round_up_divide(alignment_bytes, tens.element_size_bytes)]
+    return quantised_scales, biases
 
 
-def update_pass_weight_and_scale_tensors(nng, arch):
-    for sg in nng.subgraphs:
-        for ps in sg.passes:
-            tens = ps.weight_tensor
-            if tens is not None:
-                op = tens.find_npu_op()
-                if op is None:
-                    continue
-                needs_dma = tens.needs_dma()
-                if ps.cascade.strategy == SchedulingStrategy.WeightStream and needs_dma:
-                    ofm_depth_step = ps.block_config[-1]
-                else:
-                    ofm_depth_step = tens.shape[-1]
-                nng.total_npu_weights += compress_weights(
-                    arch, nng, tens, op.type.npu_block_type, ps.block_config[-1], ofm_depth_step, op.get_dilation_h_w()
+def encode_weight_and_scale_tensor(
+    arch, op, weight_tens, scale_tens, kernel, block_config, depth_offsets, rescale_for_faf=False
+) -> NpuWeightTensor:
+    npu_block_type = op.type.npu_block_type
+
+    wcc = create_weight_compression_config(
+        weight_tens, scale_tens, npu_block_type, block_config.ofm_block.depth, hash(str(depth_offsets)), kernel.dilation
+    )
+
+    tens_cached = CompressedWeightCache.get_tensor_with_same_compression(wcc)
+    if tens_cached is not None:
+        return tens_cached
+
+    npu_tensor = NpuWeightTensor(weight_tens.name)
+    npu_tensor.weight_compression_config = wcc
+
+    # No cache hit, perform the compression
+    assert weight_tens.quantization is not None
+    assert weight_tens.quantization.scale_f32 is not None
+    assert weight_tens.quantization.zero_point is not None
+
+    zero_point = weight_tens.quantization.zero_point
+    quant_buf = weight_tens.quant_values.astype(np.int64)
+
+    # Early zero-point correction
+    weights = quant_buf - zero_point
+
+    if len(weights.shape) == 2:
+        weights = np.expand_dims(np.expand_dims(weights, axis=0), axis=0)
+
+    # Expect this (undilated) equivalence
+    assert kernel.height == weights.shape[0]
+    assert kernel.width == weights.shape[1]
+    # Ensure depth offsets are terminated at end of OFM shape
+    assert len(depth_offsets) > 1, "Require closed depth ranges"
+
+    ifm_bitdepth = op.inputs[0].dtype.size_in_bits()
+    ifm_depth = weights.shape[-2]
+
+    # Default HW traversal
+    npu_tensor.hw_traversal = NpuBlockTraversal.DEPTH_FIRST
+
+    if npu_block_type == NpuBlockType.ConvolutionMxN:
+        # Determine which block traversal strategy has better DPU utilization
+        kernel_size = weights.shape[0] * weights.shape[1]
+        depth_utilization = weights.shape[2] / round_up(weights.shape[2], 32 if ifm_bitdepth == 8 else 16)
+        part_kernel_utilization = (weights.shape[2] / round_up(weights.shape[2], 8)) * (
+            kernel_size / round_up(kernel_size, 4 if ifm_bitdepth == 8 else 2)
+        )
+        if part_kernel_utilization >= depth_utilization or ifm_depth <= 8:
+            # Part-kernel first is always better for ifm depths <= 8
+            npu_tensor.hw_traversal = NpuBlockTraversal.PART_KERNEL_FIRST
+
+    if op.type == Op.Conv2DBackpropInputSwitchedBias:
+        # Transpose Convoluion, reverse weights in H and W axes
+        weights = np.flip(weights, axis=(0, 1))
+
+    encoded_stream = bytearray()
+    max_single_buffer_len = 0
+    is_depthwise = npu_block_type == NpuBlockType.ConvolutionDepthWise
+
+    # Bias & scale
+    if scale_tens:
+        quantised_scales, biases = _prepare_scale_and_bias(arch, scale_tens, rescale_for_faf)
+        scale_tens.element_size_bytes = 10
+
+    # Slice the weight stream up depth-ways into bricks and compress
+    full_ofm_depth = quant_buf.shape[-1]
+    ofm_block_depth = block_config.ofm_block.depth
+
+    weight_range_index = 0
+    for idx, depth_offset in enumerate(depth_offsets[:-1]):
+        # Do not generate for offsets outside the OFM
+        assert depth_offset >= 0 and depth_offset < full_ofm_depth
+        depth_length = depth_offsets[idx + 1] - depth_offset
+
+        # Get the weights necessary for this brick
+        brick_weights = weights[:, :, :, depth_offset : depth_offset + depth_length]
+
+        buffer_start_offset = len(encoded_stream)
+
+        # For each core, deinterleave weights from the larger volume
+        # and generate separate compressed streams.
+        for core in range(0, min(arch.ncores, full_ofm_depth)):
+
+            core_block_depth = int((ofm_block_depth + arch.ncores - 1 - core) // arch.ncores)
+
+            if core_block_depth != 0:
+                key = WeightKey(core, depth_offset)
+                weight_range = WeightRange()
+                weight_range.offset = len(encoded_stream)
+                weight_range.index = weight_range_index
+                weight_range_index += 1
+
+                # Scales & biases
+                if scale_tens:
+                    scale_stream = []
+                    core_scales = quantised_scales[
+                        depth_offset + core : depth_offset + core + depth_length : arch.ncores
+                    ]
+                    core_biases = biases[depth_offset + core : depth_offset + core + depth_length : arch.ncores]
+                    for j, core_bias in enumerate(core_biases):
+                        scale_stream.extend(encode_bias(np.int64(core_bias), *core_scales[j]))
+
+                    weight_range.scale_bytes = len(scale_stream)
+
+                    encoded_stream.extend(scale_stream)
+
+                    # Align to 16 for start of next substream
+                    remainder = len(encoded_stream) % 16
+                    if remainder > 0:
+                        encoded_stream.extend(bytearray(16 - remainder))
+
+                # Weights
+                core_weights = core_deinterleave(brick_weights, core, arch.ncores)
+                encoded_substream, _ = encode_weights(
+                    accelerator=arch.accelerator_config,
+                    weights_volume=core_weights,
+                    dilation_xy=kernel.dilation,
+                    ifm_bitdepth=ifm_bitdepth,
+                    ofm_block_depth=core_block_depth,
+                    is_depthwise=is_depthwise,
+                    block_traversal=npu_tensor.hw_traversal,
                 )
-                nng.total_npu_encoded_weights += tens.weight_compressed_offsets[-1]
-                nng.total_original_weights += int(tens.elements() * tens.element_size())
 
-                # Update source tensor
-                if needs_dma:
-                    src_tens = tens.get_dma_src_tensor()
-                    src_tens.shape = tens.shape
-                    src_tens.quant_values = tens.quant_values
-                    src_tens.copy_compressed_weight_info(tens)
-                    set_storage_shape(src_tens)
+                weight_range.weight_offset = len(encoded_stream) - weight_range.offset
+                weight_range.weight_bytes = len(encoded_substream)
 
-            if ps.scale_tensor is not None:
-                rescale_for_faf = False
-                if (ps.ops[-1].type in (Op.Sigmoid, Op.Tanh)) and (ps.npu_block_type != NpuBlockType.ElementWise):
-                    rescale_for_faf = True
-                calc_scales_and_pack_biases(ps.scale_tensor, arch, ofm_depth_step, rescale_for_faf)
-                if ps.scale_tensor.ops[0].type == Op.DMA:
-                    src_tens = ps.scale_tensor.get_dma_src_tensor()
-                    src_tens.shape = ps.scale_tensor.shape
-                    src_tens.quant_values = ps.scale_tensor.quant_values
-                    src_tens.element_size_bytes = ps.scale_tensor.element_size_bytes
-                    src_tens.copy_compressed_weight_info(ps.scale_tensor)
+                # Append encoded weights section
+                encoded_stream.extend(encoded_substream)
+                assert len(encoded_stream) % 16 == 0
+
+                # Record encoded range in weights tensor
+                npu_tensor.encoded_ranges[key] = weight_range
+
+        # Remember maximum encoded length for DoubleBuffering
+        max_single_buffer_len = max(max_single_buffer_len, len(encoded_stream) - buffer_start_offset)
+
+    npu_tensor.buffer = encoded_stream
+    npu_tensor.max_range_bytes = max_single_buffer_len
+    npu_tensor.set_all_shapes([1, 1, 1, len(encoded_stream)])
+    npu_tensor.format = TensorFormat.WeightsCompressed
+    npu_tensor.purpose = TensorPurpose.Weights
+    npu_tensor.mem_area = weight_tens.mem_area
+    npu_tensor.mem_type = weight_tens.mem_type
+    CompressedWeightCache.add(npu_tensor)
+    return npu_tensor