MLBEDSW-1941: Bug fix shared weights

If same weight tensor was used with different block configs,
errors would occur.

Fixed by always cloning weight tensors, using a global weight
compression cache and modifying the linear allocator to
detect multiple usage of same weight compression.

Change-Id: I91ca59176e1c59c66e0ac7a4227f2b5f0b47053f
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/vela/compiler_driver.py b/ethosu/vela/compiler_driver.py
index 64aff06..b6a98a6 100644
--- a/ethosu/vela/compiler_driver.py
+++ b/ethosu/vela/compiler_driver.py
@@ -144,13 +144,14 @@
     # processed first during serialization into tensors
     first_npu_sg = nng.subgraphs[1]
     assert first_npu_sg.placement == PassPlacement.Npu
+    # Use the linear allocator for constant tensors
     tensor_allocation.allocate_tensors(
         nng,
         first_npu_sg,
         arch,
         permanent_storage,
         scheduler_options.use_ifm_ofm_overlap,
-        options.tensor_allocator,
+        TensorAllocator.LinearAlloc,
         options.verbose_allocation,
         options.show_minimum_possible_allocation,
         lr_graph_flash,
@@ -195,7 +196,7 @@
         arch,
         permanent_storage,
         scheduler_options.use_ifm_ofm_overlap,
-        options.tensor_allocator,
+        TensorAllocator.LinearAlloc,
         options.verbose_allocation,
         options.show_minimum_possible_allocation,
     )
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index 0cc70a7..3b968dc 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -27,12 +27,8 @@
 from .tensor import TensorPurpose
 
 
-def need_dma(tens):
-    return len(tens.ops) == 1 and tens.ops[0].type == "DMA"
-
-
 def dma_if_necessary(ps, box, tensor):
-    if need_dma(tensor):
+    if tensor.needs_dma():
         dma_op = tensor.ops[0]
         in_tensor = dma_op.inputs[0]
         yield DMA(in_tensor, tensor, box)
@@ -93,7 +89,7 @@
     if strat == SchedulingStrategy.WeightStream:
         ofm_step = block_config[-1]
         ofm_stop = ofm_end[-1]
-        if weight_tensor is None or not need_dma(weight_tensor):
+        if weight_tensor is None or not weight_tensor.needs_dma():
             ofm_step = ofm_stop
         for start in range(ofm_start[-1], ofm_stop, ofm_step):
             end = min(start + ofm_step, ofm_stop)
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py
index e7b3e50..cd70446 100644
--- a/ethosu/vela/mark_tensors.py
+++ b/ethosu/vela/mark_tensors.py
@@ -17,7 +17,6 @@
 # Mark purpose and select formats for Tensors. Also compresses the weights.
 from . import rewrite_graph
 from . import weight_compressor
-from .architecture_features import Block
 from .operation import NpuBlockType
 from .tensor import TensorFormat
 from .tensor import TensorPurpose
@@ -348,18 +347,12 @@
     for tens, fmt in formats_for_tensor.items():
         tens.set_format(fmt, arch)
         if fmt == TensorFormat.WeightsCompressed and tens.values is not None:
-            npu_block_type = find_npu_usage_of_tensor(tens)
-            if len(tens.ops) == 1 and tens.ops[0].type == "DMA":
-                weight_compressor.compress_weights(tens, arch, npu_block_type, Block(32, 32, 32), 32)
+            src_tens = tens.get_dma_src_tensor()
+            if src_tens is not None:
+                npu_block_type = find_npu_usage_of_tensor(tens)
+                weight_compressor.compress_weights(arch, nng, tens, npu_block_type, 32, 32)
                 # Alias compressed weights back into source tensor
-                src_tens = tens.ops[0].inputs[0]
-                src_tens.compressed_values = tens.compressed_values
-                src_tens.storage_shape = tens.storage_shape
-                src_tens.brick_size = tens.brick_size
-                src_tens.weight_compression_scales = tens.weight_compression_scales
-                src_tens.weight_compressed_offsets = tens.weight_compressed_offsets
-                src_tens.compression_scale_for_worst_weight_stream = tens.compression_scale_for_worst_weight_stream
-                src_tens.storage_compression_scale = tens.storage_compression_scale
+                src_tens.copy_compressed_weight_info(tens)
 
     if verbose_tensor_format:
         nng.print_passes_with_tensors()
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index ed2ab32..ea35c08 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -485,6 +485,7 @@
         self.bits_per_element = {}
         self.total_size = {}
         self.total_elements = {}
