MLBEDSW-2809: Redo the Tensor addressing

Added a static class TensorAddressMap that stores all Tensor addresses
based on their equivalence_id. Made the "address" field into a property
which getter and setter looks up/sets the tensor's address in
TensorAddressMap.

This makes the references to cpu_tensor/npu_tensor obsolete and they
have been removed.

Addition to scheduler: avoid SRAM spilling if an op has consumers in
other subgraphs.

Minor rework in LUTState; it will now assign a unique equivalence_id to
the SHRAM lut tensor to avoid issues with addressing. The equivalent
checks in LUTState now compares the values of the LUT instead of the the
equivalence_id.

Updated LUT unit tests accordingly.

Signed-off-by: Jacob Bohlin <jacob.bohlin@arm.com>
Change-Id: I41de5a8a4e5f07b77d6544d8d4034b754993e503
diff --git a/ethosu/vela/extract_npu_subgraphs.py b/ethosu/vela/extract_npu_subgraphs.py
index 4adddc1..c0430b5 100644
--- a/ethosu/vela/extract_npu_subgraphs.py
+++ b/ethosu/vela/extract_npu_subgraphs.py
@@ -70,10 +70,7 @@
     orig_tens, call_ps, startup_init_ps, npu_subgraph, cpu_subgraph, subgraph_for_pass
 ):
     is_const = orig_tens.ops[0].type == "Const"
-
     new_tens = orig_tens.clone("_npu")
-    orig_tens.npu_tensor = new_tens
-    new_tens.cpu_tensor = orig_tens
 
     op_type = "SubgraphInput"
     if is_const:
@@ -107,9 +104,6 @@
 ):
 
     new_tens = orig_tens.clone("_cpu")
-    new_tens.npu_tensor = orig_tens
-    orig_tens.cpu_tensor = new_tens
-
     npu_subgraph.output_tensors.append(orig_tens)
 
     call_ps.outputs.append(new_tens)
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 156090f..9a8ee58 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -84,21 +84,11 @@
         return self.name < other.name
 
     def set_address(self, address):
-        # Set address of all unaddressed tensors in LiveRange
+        # Set address of all tensors in LiveRange
         for tens in self.tensors:
-            if tens.address is None:
-                addr = address
-            else:
-                # Limit to single tensor for the lr if the tensor address already assigned
-                assert len(self.tensors) == 1
-                addr = tens.address
-            tens.address = addr
-            # Also need to set the address to the tensor's cpu/npu clones
-            if tens.cpu_tensor is not None:
-                tens.cpu_tensor.address = addr
-            if tens.npu_tensor is not None:
-                tens.npu_tensor.address = addr
-        return addr
+            tens.address = address
+
+        return address
 
     def get_alignment(self):
         return self.alignment
@@ -113,10 +103,6 @@
             # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
             input_tensor = ps.inputs[0]
             output_tensor = ps.outputs[0]
-            # If the input or output tensor is tied to a Cpu tensor, i.e. a subgraph input
-            # or output, fuse the live-range with the Cpu tensors' live-range instead.
-            input_tensor = input_tensor.cpu_tensor if input_tensor.cpu_tensor is not None else input_tensor
-            output_tensor = output_tensor.cpu_tensor if output_tensor.cpu_tensor is not None else output_tensor
             if not tensor_should_be_ignored(input_tensor, target_mem_area) and not tensor_should_be_ignored(
                 output_tensor, target_mem_area
             ):
@@ -132,9 +118,9 @@
         self.current_time = 0
 
     def get_or_create_range(self, tens, alignment=Tensor.AllocationQuantum):
-        for rng in self.ranges.values():
-            # Return the live range of the tensor (or it's cpu/npu clone)
-            if any(tensor in rng.tensors for tensor in [tens, tens.npu_tensor, tens.cpu_tensor]):
+        # Return the live range of the tensor (or any of its clones)
+        for existing_tensor, rng in self.ranges.items():
+            if tens.equivalent(existing_tensor):
                 rng.set_alignment(alignment)
                 return rng
 
@@ -252,10 +238,6 @@
                 # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange
                 input_tensor = ps.inputs[0]
                 output_tensor = ps.outputs[0]
