MLBEDSW-1941: Bug fix shared weights

If same weight tensor was used with different block configs,
errors would occur.

Fixed by always cloning weight tensors, using a global weight
compression cache and modifying the linear allocator to
detect multiple usage of same weight compression.

Change-Id: I91ca59176e1c59c66e0ac7a4227f2b5f0b47053f
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index a81b1fb..450e091 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -21,7 +21,6 @@
 import numpy as np
 from ethosu import mlw_codec
 
-from .architecture_features import Block
 from .data_type import DataType
 from .errors import UnsupportedFeatureError
 from .nn_graph import SchedulingStrategy
@@ -35,6 +34,46 @@
 from .tensor import TensorSubPurpose
 
 
+# Contains meta info for a weight compression. If two tensors have identical weight compression config,
+# then they also will have identical compressed weights.
+WeightCompressionConfig = namedtuple(
+    "WeightCompressionConfig", ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "equivalence_id"]
+)
+
+
+def create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step):
+    # Note: for an ofm block only its depth is used in weight compression.
+    # And block depth > ofm depth gives same result as block depth == ofm depth
+    block_depth = min(ofm_block_depth, tens.quant_values.shape[-1])
+    return WeightCompressionConfig(npu_block_type, block_depth, ofm_depth_step, tens.equivalence_id)
+
+
+def set_storage_shape(tens):
+    # Sets the storage shape depending on the tensor's sub purpose
+    if tens.sub_purpose == TensorSubPurpose.DoubleBuffer and len(tens.compressed_values) > 2:
+        offset = 2 * np.amax([len(x) for x in tens.compressed_values])
+        assert offset % 16 == 0
+    else:
+        offset = tens.weight_compressed_offsets[-1]
+    tens.storage_shape = [1, 1, 1, offset]
+
+
+class CompressedWeightCache:
+    # Contains weight compressions for all weight tensors in a graph
+    def __init__(self):
+        self.cache = {}  # maps from WeightCompressionConfig to a tensor clone containing compressed weights
+
+    def get_tensor_with_same_compression(self, wcc):
+        return self.cache.get(wcc)
+
+    def add(self, tens):
+        # Adds the compressed weights from the tensor to the cache
+        wcc = tens.weight_compression_config
+        # Clone the tensor to make sure that nothing related to the weight compression is modified
+        tens_clone = tens.clone("_weights{}_{}".format(wcc.ofm_block_depth, wcc.ofm_depth_step))
+        self.cache[wcc] = tens_clone
+
+
 def encode(weight_stream):
     assert np.amin(weight_stream) >= -255
     assert np.amax(weight_stream) <= 255
@@ -51,7 +90,7 @@
     return compressed
 
 
-def generate_brick(arch, brick_weights, ofm_block, block_traversal, ifm_bitdepth):
+def generate_brick(arch, brick_weights, ofm_block_depth, block_traversal, ifm_bitdepth):
     is_depthwise = block_traversal == TensorBlockTraversal.DepthWise
     is_partkernel = block_traversal == TensorBlockTraversal.PartKernelFirst
     subkernel_max = arch.subkernel_max
@@ -74,8 +113,8 @@
     stream = []
 
     # Top level striping - OFM blocks in the entire brick's depth
-    for ofm_block_z in range(0, ofm_depth, ofm_block.depth):
-        clipped_ofm_block_depth = min(ofm_block.depth, ofm_depth - ofm_block_z)
+    for ofm_block_z in range(0, ofm_depth, ofm_block_depth):
+        clipped_ofm_block_depth = min(ofm_block_depth, ofm_depth - ofm_block_z)
         # IFM blocks required for the brick
         for ifm_block_z in range(0, (1 if is_depthwise else ifm_depth), ifm_block_depth):
             if is_depthwise:
@@ -139,20 +178,23 @@
 
 
 # Compress the weights
-def compress_weights(tens, arch, npu_block_type, ofm_block, ofm_depth_step, min_val=None, max_val=None):
+def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth_step):
     assert tens.purpose == TensorPurpose.Weights
     assert tens.format == TensorFormat.WeightsCompressed
 
-    WeightCompressionConfig = namedtuple("WeightCompressionConfig", ["npu_block_type", "ofm_block", "ofm_depth_step"])
+    # Check the weight cache
+    if nng.weight_cache is None:
+        nng.weight_cache = CompressedWeightCache()
+    wcc = create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step)
+    tens.weight_compression_config = wcc
+    tens_cached = nng.weight_cache.get_tensor_with_same_compression(wcc)
+    if tens_cached is not None:
+        # Cache hit, copy weights from the cache
+        tens.copy_compressed_weight_info(tens_cached)
+        set_storage_shape(tens)
+        return
 
-    # check if weights have already been compressed
-    wcc = tens.weight_compression_config
-    if wcc is not None:
-        assert wcc.npu_block_type == npu_block_type, "Weights not used by the same operator type"
-
-        if wcc.ofm_block == ofm_block and wcc.ofm_depth_step == ofm_depth_step:
-            return
-
+    # No cache hit, perform the compression
     assert tens.quantization is not None
     assert tens.quantization.scale_f32 is not None
     assert tens.quantization.zero_point is not None
@@ -173,7 +215,6 @@
     compressed_offsets = []
     encoded_streams = []
     offset = 0