+        self.weight_cache = None  # See CompressedWeightCache
 
     def get_root_subgraph(self):
         return self.subgraphs[0]
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 160cf63..2f91f61 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -225,7 +225,6 @@
         "quantization",
         "weight_compressed_offsets",
         "element_size_bytes",
-        "reshaped",
         "block_traversal",
         "offset",
         "cpu_tensor",
@@ -273,8 +272,6 @@
 
         # quantization parameters
         self.quantization = None
-
-        self.reshaped = False
         self.block_traversal = TensorBlockTraversal.Default
         self.resampling_mode = resampling_mode.NONE
 
@@ -294,20 +291,13 @@
 
         res.values = self.values
         res.quant_values = self.quant_values
-        res.compressed_values = self.compressed_values
         res.mem_area = self.mem_area
         res.format = self.format
         res.purpose = self.purpose
         res.sub_purpose = self.sub_purpose
         res.alignment = self.alignment
-        res.weight_transpose_depthwise = self.weight_transpose_depthwise
-
-        res.storage_compression_scale = self.storage_compression_scale
         res.bandwidth_compression_scale = self.bandwidth_compression_scale
-        res.compression_scale_for_worst_weight_stream = self.compression_scale_for_worst_weight_stream
-        res.weight_compression_scales = self.weight_compression_scales
         res.storage_rounding_quantum = self.storage_rounding_quantum
-        res.brick_size = self.brick_size
         res.address = 0
 
         if self.quantization is not None:
@@ -317,6 +307,7 @@
 
         res.resampling_mode = self.resampling_mode
 
+        res.copy_compressed_weight_info(self)
         return res
 
     def clone_into_fast_storage(self, arch):
@@ -324,6 +315,19 @@
         res.mem_area = arch.fast_storage_mem_area
         return res
 
+    def copy_compressed_weight_info(self, src_tens):
+        # Copies compressed values + all related weight compression info from the given tensor
+        self.compressed_values = src_tens.compressed_values
+        self.storage_shape = src_tens.storage_shape
+        self.brick_size = src_tens.brick_size
+        self.weight_compression_scales = src_tens.weight_compression_scales
+        self.weight_compressed_offsets = src_tens.weight_compressed_offsets
+        self.weight_transpose_depthwise = src_tens.weight_transpose_depthwise
+        self.compression_scale_for_worst_weight_stream = src_tens.compression_scale_for_worst_weight_stream
+        self.storage_compression_scale = src_tens.storage_compression_scale
+        self.block_traversal = src_tens.block_traversal
+        self.weight_compression_config = src_tens.weight_compression_config
+
     def set_format(self, fmt, arch):
         self.format = fmt
         shape_len = 0
@@ -527,6 +531,14 @@
 
         return strides
 
+    def needs_dma(self):
+        return len(self.ops) == 1 and self.ops[0].type == "DMA"
+
+    def get_dma_src_tensor(self):
+        # For weight tensors that need DMA: returns the source tensor in Flash, else None
+        # Note: for DMA ops, Pass.weight_tensor is referring to the SRAM weight tensor
+        return self.ops[0].inputs[0] if self.needs_dma() else None
+
     def compressed_stream_index_from_coord(self, coord):
         assert self.format == TensorFormat.WeightsCompressed
         assert len(self.compressed_values) > 0
@@ -575,7 +587,7 @@
             if len(self.weight_compressed_offsets) == 0:
                 return 0
 
-            if len(self.ops) == 1 and self.ops[0].type == "DMA" and self.sub_purpose == TensorSubPurpose.DoubleBuffer:
+            if self.needs_dma() and self.sub_purpose == TensorSubPurpose.DoubleBuffer:
                 depth = orig_coord[-1]
                 brick_depth = self.brick_size[-1]
                 # Clamp position at final element index
diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py
index cd2b570..e3952df 100644
--- a/ethosu/vela/tensor_allocation.py
+++ b/ethosu/vela/tensor_allocation.py
@@ -27,18 +27,26 @@
 from .tensor import MemArea
 
 
-def linear_allocate_live_ranges(live_ranges, alloc_granularity=256):
+def linear_allocate_live_ranges(live_ranges, alloc_granularity=16):
+    # Allocates using increasing addresses. Duplicate constant tensors will be allocated to the same address
     total_sz = 0
     allocated_tensors = []
 
