MLBEDSW-4602: Fix Deepspeech scale & bias reuse issue.

 - Deepspeech reuses identical weights and biases throughout
   the network. Since biases are now interleaved with weights
   there is a scaling issue when the ifm scales differ between
   operations using the same weight and scale tensor.

 - This commit uses interleaved weights/scales on their first use
   but separates scales to source memory on subsequent use (if
   the ifm scale is different).

Signed-off-by: Tim Hall <tim.hall@arm.com>
Change-Id: I7aae163438160a919cae04e235966e75355a6148
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 652d016..4ce03d5 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -40,10 +40,11 @@
 # Contains meta info for a weight compression. If two tensors have identical weight compression config,
 # then they also will have identical compressed weights.
 WeightCompressionConfig = namedtuple(
-    "WeightCompressionConfig",
-    ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "dilation", "weight_value_id", "scale_value_id"],
+    "WeightCompressionConfig", ["npu_block_type", "ofm_block_depth", "ofm_depth_step", "dilation", "weight_value_id"],
 )
 
+ScaleCompressionConfig = namedtuple("ScaleCompressionConfig", ["scale_value_id", "ifm_scale", "ofm_scale"])
+
 WeightKey = namedtuple("WeightKey", ["core", "depth"])
 
 
@@ -68,6 +69,7 @@
         self.encoded_ranges = OrderedDict()
         self.hw_traversal = NpuBlockTraversal.DEPTH_FIRST
         self.dtype = DataType.uint8
+        self.scale_compression_config = None
 
 
 class CompressedWeightCache:
@@ -95,15 +97,11 @@
         return cache_obj[1] if cache_obj else None
 
 
-def create_weight_compression_config(
-    weight_tens, scale_tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation
-):
+def create_weight_compression_config(weight_tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation):
     # Note: for an ofm block only its depth is used in weight compression.
     # And block depth > ofm depth gives same result as block depth == ofm depth
     block_depth = min(ofm_block_depth, weight_tens.quant_values.shape[-1])
-    return WeightCompressionConfig(
-        npu_block_type, block_depth, ofm_depth_step, dilation, weight_tens.value_id, scale_tens.value_id
-    )
+    return WeightCompressionConfig(npu_block_type, block_depth, ofm_depth_step, dilation, weight_tens.value_id)
 
 
 def encode_weights(
@@ -277,72 +275,86 @@
 
 def encode_weight_and_scale_tensor(
     arch, op, weight_tens, scale_tens, kernel, block_config, depth_offsets, rescale_for_faf=False
-) -> NpuWeightTensor:
+) -> (NpuWeightTensor, NpuWeightTensor):
     npu_block_type = op.type.npu_block_type
 
+    ifm_scale = scale_tens and scale_tens.consumer_list[0].get_input_quantization().scale_f32
+    ofm_scale = scale_tens and scale_tens.consumer_list[0].get_output_quantization().scale_f32
+
     wcc = create_weight_compression_config(
-        weight_tens, scale_tens, npu_block_type, block_config.ofm_block.depth, hash(str(depth_offsets)), kernel.dilation
+        weight_tens, npu_block_type, block_config.ofm_block.depth, hash(str(depth_offsets)), kernel.dilation
     )
 
+    scc = ScaleCompressionConfig(scale_tens and scale_tens.value_id, ifm_scale, ofm_scale)
+
     tens_cached = CompressedWeightCache.get_tensor_with_same_compression(wcc)
     if tens_cached is not None:
-        return tens_cached
+        if tens_cached.scale_compression_config == scc:
+            return tens_cached, None
+        npu_tensor = NpuWeightTensor(scale_tens.name)
+        do_weights = False
+        do_scales = True
+    else:
+        npu_tensor = NpuWeightTensor(weight_tens.name)
+        do_weights = True
+        do_scales = True
 
-    npu_tensor = NpuWeightTensor(weight_tens.name)
     npu_tensor.weight_compression_config = wcc
+    npu_tensor.scale_compression_config = scc
 