-                # If the input or output tensor is tied to a Cpu tensor, i.e. a subgraph input
-                # or output, fuse the live-range with the Cpu tensors' live-range instead.
-                input_tensor = input_tensor.cpu_tensor if input_tensor.cpu_tensor is not None else input_tensor
-                output_tensor = output_tensor.cpu_tensor if output_tensor.cpu_tensor is not None else output_tensor
                 if not tensor_should_be_ignored(input_tensor, target_mem_area, target_mem_type_set) and not (
                     tensor_should_be_ignored(output_tensor, target_mem_area, target_mem_type_set)
                 ):
diff --git a/ethosu/vela/lut.py b/ethosu/vela/lut.py
index 0e8dcc9..e3373ca 100644
--- a/ethosu/vela/lut.py
+++ b/ethosu/vela/lut.py
@@ -42,9 +42,9 @@
         self.tensors = []
 
     def get_equivalent(self, lut_tens):
-        # Returns existing lut with same equivalence id, None if not found
+        # Returns existing lut with the same values, None if not found
         for t in self.tensors:
-            if t.equivalent(lut_tens):
+            if np.array_equal(t.values, lut_tens.values):
                 return t
         return None
 
@@ -60,6 +60,7 @@
             end2 = start2 + tens.storage_size()
             if not numeric_util.overlaps(start, end, start2, end2):
                 new_state.tensors.append(tens)
+
         return new_state
 
     def find_best_address(self, start, stop, step):
@@ -129,6 +130,7 @@
         # Place the LUT in the last 2 blocks of SHRAM
         # Alignment is always on the size of the LUT, 256 for 256-byte LUT, 1K for 1K LUT, etc
         address = lut_state.find_best_address(lut_start, lut_end, lut_tens.storage_size())
+        lut_tens.equivalence_id = uuid.uuid4()
         lut_tens.address = address
         cmd.ps.primary_op.attrs["lut_index"] = (address - lut_start) // slot_size
         lut_state = lut_state.put(lut_tens)
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index e9a93c1..47f8a47 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -35,6 +35,7 @@
 from .npu_performance import make_macs_array
 from .npu_performance import make_metrics_arrays
 from .npu_performance import PassCycles
+from .numeric_util import full_shape
 from .operation import NpuBlockType
 from .shared_buffer_allocation import find_block_configs_suitable_for_pass_and_shared_buffer
 from .shared_buffer_allocation import shared_buffer_allocation_for_pass_and_block_config
@@ -43,7 +44,7 @@
 from .tensor import TensorFormat
 from .tensor import TensorPurpose
 from .tensor import TensorSubPurpose
-from .numeric_util import full_shape
+
 
 class ParetoMetric(enum.Enum):
     BwCycMem = 1
@@ -652,6 +653,9 @@
         for op in pred_candidate.ops:
             if op.type == "ConcatSliceWrite":
                 return True
+            if len(op.outputs) > 1 or len(op.outputs[0].consumer_list) > 1:
+                # The op has consumers in other subgraphs
+                return True
         return False
 
     def search_ifm_streaming_partial(self, ps, block_config):
@@ -976,8 +980,15 @@
                                 # be processed by CPU operations. No-op reshape consumers with empty lists
                                 # (those that have no consumers, or null-consumers used as list terminators)
                                 # must use normal NHWC output.
-                                incompatible_consumers = [ (not consumer.run_on_npu or consumer.type == "Reshape" or (consumer is last_op_in_subgraph))
-                                                           for consumer in op.outputs[0].consumer_list if consumer is not None ]
+                                incompatible_consumers = [
+                                    (
+                                        not consumer.run_on_npu
+                                        or consumer.type == "Reshape"
+                                        or (consumer is last_op_in_subgraph)
+                                    )
+                                    for consumer in op.outputs[0].consumer_list
+                                    if consumer is not None
+                                ]
                                 if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers):
                                     rewrites.append(op)
                                 else:
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 49521e7..0f8170d 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -17,6 +17,7 @@
 # Internal representation of a Neural Network Tensor.
 import enum
 import uuid
+from collections import defaultdict
 
 import numpy as np
 
@@ -258,6 +259,25 @@
     return reshape_ofm if ifm_reshape else reshape_ifm
 
 
