MLBEDSW-3060 Adjust check if weights fit in sram

When deciding if weights fit sram:
A compression of the weights has been added when a
weight compression test limit makes it impossible to
fit weights in a double buffer in sram.

The worst compression ratio from compression, is used
to decide if weights can be fit in sram.

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I9458769866b3f9fc15659185aae09658ed10fb38
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index e7c15cd..4f435dc 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -68,14 +68,14 @@
 memory_only_ops = set(("Reshape",))
 
 
-def remove_passthrough_tensor(tens, arch):
+def remove_passthrough_tensor(tens, arch, nng):
     if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
         assert len(tens.ops[0].inputs) == 1
         tens = tens.ops[0].inputs[0]
     return tens
 
 
-def rewrite_concat(tens, arch):
+def rewrite_concat(tens, arch, nng):
     if len(tens.ops) == 1 and tens.ops[0].is_concat_op():
         concat_op = tens.ops[0]
         if tens != concat_op.outputs[0]:
@@ -114,7 +114,7 @@
     return tens
 
 
-def rewrite_split(tens, arch):
+def rewrite_split(tens, arch, nng):
 
     if len(tens.ops) == 1 and tens.ops[0].is_split_op():
         split_op = tens.ops[0]
@@ -205,7 +205,7 @@
     return padding, skirt
 
 
-def fixup_conv2d_backprop(op, arch):
+def fixup_conv2d_backprop(op, arch, nng):
     if op.type == "Conv2DBackpropInput":
         # flip the inputs
         op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
@@ -295,7 +295,7 @@
     return op
 
 
-def fixup_resizebilinear(op, arch):
+def fixup_resizebilinear(op, arch, nng):
     if op.type == "ResizeBilinear" and op.run_on_npu:
         if op.inputs[0].shape == op.outputs[0].shape:
             # Bypass nop resizebilinear
@@ -309,7 +309,7 @@
     return op
 
 
-def convert_nop_split_to_identity(op, arch):
+def convert_nop_split_to_identity(op, arch, nng):
     if op.type == "Split" and op.attrs.get("num_splits") == 1:
         # the list comprehension should return a list with a single tensor
         # if it shouldn't, remove_passthrough_tensor will fail appropriately
@@ -318,7 +318,7 @@
     return op
 
 
-def fixup_fully_connected_input(op, arch):
+def fixup_fully_connected_input(op, arch, nng):
     if op.type == "FullyConnectedAct":
         inp = op.inputs[0]
         weights = op.inputs[1]
@@ -336,7 +336,7 @@
     return op
 
 
-def convert_batched_fc_to_conv(op, arch):
+def convert_batched_fc_to_conv(op, arch, nng):
     if op.type == "FullyConnectedAct":
         ifm = op.inputs[0]
         ofm = op.outputs[0]
@@ -407,7 +407,7 @@
     return op
 
 
-def fixup_pack_input(op, arch):
+def fixup_pack_input(op, arch, nng):
     if op.type == "Pack":
         # Pack is also referred to as Stack
         # Requires the rewrite_concat function to be called on the op afterwards
@@ -433,7 +433,7 @@
     return op
 
 
-def unfuse_activation_function(op, arch):
+def unfuse_activation_function(op, arch, nng):
     unfuse_ops = ("ConcatTFLite",)
     if op.type in unfuse_ops and op.run_on_npu and op.attrs.get("fused_activation_function", None) is not None:
         act = op.attrs["fused_activation_function"]
@@ -448,7 +448,7 @@
     return op
 
 
-def fixup_unpack_output(tens, arch):
+def fixup_unpack_output(tens, arch, nng):
     op = tens.ops[0]
     if op.type in set(("Unpack", "StridedSlice")):
         # Unpack is also referred to as Unstack
@@ -515,7 +515,7 @@
     return tens
 
 
-def add_padding_fields(op, arch):
+def add_padding_fields(op, arch, nng):
     if op.run_on_npu:
         if "padding" in op.attrs:
             if op.type in conv_op | depthwise_op:
@@ -564,7 +564,7 @@
     return None
 
 
-def mark_npu_block_type(op, arch):
+def mark_npu_block_type(op, arch, nng):
     npu_block_type = NpuBlockType.Default
     if op.type in conv_op:
         npu_block_type = NpuBlockType.ConvolutionMxN