-    # No cache hit, perform the compression
-    assert weight_tens.quantization is not None
-    assert weight_tens.quantization.scale_f32 is not None
-    assert weight_tens.quantization.zero_point is not None
-
-    zero_point = weight_tens.quantization.zero_point
-    quant_buf = weight_tens.quant_values.astype(np.int64)
-
-    # Early zero-point correction
-    weights = quant_buf - zero_point
-
-    if len(weights.shape) == 2:
-        weights = np.expand_dims(np.expand_dims(weights, axis=0), axis=0)
-
-    # Expect this (undilated) equivalence
-    assert kernel.height == weights.shape[0]
-    assert kernel.width == weights.shape[1]
     # Ensure depth offsets are terminated at end of OFM shape
     assert len(depth_offsets) > 1, "Require closed depth ranges"
 
     ifm_bitdepth = op.inputs[0].dtype.size_in_bits()
-    ifm_depth = weights.shape[-2]
 
-    # Default HW traversal
-    npu_tensor.hw_traversal = NpuBlockTraversal.DEPTH_FIRST
+    # No cache hit, need to perform the encoding
+    if do_weights:
+        assert weight_tens.quantization is not None
+        assert weight_tens.quantization.scale_f32 is not None
+        assert weight_tens.quantization.zero_point is not None
 
-    if npu_block_type == NpuBlockType.ConvolutionMxN:
-        # Determine which block traversal strategy has better DPU utilization
-        kernel_size = weights.shape[0] * weights.shape[1]
-        depth_utilization = weights.shape[2] / round_up(weights.shape[2], 32 if ifm_bitdepth == 8 else 16)
-        part_kernel_utilization = (weights.shape[2] / round_up(weights.shape[2], 8)) * (
-            kernel_size / round_up(kernel_size, 4 if ifm_bitdepth == 8 else 2)
-        )
-        if part_kernel_utilization >= depth_utilization or ifm_depth <= 8:
-            # Part-kernel first is always better for ifm depths <= 8
-            npu_tensor.hw_traversal = NpuBlockTraversal.PART_KERNEL_FIRST
+        # Early zero-point correction
+        quant_buf = weight_tens.quant_values.astype(np.int64)
+        weights = quant_buf - weight_tens.quantization.zero_point
 
-    if op.type == Op.Conv2DBackpropInputSwitchedBias:
-        # Transpose Convoluion, reverse weights in H and W axes
-        weights = np.flip(weights, axis=(0, 1))
+        if len(weights.shape) == 2:
+            weights = np.expand_dims(np.expand_dims(weights, axis=0), axis=0)
+
+        # Expect this (undilated) equivalence
+        assert kernel.height == weights.shape[0]
+        assert kernel.width == weights.shape[1]
+
+        ifm_depth = weights.shape[-2]
+
+        # Default HW traversal
+        npu_tensor.hw_traversal = NpuBlockTraversal.DEPTH_FIRST
+
+        if npu_block_type == NpuBlockType.ConvolutionMxN:
+            # Determine which block traversal strategy has better DPU utilization
+            kernel_size = weights.shape[0] * weights.shape[1]
+            depth_utilization = weights.shape[2] / round_up(weights.shape[2], 32 if ifm_bitdepth == 8 else 16)
+            part_kernel_utilization = (weights.shape[2] / round_up(weights.shape[2], 8)) * (
+                kernel_size / round_up(kernel_size, 4 if ifm_bitdepth == 8 else 2)
+            )
+            if part_kernel_utilization >= depth_utilization or ifm_depth <= 8:
+                # Part-kernel first is always better for ifm depths <= 8
+                npu_tensor.hw_traversal = NpuBlockTraversal.PART_KERNEL_FIRST
+
+        if op.type == Op.Conv2DBackpropInputSwitchedBias:
+            # Transpose Convoluion, reverse weights in H and W axes
+            weights = np.flip(weights, axis=(0, 1))
 
     encoded_stream = bytearray()
     max_single_buffer_len = 0
     is_depthwise = npu_block_type == NpuBlockType.ConvolutionDepthWise
 
     # Bias & scale
-    if scale_tens:
+    if do_scales:
         quantised_scales, biases = _prepare_scale_and_bias(arch, scale_tens, rescale_for_faf)
         scale_tens.element_size_bytes = 10
 
     # Slice the weight stream up depth-ways into bricks and compress
-    full_ofm_depth = quant_buf.shape[-1]
+    full_ofm_depth = weight_tens.quant_values.shape[-1]
     ofm_block_depth = block_config.ofm_block.depth
 
     weight_range_index = 0