-    # just assign increasing addresses
+    # just assign increasing addresses, except for duplicates
     for tens, lr in live_ranges.ranges.items():
         if tens in allocated_tensors:
             continue
 
-        lr.set_address(total_sz)
+        address = total_sz
+        if tens.weight_compression_config is not None:
+            for allocated_tens in allocated_tensors:
+                if allocated_tens.weight_compression_config == tens.weight_compression_config:
+                    address = allocated_tens.address
+                    break
+        lr.set_address(address)
         allocated_tensors += lr.tensors
-        total_sz += numeric_util.round_up(int(math.ceil(lr.size)), alloc_granularity)
+        if address == total_sz:
+            total_sz += numeric_util.round_up(int(math.ceil(lr.size)), alloc_granularity)
 
     return total_sz
 
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 109ae0e..5ab90f0 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -39,24 +39,24 @@
     return s.decode("utf-8")
 
 
-def reshape_tensor_add_const_op(tens, reorder):
-    if not tens.reshaped:
-        original_shape = tens.shape
-        tens.name = tens.name + "_reshape"
-        tens.shape = [original_shape[idx] for idx in reorder]
-        tens.bandwidth_shape = tens.shape
-        tens.storage_shape = tens.shape
+def clone_and_reshape_tensor(src_tens, reorder):
 
-        if tens.values is not None:
-            tens.values = tens.values.transpose(reorder)
+    tens = src_tens.clone("_reshape")
+    tens.shape = [src_tens.shape[idx] for idx in reorder]
+    tens.bandwidth_shape = tens.shape
+    tens.storage_shape = tens.shape
 
-        if tens.quant_values is not None:
-            tens.quant_values = tens.quant_values.transpose(reorder)
+    if tens.values is not None:
+        tens.values = tens.values.transpose(reorder)
 
-        op = Operation("Const", tens.name)
-        op.outputs = [tens]
-        tens.ops = [op]
-        tens.reshaped = True
+    if tens.quant_values is not None:
+        tens.quant_values = tens.quant_values.transpose(reorder)
+
+    op = Operation("Const", tens.name)
+    op.outputs = [tens]
+    tens.ops = [op]
+
+    return tens
 
 
 class TFLiteSubgraph:
@@ -137,10 +137,10 @@
         activation_function_to_split_out = None
 
         if op_type.startswith("DepthwiseConv2d") or op_type.startswith("Conv2D"):
-            reshape_tensor_add_const_op(inputs[1], (1, 2, 3, 0))
+            inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0))
 
         if op_type.startswith("FullyConnected"):
-            reshape_tensor_add_const_op(inputs[1], (1, 0))
+            inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0))
 
         if opt_serializer is not None:
             op.attrs = opt_serializer.deserialize(op_data.BuiltinOptions(), op_data.CustomOptionsAsNumpy())
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index a81b1fb..450e091 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -21,7 +21,6 @@
 import numpy as np
 from ethosu import mlw_codec
 
-from .architecture_features import Block
 from .data_type import DataType
 from .errors import UnsupportedFeatureError
 from .nn_graph import SchedulingStrategy
@@ -35,6 +34,46 @@
 from .tensor import TensorSubPurpose
 
 
+# Contains meta info for a weight compression. If two tensors have identical weight compression config,
+# then they also will have identical compressed weights.
+WeightCompressionConfig = namedtuple(
+    "WeightCompressionConfig", ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "equivalence_id"]
+)
+
+
+def create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step):
+    # Note: for an ofm block only its depth is used in weight compression.
+    # And block depth > ofm depth gives same result as block depth == ofm depth
+    block_depth = min(ofm_block_depth, tens.quant_values.shape[-1])
+    return WeightCompressionConfig(npu_block_type, block_depth, ofm_depth_step, tens.equivalence_id)
+
+
+def set_storage_shape(tens):
+    # Sets the storage shape depending on the tensor's sub purpose
+    if tens.sub_purpose == TensorSubPurpose.DoubleBuffer and len(tens.compressed_values) > 2:
+        offset = 2 * np.amax([len(x) for x in tens.compressed_values])
+        assert offset % 16 == 0
+    else:
+        offset = tens.weight_compressed_offsets[-1]
+    tens.storage_shape = [1, 1, 1, offset]
+
+
+class CompressedWeightCache:
+    # Contains weight compressions for all weight tensors in a graph
+    def __init__(self):
+        self.cache = {}  # maps from WeightCompressionConfig to a tensor clone containing compressed weights
+
+    def get_tensor_with_same_compression(self, wcc):
+        return self.cache.get(wcc)
+
+    def add(self, tens):
+        # Adds the compressed weights from the tensor to the cache
+        wcc = tens.weight_compression_config
+        # Clone the tensor to make sure that nothing related to the weight compression is modified
+        tens_clone = tens.clone("_weights{}_{}".format(wcc.ofm_block_depth, wcc.ofm_depth_step))
+        self.cache[wcc] = tens_clone
+
+
 def encode(weight_stream):
     assert np.amin(weight_stream) >= -255
     assert np.amax(weight_stream) <= 255