+# class that keeps track of all tensor addresses in the different memory types
+class TensorAddressMap:
+    address_map = defaultdict(dict)  # dict (tens.equivalence_id -> dict (mem_type -> address))
+
+    @classmethod
+    def get_address_for_tens(cls, tens_id, mem_type):
+        return cls.address_map[tens_id].get(mem_type)
+
+    @classmethod
+    def set_address_for_tens(cls, tens_id, mem_type, address):
+        # Check previous address if there is one
+        previous_address = cls.address_map[tens_id].get(mem_type)
+        if previous_address is not None:
+            assert previous_address == address, "Two different addresses cannot be assigned to the same tensor."
+
+        # Set tensor's address for memory type
+        cls.address_map[tens_id][mem_type] = address
+
+
 class Tensor:
     __slots__ = (
         "shape",
@@ -285,13 +305,10 @@
         "weight_compression_config",
         "storage_rounding_quantum",
         "brick_size",
-        "address",
         "quantization",
         "weight_compressed_offsets",
         "element_size_bytes",
         "block_traversal",
-        "cpu_tensor",
-        "npu_tensor",
         "equivalence_id",
         "resampling_mode",
         "avoid_NHCWB16",
@@ -308,10 +325,6 @@
 
         self.ops = []
         self.consumer_list = []
-        # Below attributes are only set if a tensor has been cloned,
-        # either from Cpu -> Npu or vice versa. Needed for offline allocation
-        self.cpu_tensor = None  # reference to the corresponding Cpu tensor
-        self.npu_tensor = None  # reference to the corresponding Npu tensor
 
         self.values = None
         self.quant_values = None
@@ -333,7 +346,6 @@
         self.weight_compressed_offsets = []
         self.storage_rounding_quantum = (1, 1, 1, 1)
         self.brick_size = (1, 1, 1, 1)
-        self.address = None  # start address of tensor. will be filled in by tensor allocator
         self.element_size_bytes = 0
 
         # quantization parameters
@@ -343,6 +355,14 @@
 
         self.avoid_NHCWB16 = False
 
+    @property
+    def address(self):
+        return TensorAddressMap.get_address_for_tens(self.equivalence_id, self.mem_type)
+
+    @address.setter
+    def address(self, address):
+        TensorAddressMap.set_address_for_tens(self.equivalence_id, self.mem_type, address)
+
     def element_size(self):
         if self.element_size_bytes == 0:
             return self.dtype.size_in_bits() / 8
@@ -367,7 +387,6 @@
         res.alignment = self.alignment
         res.bandwidth_compression_scale = self.bandwidth_compression_scale
         res.storage_rounding_quantum = self.storage_rounding_quantum
-        res.address = None
 
         if self.quantization is not None:
             res.quantization = self.quantization.clone()
diff --git a/ethosu/vela/test/test_lut.py b/ethosu/vela/test/test_lut.py
index 3dda179..ee1a40f 100644
--- a/ethosu/vela/test/test_lut.py
+++ b/ethosu/vela/test/test_lut.py
@@ -15,6 +15,8 @@
 # limitations under the License.
 # Description:
 # Unit tests for LUT support
+import random
+
 import numpy as np
 
 from ethosu.vela import insert_dma
@@ -31,29 +33,29 @@
 
 
 def set_256_lut(op, key):
-    values = list(range(256))
+    random.seed(key)
+    values = random.choices(range(256), k=256)
     lut_tensor = create_const_tensor(
         op.name + "_lut", [1, 1, 1, 256], DataType.int8, values, np.uint8, TensorPurpose.LUT
     )
-    lut_tensor.equivalence_id = lut.create_equivalence_id(key)
     op.set_activation_lut(lut_tensor)
 
 
 def set_1K_lut(op, key):
-    values = list(range(256))
+    random.seed(key)
+    values = random.choices(range(256), k=256)
     lut_tensor = create_const_tensor(
         op.name + "_lut", [1, 1, 1, 256], DataType.int32, values, np.uint32, TensorPurpose.LUT
     )
-    lut_tensor.equivalence_id = lut.create_equivalence_id(key)
     op.set_activation_lut(lut_tensor)
 
 
 def set_2K_lut(op, key):
-    values = list(range(512))
+    random.seed(key)
+    values = random.choices(range(512), k=512)
     lut_tensor = create_const_tensor(
         op.name + "_lut", [1, 1, 1, 512], DataType.int32, values, np.uint32, TensorPurpose.LUT
     )
-    lut_tensor.equivalence_id = lut.create_equivalence_id(key)
     op.set_activation_lut(lut_tensor)