@@ -352,11 +364,12 @@
         depth_length = depth_offsets[idx + 1] - depth_offset
 
         # Get the weights necessary for this brick
-        brick_weights = weights[:, :, :, depth_offset : depth_offset + depth_length]
+        if do_weights:
+            brick_weights = weights[:, :, :, depth_offset : depth_offset + depth_length]
 
         buffer_start_offset = len(encoded_stream)
 
-        # For each core, deinterleave weights from the larger volume
+        # For each core, deinterleave weights/scales from the larger volume
         # and generate separate compressed streams.
         for core in range(0, min(arch.ncores, full_ofm_depth)):
 
@@ -370,7 +383,7 @@
                 weight_range_index += 1
 
                 # Scales & biases
-                if scale_tens:
+                if do_scales:
                     scale_stream = []
                     core_scales = quantised_scales[
                         depth_offset + core : depth_offset + core + depth_length : arch.ncores
@@ -389,36 +402,49 @@
                         encoded_stream.extend(bytearray(16 - remainder))
 
                 # Weights
-                core_weights = core_deinterleave(brick_weights, core, arch.ncores)
-                encoded_substream, _ = encode_weights(
-                    accelerator=arch.accelerator_config,
-                    weights_volume=core_weights,
-                    dilation_xy=kernel.dilation,
-                    ifm_bitdepth=ifm_bitdepth,
-                    ofm_block_depth=core_block_depth,
-                    is_depthwise=is_depthwise,
-                    block_traversal=npu_tensor.hw_traversal,
-                )
+                if do_weights:
+                    core_weights = core_deinterleave(brick_weights, core, arch.ncores)
+                    encoded_substream, _ = encode_weights(
+                        accelerator=arch.accelerator_config,
+                        weights_volume=core_weights,
+                        dilation_xy=kernel.dilation,
+                        ifm_bitdepth=ifm_bitdepth,
+                        ofm_block_depth=core_block_depth,
+                        is_depthwise=is_depthwise,
+                        block_traversal=npu_tensor.hw_traversal,
+                    )
+                    weight_range.weight_offset = len(encoded_stream) - weight_range.offset
+                    weight_range.weight_bytes = len(encoded_substream)
+                    # Append encoded section
+                    encoded_stream.extend(encoded_substream)
+                    assert len(encoded_stream) % 16 == 0
 
-                weight_range.weight_offset = len(encoded_stream) - weight_range.offset
-                weight_range.weight_bytes = len(encoded_substream)
-
-                # Append encoded weights section
-                encoded_stream.extend(encoded_substream)
-                assert len(encoded_stream) % 16 == 0
-
-                # Record encoded range in weights tensor
+                # Record encoded range in tensor
                 npu_tensor.encoded_ranges[key] = weight_range
 
         # Remember maximum encoded length for DoubleBuffering
         max_single_buffer_len = max(max_single_buffer_len, len(encoded_stream) - buffer_start_offset)
 
+    # Attach buffer to tensor
     npu_tensor.buffer = encoded_stream
     npu_tensor.max_range_bytes = max_single_buffer_len
     npu_tensor.set_all_shapes([1, 1, 1, len(encoded_stream)])
     npu_tensor.format = TensorFormat.WeightsCompressed
-    npu_tensor.purpose = TensorPurpose.Weights
-    npu_tensor.mem_area = weight_tens.mem_area
-    npu_tensor.mem_type = weight_tens.mem_type
-    CompressedWeightCache.add(npu_tensor)
-    return npu_tensor
+
+    # Scale only tensor
+    if not do_weights:
+        npu_tensor.weight_compression_config = None
+        npu_tensor.purpose = TensorPurpose.FSBias
+        npu_tensor.mem_area = scale_tens.mem_area
+        npu_tensor.mem_type = scale_tens.mem_type
+        weights_tensor = tens_cached
+        scale_tensor = npu_tensor
+    else:
+        npu_tensor.purpose = TensorPurpose.Weights
+        npu_tensor.mem_area = weight_tens.mem_area
+        npu_tensor.mem_type = weight_tens.mem_type
+        weights_tensor = npu_tensor
+        scale_tensor = None
+        CompressedWeightCache.add(weights_tensor)
+
+    return weights_tensor, scale_tensor