-    max_single_buffer_len = 0
 
     ifm_bitdepth = tens.consumer_list[0].inputs[0].dtype.size_in_bits()
     ifm_depth = weights.shape[-2]
@@ -200,14 +241,10 @@
         brick_weights = weights[:, :, :, idx : idx + count]
 
         # Encode all weights into one chunk
-        raw_stream = generate_brick(arch, brick_weights, ofm_block, tens.block_traversal, ifm_bitdepth)
+        raw_stream = generate_brick(arch, brick_weights, ofm_block_depth, tens.block_traversal, ifm_bitdepth)
         encoded = encode(raw_stream)
         encoded_streams.append(encoded)
 
-        # Remember maximum encoded length for DoubleBuffering
-        if max_single_buffer_len < len(encoded):
-            max_single_buffer_len = len(encoded)
-
         # Remember where we put it for linear addressing
         compressed_offsets.append(offset)
         offset += len(encoded)
@@ -219,18 +256,14 @@
     # Also track complete length in the offsets array
     compressed_offsets.append(offset)
 
-    if tens.sub_purpose == TensorSubPurpose.DoubleBuffer and len(encoded_streams) > 2:
-        offset = 2 * max_single_buffer_len
-        assert offset % 16 == 0
-
-    tens.storage_shape = [1, 1, 1, offset]
     tens.weight_compression_scales = compression_scales
-    tens.weight_compression_config = WeightCompressionConfig(npu_block_type, ofm_block, ofm_depth_step)
     tens.weight_compressed_offsets = compressed_offsets
     tens.compression_scale_for_worst_weight_stream = np.amax(compression_scales)
     tens.storage_compression_scale = tens.bandwidth_compression_scale = np.average(compression_scales)
     tens.compressed_values = encoded_streams
     tens.brick_size = (weights_shape[0], weights_shape[1], weights_shape[2], min(tens.shape[-1], ofm_depth_step))
+    set_storage_shape(tens)
+    nng.weight_cache.add(tens)
 
 
 def calc_scales_and_pack_biases(tens, arch, oc_quantum, rescale_for_faf=False):
@@ -352,39 +385,29 @@
 
     for sg in nng.subgraphs:
         for ps in sg.passes:
-            if ps.weight_tensor is not None:
-                npu_usage_of_tensor = find_npu_usage_of_tensor(ps.weight_tensor)
+            tens = ps.weight_tensor
+            if tens is not None:
+                npu_usage_of_tensor = find_npu_usage_of_tensor(tens)
                 if npu_usage_of_tensor == NpuBlockType.ConvolutionDepthWise:
-                    ps.weight_tensor.quant_values = np.transpose(ps.weight_tensor.quant_values, (0, 1, 3, 2))
-                    ps.weight_tensor.shape = ps.weight_tensor.storage_shape = ps.weight_tensor.bandwidth_shape = list(
-                        ps.weight_tensor.quant_values.shape
-                    )
-                    ps.weight_tensor.weight_transpose_depthwise = True
+                    tens.quant_values = np.transpose(tens.quant_values, (0, 1, 3, 2))
+                    tens.shape = tens.storage_shape = tens.bandwidth_shape = list(tens.quant_values.shape)
+                    tens.weight_transpose_depthwise = True
 
-                needs_dma = len(ps.weight_tensor.ops) == 1 and ps.weight_tensor.ops[0].type == "DMA"
+                needs_dma = tens.needs_dma()
                 if ps.cascade.strategy == SchedulingStrategy.WeightStream and needs_dma:
                     ofm_depth_step = ps.block_config[-1]
                 else:
-                    ofm_depth_step = ps.weight_tensor.shape[-1]
-
+                    ofm_depth_step = tens.shape[-1]
                 compress_weights(
-                    ps.weight_tensor,
-                    arch,
-                    npu_usage_of_tensor,
-                    Block(ps.block_config[-3], ps.block_config[-4], ps.block_config[-1]),
-                    ofm_depth_step,
+                    arch, nng, tens, npu_usage_of_tensor, ps.block_config[-1], ofm_depth_step,
                 )
                 # Update source tensor
-                if len(ps.weight_tensor.ops) == 1 and ps.weight_tensor.ops[0].type == "DMA":
-                    src_tens = ps.weight_tensor.ops[0].inputs[0]
-                    src_tens.shape = ps.weight_tensor.shape
-                    src_tens.weight_transpose_depthwise = ps.weight_tensor.weight_transpose_depthwise
-                    src_tens.quant_values = ps.weight_tensor.quant_values
-                    src_tens.compressed_values = ps.weight_tensor.compressed_values
-                    src_tens.storage_shape = [1, 1, 1, ps.weight_tensor.weight_compressed_offsets[-1]]
-                    src_tens.brick_size = ps.weight_tensor.brick_size
-                    src_tens.weight_compression_scales = ps.weight_tensor.weight_compression_scales
-                    src_tens.weight_compressed_offsets = ps.weight_tensor.weight_compressed_offsets
+                if needs_dma:
+                    src_tens = tens.get_dma_src_tensor()
+                    src_tens.shape = tens.shape
+                    src_tens.quant_values = tens.quant_values
+                    src_tens.copy_compressed_weight_info(tens)
+                    set_storage_shape(src_tens)
 
             if ps.scale_tensor is not None:
                 rescale_for_faf = False