@@ -51,7 +90,7 @@
     return compressed
 
 
-def generate_brick(arch, brick_weights, ofm_block, block_traversal, ifm_bitdepth):
+def generate_brick(arch, brick_weights, ofm_block_depth, block_traversal, ifm_bitdepth):
     is_depthwise = block_traversal == TensorBlockTraversal.DepthWise
     is_partkernel = block_traversal == TensorBlockTraversal.PartKernelFirst
     subkernel_max = arch.subkernel_max
@@ -74,8 +113,8 @@
     stream = []
 
     # Top level striping - OFM blocks in the entire brick's depth
-    for ofm_block_z in range(0, ofm_depth, ofm_block.depth):
-        clipped_ofm_block_depth = min(ofm_block.depth, ofm_depth - ofm_block_z)
+    for ofm_block_z in range(0, ofm_depth, ofm_block_depth):
+        clipped_ofm_block_depth = min(ofm_block_depth, ofm_depth - ofm_block_z)
         # IFM blocks required for the brick
         for ifm_block_z in range(0, (1 if is_depthwise else ifm_depth), ifm_block_depth):
             if is_depthwise:
@@ -139,20 +178,23 @@
 
 
 # Compress the weights
-def compress_weights(tens, arch, npu_block_type, ofm_block, ofm_depth_step, min_val=None, max_val=None):
+def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth_step):
     assert tens.purpose == TensorPurpose.Weights
     assert tens.format == TensorFormat.WeightsCompressed
 
-    WeightCompressionConfig = namedtuple("WeightCompressionConfig", ["npu_block_type", "ofm_block", "ofm_depth_step"])
+    # Check the weight cache
+    if nng.weight_cache is None:
+        nng.weight_cache = CompressedWeightCache()
+    wcc = create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step)
+    tens.weight_compression_config = wcc
+    tens_cached = nng.weight_cache.get_tensor_with_same_compression(wcc)
+    if tens_cached is not None:
+        # Cache hit, copy weights from the cache
+        tens.copy_compressed_weight_info(tens_cached)
+        set_storage_shape(tens)
+        return
 
-    # check if weights have already been compressed
-    wcc = tens.weight_compression_config
-    if wcc is not None:
-        assert wcc.npu_block_type == npu_block_type, "Weights not used by the same operator type"
-
-        if wcc.ofm_block == ofm_block and wcc.ofm_depth_step == ofm_depth_step:
-            return
-
+    # No cache hit, perform the compression
     assert tens.quantization is not None
     assert tens.quantization.scale_f32 is not None
     assert tens.quantization.zero_point is not None
@@ -173,7 +215,6 @@
     compressed_offsets = []
     encoded_streams = []
     offset = 0
-    max_single_buffer_len = 0
 
     ifm_bitdepth = tens.consumer_list[0].inputs[0].dtype.size_in_bits()
     ifm_depth = weights.shape[-2]
@@ -200,14 +241,10 @@
         brick_weights = weights[:, :, :, idx : idx + count]
 
         # Encode all weights into one chunk
-        raw_stream = generate_brick(arch, brick_weights, ofm_block, tens.block_traversal, ifm_bitdepth)
+        raw_stream = generate_brick(arch, brick_weights, ofm_block_depth, tens.block_traversal, ifm_bitdepth)
         encoded = encode(raw_stream)
         encoded_streams.append(encoded)
 
-        # Remember maximum encoded length for DoubleBuffering
-        if max_single_buffer_len < len(encoded):
-            max_single_buffer_len = len(encoded)
-
         # Remember where we put it for linear addressing
         compressed_offsets.append(offset)
         offset += len(encoded)