@@ -583,7 +583,7 @@
     return op
 
 
-def convert_depthwise_to_conv(op, arch):
+def convert_depthwise_to_conv(op, arch, nng):
     # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and
     # the ofm depth equals the depth multipler.
     # If those conditions are true, then we can perform a simple
@@ -610,7 +610,7 @@
     return op
 
 
-def reorder_depthwise_weights(op, arch):
+def reorder_depthwise_weights(op, arch, nng):
     if op.type in depthwise_op:
         weight_tensor = op.inputs[1]
         weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2))
@@ -620,7 +620,7 @@
     return op
 
 
-def convert_conv_to_fc(op, arch):
+def convert_conv_to_fc(op, arch, nng):
     # Conv 1x1 can be equivalent to Fully Connected.
     # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
     # caching/double buffering for the weights.
@@ -661,7 +661,7 @@
     return op
 
 
-def fixup_relus_with_differing_ifm_ofm_scaling(op, arch):
+def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
     if op.run_on_npu and op.type in relu_ops:
         ifm = op.inputs[0]
         ofm = op.outputs[0]
@@ -690,7 +690,7 @@
 
 
 # Reorder activation op if it's after the memory only operations
-def fixup_act_reorder(op, arch):
+def fixup_act_reorder(op, arch, nng):
     if op.type in activation_ops:
         prep_op = get_prepend_op(op)
         if prep_op is not None:
@@ -715,7 +715,7 @@
     return op
 
 
-def fixup_elementwise_with_scalars(op, arch):
+def fixup_elementwise_with_scalars(op, arch, nng):
     if op.type in binary_elementwise_op:
         ifm_tensor, ifm2_tensor, _, _ = op.get_ifm_ifm2_weights_ofm()
         if ifm2_tensor.shape != [] and ifm_tensor.shape != []:
@@ -736,7 +736,7 @@
 
 
 # Set input/output tensor equivalence to the same id for memory operations
-def set_tensor_equivalence(op, arch):
+def set_tensor_equivalence(op, arch, nng):
     if op.type in memory_only_ops:
         eid = op.outputs[0].equivalence_id
         for inp in op.inputs:
@@ -744,14 +744,14 @@
     return op
 
 
-def convert_softmax(op, arch):
+def convert_softmax(op, arch, nng):
     if op.type == "Softmax" and op.run_on_npu:
         softmax = SoftMax(op)
         op = softmax.get_graph()
     return op
 
 