@@ -219,18 +256,14 @@
     # Also track complete length in the offsets array
     compressed_offsets.append(offset)
 
-    if tens.sub_purpose == TensorSubPurpose.DoubleBuffer and len(encoded_streams) > 2:
-        offset = 2 * max_single_buffer_len
-        assert offset % 16 == 0
-
-    tens.storage_shape = [1, 1, 1, offset]
     tens.weight_compression_scales = compression_scales
-    tens.weight_compression_config = WeightCompressionConfig(npu_block_type, ofm_block, ofm_depth_step)
     tens.weight_compressed_offsets = compressed_offsets
     tens.compression_scale_for_worst_weight_stream = np.amax(compression_scales)
     tens.storage_compression_scale = tens.bandwidth_compression_scale = np.average(compression_scales)
     tens.compressed_values = encoded_streams
     tens.brick_size = (weights_shape[0], weights_shape[1], weights_shape[2], min(tens.shape[-1], ofm_depth_step))
+    set_storage_shape(tens)
+    nng.weight_cache.add(tens)
 
 
 def calc_scales_and_pack_biases(tens, arch, oc_quantum, rescale_for_faf=False):
@@ -352,39 +385,29 @@
 
     for sg in nng.subgraphs:
         for ps in sg.passes:
-            if ps.weight_tensor is not None:
-                npu_usage_of_tensor = find_npu_usage_of_tensor(ps.weight_tensor)
+            tens = ps.weight_tensor
+            if tens is not None:
+                npu_usage_of_tensor = find_npu_usage_of_tensor(tens)
                 if npu_usage_of_tensor == NpuBlockType.ConvolutionDepthWise:
-                    ps.weight_tensor.quant_values = np.transpose(ps.weight_tensor.quant_values, (0, 1, 3, 2))
-                    ps.weight_tensor.shape = ps.weight_tensor.storage_shape = ps.weight_tensor.bandwidth_shape = list(
-                        ps.weight_tensor.quant_values.shape
-                    )
-                    ps.weight_tensor.weight_transpose_depthwise = True
+                    tens.quant_values = np.transpose(tens.quant_values, (0, 1, 3, 2))
+                    tens.shape = tens.storage_shape = tens.bandwidth_shape = list(tens.quant_values.shape)
+                    tens.weight_transpose_depthwise = True
 
-                needs_dma = len(ps.weight_tensor.ops) == 1 and ps.weight_tensor.ops[0].type == "DMA"
+                needs_dma = tens.needs_dma()
                 if ps.cascade.strategy == SchedulingStrategy.WeightStream and needs_dma:
                     ofm_depth_step = ps.block_config[-1]
                 else:
-                    ofm_depth_step = ps.weight_tensor.shape[-1]
-
+                    ofm_depth_step = tens.shape[-1]
                 compress_weights(
-                    ps.weight_tensor,
-                    arch,
-                    npu_usage_of_tensor,
-                    Block(ps.block_config[-3], ps.block_config[-4], ps.block_config[-1]),
-                    ofm_depth_step,
+                    arch, nng, tens, npu_usage_of_tensor, ps.block_config[-1], ofm_depth_step,
                 )
                 # Update source tensor
-                if len(ps.weight_tensor.ops) == 1 and ps.weight_tensor.ops[0].type == "DMA":
-                    src_tens = ps.weight_tensor.ops[0].inputs[0]
-                    src_tens.shape = ps.weight_tensor.shape
-                    src_tens.weight_transpose_depthwise = ps.weight_tensor.weight_transpose_depthwise
-                    src_tens.quant_values = ps.weight_tensor.quant_values
-                    src_tens.compressed_values = ps.weight_tensor.compressed_values
-                    src_tens.storage_shape = [1, 1, 1, ps.weight_tensor.weight_compressed_offsets[-1]]
-                    src_tens.brick_size = ps.weight_tensor.brick_size
-                    src_tens.weight_compression_scales = ps.weight_tensor.weight_compression_scales
-                    src_tens.weight_compressed_offsets = ps.weight_tensor.weight_compressed_offsets
+                if needs_dma:
+                    src_tens = tens.get_dma_src_tensor()
+                    src_tens.shape = tens.shape
+                    src_tens.quant_values = tens.quant_values
+                    src_tens.copy_compressed_weight_info(tens)
+                    set_storage_shape(src_tens)
 
             if ps.scale_tensor is not None:
                 rescale_for_faf = False