-def convert_mul_max_to_abs_or_lrelu(op, arch):
+def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
     r"""Whenever there is a subgraph with this topology:
 
        Input    X   For X = -1 or X > 0
@@ -958,7 +958,7 @@
     return convert_to_lut(op, values)
 
 
-def convert_lrelu(op, arch):
+def convert_lrelu(op, arch, nng):
     # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
     if op.type != "LeakyRelu":
         return op
@@ -972,7 +972,7 @@
     return convert_lrelu_to_mul_max(op, arch)
 
 
-def convert_tanh_sigmoid_to_lut(op, arch):
+def convert_tanh_sigmoid_to_lut(op, arch, nng):
     # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
     if op.type == "Sigmoid":
         return convert_to_lut8(op, clamp_sigmoid)
@@ -981,7 +981,7 @@
     return op
 
 
-def remove_unwanted_reshapes(op, arch):
+def remove_unwanted_reshapes(op, arch, nng):
     # Try to remove reshapes enclosing ElementWise operator with only one non-constant input
     if not op.run_on_npu or op.attrs["npu_block_type"] != NpuBlockType.ElementWise:
         return op
@@ -1016,7 +1016,7 @@
     return op
 
 
-def fuse_activation_function_with_prev(op, arch):
+def fuse_activation_function_with_prev(op, arch, nng):
     # if op is a no-op: attempts to move the activation function to the preceding op
     if not op.attrs.get("is_nop", False) or op.attrs.get("fused_activation_function", None) is None:
         return op
@@ -1049,7 +1049,7 @@
     return op
 
 
-def add_attrs_to_resizebilinear(op, arch):
+def add_attrs_to_resizebilinear(op, arch, nng):
     if op.type == "ResizeBilinear" and op.run_on_npu:
         input_tensor = op.inputs[0]
         upscaled_shape = [input_tensor.shape[1] * 2, input_tensor.shape[2] * 2]
@@ -1069,7 +1069,7 @@
     return op
 
 
-def fixup_bias_tensors(op, arch):
+def fixup_bias_tensors(op, arch, nng):
     if op.needs_bias() and not op.inputs[-1]:
         # Op has no bias, add bias tensor filled with zeros
         nr_biases = op.inputs[1].shape[-1]
@@ -1081,7 +1081,7 @@
     return op
 
 
-def supported_operator_check(op, arch):
+def supported_operator_check(op, arch, nng):
     op.run_on_npu = arch.supported_operators.is_operator_supported(op)
     return op
 
@@ -1121,13 +1121,13 @@
     for idx, sg in enumerate(nng.subgraphs):
         # rewrite graph pass
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            sg, arch, [fixup_unpack_output], op_rewrite_list, rewrite_unsupported=False
+            nng, sg, arch, [fixup_unpack_output], op_rewrite_list, rewrite_unsupported=False
         )
 
     for idx, sg in enumerate(nng.subgraphs):
         # remove passthrough tensors and attempt further optimizations
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields]
+            nng, sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields]
         )
 
     if verbose_graph:
@@ -1141,7 +1141,7 @@
 
     for idx, sg in enumerate(nng.subgraphs):
         # combined rewrite graph pass
-        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(sg, arch, [rewrite_concat, rewrite_split], [])
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [rewrite_concat, rewrite_split], [])
 
     if verbose_graph:
         nng.print_graph()
diff --git a/ethosu/vela/insert_dma.py b/ethosu/vela/insert_dma.py
index 9304526..99b46c0 100644
--- a/ethosu/vela/insert_dma.py
+++ b/ethosu/vela/insert_dma.py
@@ -21,12 +21,13 @@
 from .tensor import MemArea
 from .tensor import MemType
 from .tensor import TensorPurpose
+from .weight_compressor import compress_weights
 
 
 binary_elementwise_op = set(("AddAct", "MulAct", "SubAct", "Maximum", "Minimum"))
 
 
-def weights_fit_sram(arch, tens):
+def weights_fit_sram(arch, op, tens, nng):
     if tens.purpose != TensorPurpose.Weights:
         return True
 
@@ -36,25 +37,33 @@
     elif len(tens.shape) == 2:
         min_weight_size = tens.shape[0] * arch.OFMSplitDepth
 
-    w_compression = 1  # TODO worst compression ratio currently assumed
-
     # Need to be fit into Sram, as a double buffer
-    if (w_compression * min_weight_size * 2) > arch.sram_size:
-        print(
-            "Weights, {}, are too big to be DMAed to SRAM, estimated minimum size is {} bytes".format(
-                tens.name, (w_compression * min_weight_size * 2)
+    # Only evaluate when the compression test limit will make it impossible to fit
+    w_comp_test_limit = 2
+    if (w_comp_test_limit * min_weight_size * 2) > arch.sram_size:
+        # check worst compression ratio
+        npu_block_type = op.attrs.get("npu_block_type", NpuBlockType.Default)
+        compress_weights(arch, nng, tens, npu_block_type, 16, 16, op.get_dilation_h_w())
+
+        worst_buffer_size = tens.compression_scale_for_worst_weight_stream * min_weight_size * 2
+        if worst_buffer_size > arch.sram_size:
+            print(
+                "Weights, {}, are too big to be DMAed to SRAM, estimated minimum size is {} bytes".format(
+                    tens.name, worst_buffer_size
+                )
             )
-        )
-        return False
+            return False
     return True
 
 
-def insert_dma_cmd(op, arch):
+def insert_dma_cmd(op, arch, nng):
     if op.type == "DMA" or not op.run_on_npu:
         return op
 
-    is_lut_used         = any(inp.purpose == TensorPurpose.LUT for inp in op.inputs)
-    max_ifm_shram_avail = (arch.available_shram_banks(is_lut_used) - arch.shram_reserved_output_banks) * arch.shram_bank_size // 2
+    is_lut_used = any(inp.purpose == TensorPurpose.LUT for inp in op.inputs)
+    max_ifm_shram_avail = (
+        (arch.available_shram_banks(is_lut_used) - arch.shram_reserved_output_banks) * arch.shram_bank_size // 2
+    )
 
     for idx, tens in enumerate(op.inputs):
 
@@ -66,8 +75,11 @@
                 and arch.permanent_storage_mem_area != arch.fast_storage_mem_area
             ) or tens.purpose == TensorPurpose.LUT:
                 if tens.purpose in (TensorPurpose.Weights, TensorPurpose.LUT) or (
-                    tens.purpose == TensorPurpose.FeatureMap and op.type in binary_elementwise_op and
-                    tens.shape != [] and tens.shape != op.outputs[0].shape and tens.storage_size() > max_ifm_shram_avail
+                    tens.purpose == TensorPurpose.FeatureMap
+                    and op.type in binary_elementwise_op
+                    and tens.shape != []
+                    and tens.shape != op.outputs[0].shape
+                    and tens.storage_size() > max_ifm_shram_avail
                 ):
                     only_vector_product_consumers = True
                     for oper in tens.consumers():
@@ -79,7 +91,7 @@
                     # Other operations re-reads tensors, this is better done from SRAM.
                     # LUTs must be placed in the last 2 blocks of SHRAM.
                     if (
-                        not only_vector_product_consumers and weights_fit_sram(arch, tens)
+                        not only_vector_product_consumers and weights_fit_sram(arch, op, tens, nng)
                     ) or tens.purpose == TensorPurpose.LUT:
                         # Insert a DMA command here, as well as a new tensor situated in SRAM of the same size.
                         new_tens = tens.clone_into_fast_storage(arch)
@@ -98,7 +110,7 @@
 def insert_dma_commands(nng, arch, verbose_graph=False):
 
     for idx, sg in enumerate(nng.subgraphs):
-        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(sg, arch, [], [insert_dma_cmd])
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [insert_dma_cmd])
     if verbose_graph:
         nng.print_graph()
     return nng
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py
index a971ef2..c4496cd 100644
--- a/ethosu/vela/mark_tensors.py
+++ b/ethosu/vela/mark_tensors.py
@@ -266,7 +266,7 @@
             )  # special case constants, as they must be in permanent storage
             tens.mem_type = MemType.Permanent_NPU
 
-    def rewrite_mark_tensor_purpose(op, arch):
+    def rewrite_mark_tensor_purpose(op, arch, nng):
         # find disconnected outputs and mark as parameters
         for tens in op.outputs:
             if not tens.consumers():
@@ -308,7 +308,7 @@
         return op
 
     for sg in nng.subgraphs:
-        sg = rewrite_graph.rewrite_graph_pre_order(sg, arch, [], [rewrite_mark_tensor_purpose])
+        sg = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [rewrite_mark_tensor_purpose])
         for tens in sg.output_tensors:
             mark_tensor_helper(tens, TensorPurpose.FeatureMap)
 
diff --git a/ethosu/vela/rewrite_graph.py b/ethosu/vela/rewrite_graph.py
index e76e961..e71b228 100644
--- a/ethosu/vela/rewrite_graph.py
+++ b/ethosu/vela/rewrite_graph.py
@@ -24,7 +24,7 @@
 # Post-order traversal, this does not support rewrites. Therefore, functions must return the original value.
 
 
-def rewrite_graph_pre_order(sg, arch, tensor_rewrite_list, op_rewrite_list, rewrite_unsupported=True):
+def rewrite_graph_pre_order(nng, sg, arch, tensor_rewrite_list, op_rewrite_list, rewrite_unsupported=True):
 
     op_visit_dict = dict()
     tens_visit_dict = dict()
@@ -38,7 +38,7 @@
             prev_res = res
             for rewrite in op_rewrite_list:
                 if res.run_on_npu or rewrite_unsupported:
-                    res = rewrite(res, arch)
+                    res = rewrite(res, arch, nng)
 
         op_visit_dict[op] = res
         op_visit_dict[res] = res
@@ -64,7 +64,7 @@
         while prev_res != res:
             prev_res = res
             for rewrite in tensor_rewrite_list:
-                res = rewrite(res, arch)
+                res = rewrite(res, arch, nng)
 
         tens_visit_dict[tens] = res
         tens_visit_dict[res] = res
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index c5a3f3f..8426705 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -280,7 +280,6 @@
 # Compress the weights
 def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation):
     assert tens.purpose == TensorPurpose.Weights
-    assert tens.format == TensorFormat.WeightsCompressed
 
     # Check the weight cache
     if nng.weight_cache is None: