MLBEDSW-3808: Ported search allocator to python

- Straight port of the C++ implementation to python.
- Renamed the allocator from "Search" to "HillClimb"

Change-Id: I50797d541f326d0264daf79bf7866aef32350a60
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/vela/hillclimb_allocation.py b/ethosu/vela/hillclimb_allocation.py
new file mode 100644
index 0000000..de53ab8
--- /dev/null
+++ b/ethosu/vela/hillclimb_allocation.py
@@ -0,0 +1,310 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Description:
+# Tensor allocator based on a hill-climb search
+import random
+from typing import List
+from typing import Set
+
+from .live_range import LiveRange
+
+
+class LiveRangeInfo:
+    def __init__(self, id: int, start_time: int, end_time: int, size: int):
+        # Index of this live range
+        self.id = id
+        # Start time (input to the allocator algorithm)
+        self.start_time = start_time
+        # End time, inclusive (input to the allocator algorithm)
+        self.end_time = end_time
+        # Size in bytes (input to the allocator algorithm)
+        self.size = size
+        # Allocated address (the main output from the allocator algorithm)
+        self.address: int = 0
+        # End address, exclusive
+        self.end_address: int = 0
+        # id of predecessor live range (predecessor's end address == this lr's address)
+        self.predecessor: int = 0
+        # Turn at which the live range was allocated
+        self.turn: int = 0
+        # Max value of size_at_time (only used in the heuristic allocation)
+        self.urgency = 0
+        self.neighbours: List["LiveRangeInfo"] = []
+
+    def overlaps(self, addr2: int, size2: int) -> int:
+        return self.address < addr2 + size2 and addr2 < self.end_address
+
+    def is_neighbour(self, lr: "LiveRangeInfo") -> bool:
+        return self.start_time <= lr.end_time and lr.start_time <= self.end_time
+
+    def __str__(self):
+        return "<LiveRangeInfo: id={}, start_time={}, end_time={}, size={}, address={}>".format(
+            self.id, self.start_time, self.end_time, self.size, self.address
+        )
+
+    def __lt__(self, other) -> bool:
+        if self.urgency != other.urgency:
+            return self.urgency > other.urgency
+        duration1 = self.end_time - self.start_time
+        duration2 = other.end_time - other.start_time
+        if duration1 != duration2:
+            return duration1 > duration2
+
+        if self.start_time != other.start_time:
+            return self.start_time < other.start_time
+
+        if self.size != other.size:
+            return self.size > other.size
+
+        return self.id < other.id
+
+
+class HillClimbAllocator:
+    """
+    Implements tensor allocator using a hill climb search.
+
+    The basic algorithm is:
+
+    Use a heuristic allocator to find an initial allocation
+    while allocation is not optimal and iterations < MAX_ITERATIONS:
+        find the "bottleneck": the live range with highest end address
+        find all live ranges that affected the allocation of the bottleneck
+        swap the order of any two affecting live ranges
+        reallocate tensors using the reordered live ranges
+        if the new allocation is better: keep it, else set allocation to previous allocation
+    """
+
+    MAX_ITERATIONS = 500
+    NOT_ALLOCATED = -1
+    # Used for live ranges allocated at address 0
+    NO_PREDECESSOR = -1
+
+    def __init__(self, live_ranges: List[LiveRange]):
+        # Contains the live ranges
+        self.lrs: List[LiveRangeInfo] = [
+            LiveRangeInfo(id, lr.start_time, lr.end_time, lr.size) for id, lr in enumerate(live_ranges)
+        ]
+        self.lrs_at_time = []
+        # The available size (input to algorithm).
+        self.available_size: int = 0
+        # The algorithm stops once the target size has been achieved
+        self.target_size: int = 0
+        # The highest end address of the best found allocation
+        self.best_size: int = 1 << 63
+        # For each live range: max value of size_at_time (only used in the heuristic allocation)
+        self.lr_urgency = len(self.lrs) * [0]
+        nr_time_slots = 1 + max(lr.end_time for lr in self.lrs)
+        # Contains active live ranges at each timestamp
+        self.lrs_at_time = [[] for i in range(nr_time_slots)]
+        for lr in self.lrs:
+            for t in range(lr.start_time, lr.end_time + 1):
+                self.lrs_at_time[t].append(lr)
+        # At each timestamp: accumulated size of active live ranges
+        size_at_time = [sum(lr.size for lr in self.lrs_at_time[t]) for t in range(nr_time_slots)]
+        # The minimum possible size, assuming all live ranges can be perfectly allocated
+        self.min_required_size: int = max(size_at_time)
+        # Calculate all neighbours + the urgency of each live range
+        for lr in self.lrs:
+            lr.urgency = 0
+            lr.neighbours = []
+            neighbours = set()
+            for t in range(lr.start_time, lr.end_time + 1):
+                lr.urgency = max(size_at_time[t], lr.urgency)
+                for lr2 in self.lrs_at_time[t]:
+                    if lr2 not in neighbours and lr != lr2:
+                        neighbours.add(lr2)
+                        lr.neighbours.append(lr2)
+
+    def allocate_lr(self, lr: LiveRangeInfo):
+        """
+        Allocates the given live range at the smallest possible address
+        """
+        address = 0
+        predecessor = HillClimbAllocator.NO_PREDECESSOR
+        fits = False
+        while not fits:
+            fits = True
+            # Find neighbours that overlap with address
+            for lr2 in lr.neighbours:
+                if lr2.address == HillClimbAllocator.NOT_ALLOCATED or lr2.end_address <= address:
+                    continue
+                if lr2.overlaps(address, lr.size):
+                    # Overlap found increase address
+                    fits = False
+                    address = lr2.end_address
+                    predecessor = lr2.id
+        lr.address = address
+        lr.end_address = address + lr.size
+        lr.predecessor = predecessor
+
+    def allocate_indices(self, indices: List[int]):
+        """
+        Allocates the live ranges in the order indicated by the indices;
+        allocates each live range at the lowest possible address.
+        """
+        for lr in self.lrs:
+            lr.address = HillClimbAllocator.NOT_ALLOCATED
+        size = 0
+        for turn, index in enumerate(indices):
+            lr = self.lrs[index]
+            self.allocate_lr(lr)
+            lr.turn = turn
+            size = max(size, lr.end_address)
+            if size > self.best_size:
+                # This allocation is worse than the best known allocation;
+                # no need to continue
+                break
+        return size
+
+    def add_predecessor_turns(self, turn_set: Set[int], turn_list: List[int], lr: LiveRangeInfo):
+        """
+        Adds the given live range + predecessors to the turns vector.
+        Note: the turn_set is for quick detection of duplicates,
+        the turn_list is to get reproduceable results
+        """
+        if lr.turn not in turn_set:
+            turn_set.add(lr.turn)
+            turn_list.append(lr.turn)
+        id = lr.id
+        while self.lrs[id].predecessor != HillClimbAllocator.NO_PREDECESSOR:
+            id = self.lrs[id].predecessor
+            turn = self.lrs[id].turn
+            if turn not in turn_set:
+                turn_set.add(turn)
+                turn_list.append(turn)
+
+    def attempt_bottleneck_fix(self, indices: List[int]):
+        """
+        Finds the "bottleneck", the live range with highest end address, and reorders the indices
+        such that a next allocation might lower the memory usage.
+
+                                  ---------
+                                  |       |
+                                  |   D   |
+                                  |       |
+         ----------------------------------
+         |           B                 |
+         -------------------------------
+         | |
+         |A|                      ---
+         | |                      |C|
+         | |                      | |
+         ---------------------------------------
+
+        In the above example, the allocation order was [A, B, C, D] and D is the resulting bottle-neck.
+        The live ranges that affected the allocation of D are the direct neighbours of D (i.e. B and C),
+        and all direct and indirect predecessors of D and its neighbours
+        (i.e. A, which is the predecessor of B, and indirect predecessor of D).
+
+        By permuting the order in which the affecting live ranges are allocated, the bottleneck might
+        be lowered. In the above example, almost any permutation would lower the bottleneck.
+        """
+        # Find the bottleneck
+        max_lr = self.lrs[0]
+        for lr in self.lrs[1:]:
+            if lr.end_address > max_lr.end_address:
+                max_lr = lr
+
+        # Find all live ranges that affected the placement of the bottleneck live range.
+        # This consists of two types of live ranges:
+        # - direct neighbours of the bottleneck live range
+        # - direct and indirect predecessors of these neighbours + bottleneck
+        # The turns at which these live ranges were allocated are put in the turns set.
+        turn_set = set()
+        turn_list = list()
+        self.add_predecessor_turns(turn_set, turn_list, max_lr)
+        for lr2 in max_lr.neighbours:
+            self.add_predecessor_turns(turn_set, turn_list, lr2)
+
+        # Non-direct neighbours that interfere with the allocation of the bottleneck are the
+        # immediate cause for gaps in the allocation, and are selected with higher probability.
+        non_nb_turn_list = []
+        for turn in turn_list:
+            lr = self.lrs[indices[turn]]
+            if not max_lr.is_neighbour(lr):
+                non_nb_turn_list.append(turn)
+        assert turn_list
+        # Pick from non-neighbour list with 30% probability
+        # (magic number based on tuning)
+        if random.randint(0, 100) < 30 and non_nb_turn_list:
+            # Pick a live range from the "non-neighbour list"
+            ix1 = non_nb_turn_list[random.randint(0, len(non_nb_turn_list) - 1)]
+        else:
+            # Pick any affecting live range.
+            ix1 = turn_list[random.randint(0, len(turn_list) - 1)]
+
+        ix2 = turn_list[random.randint(0, len(turn_list) - 2)]
+        if ix1 == ix2:
+            ix2 = turn_list[-1]
+        # Swap indices
+        indices[ix1], indices[ix2] = indices[ix2], indices[ix1]
+
+    def search(self, indices: List[int], iterations: int):
+        """
+        Search for a solution, using the given indices as initial solution.
+        """
+        best_indices = indices[:]
+        for _ in range(iterations):
+            # Reorder the indices
+            self.attempt_bottleneck_fix(indices)
+            # Allocate the reordered indices and check if it gave an improvement
+            new_size = self.allocate_indices(indices)
+            if new_size <= self.best_size:
+                # The new allocation produced a new best result remember it
+                self.best_size = new_size
+                best_indices = indices[:]
+                self.allocated_addresses = [lr.address for lr in self.lrs]
+                if self.best_size <= self.min_required_size:
+                    # Target reached stop
+                    return
+            else:
+                # The new allocation produced worse result undo the change
+                indices = best_indices[:]
+
+    def allocate(self) -> List[int]:
+        """
+        Runs the allocation algorithm. Finishes when an optimal solution has been
+        found or when maximum iterations have been run.
+        The allocated addresses are placed in the output vector, in the same
+        order as the input vector.
+
+        Implementation note: the algorithm produces reproduceable results by using
+        a well-defined random number generator with well-defined default seed,
+        and using a fixed number of iterations.
+        """
+        random.seed(1)
+        # Sort indices on priority. Note: self.lrs must be left unchanged
+        indices = [lr.id for lr in sorted(self.lrs)]
+        # Allocate the initial solution
+        self.best_size = self.allocate_indices(indices)
+        self.allocated_addresses = [lr.address for lr in self.lrs]
+        if self.best_size > self.min_required_size:
+            # Try to improve the heuristic allocation
+            self.search(indices, HillClimbAllocator.MAX_ITERATIONS)
+        # else the heuristic allocation returned an optimal solution; no search needed
+        return self.allocated_addresses
+
+
+def allocate_live_ranges(lrs: List[LiveRange]) -> List[int]:
+    """
+    Allocates live ranges using a search based allocator.
+    Returns the list of allocated addresses (one for each live range)
+    """
+    if not lrs:
+        return []
+    allocator = HillClimbAllocator(lrs)
+    return allocator.allocate()
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index db878bc..c45d0e3 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -38,7 +38,7 @@
 class TensorAllocator(enum.Enum):
     LinearAlloc = 1
     Greedy = 2
-    Search = 3
+    HillClimb = 3
 
     def __str__(self):
         return self.name
diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py
index 0a7da5d..b7057f0 100644
--- a/ethosu/vela/tensor_allocation.py
+++ b/ethosu/vela/tensor_allocation.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -20,6 +20,7 @@
 
 import numpy as np
 
+from . import hillclimb_allocation
 from . import live_range
 from . import numeric_util
 from .errors import AllocationError
@@ -30,7 +31,6 @@
 from .tensor import MemType
 from .tensor import Tensor
 from .tensor import TensorPurpose
-from ethosu import tensor_allocator
 
 
 def linear_allocate_live_ranges(live_ranges, alloc_granularity=Tensor.AllocationQuantum):
@@ -63,22 +63,14 @@
     return total_sz
 
 
-def search_allocate_live_ranges(live_ranges: LiveRangeGraph, alloc_granularity: int) -> int:
-    # Allocates using the search-based allocator (implemented in C++)
-    input = []
-    lrs = []
-    lr_set = set()
-    for lr in live_ranges.ranges.values():
-        lr_set.add((lr.start_time, lr.end_time, lr))
-    lr_list = sorted(lr_set)
-    # Create a single array of ints containing start/end/size of the live ranges
-    for start, end, lr in lr_list:
-        input += [start, end, numeric_util.round_up(lr.size, alloc_granularity)]
-        lrs.append(lr)
-    addresses = tensor_allocator.allocate(input, 0)
+def hillclimb_allocate_live_ranges(live_ranges: LiveRangeGraph, alloc_granularity: int) -> int:
+    # Allocates using the hill climb allocator
+    lr_set = {(lr.start_time, lr.end_time, lr) for lr in live_ranges.ranges.values()}
+    lr_list = [lr for _, _, lr in lr_set]
+    addresses = hillclimb_allocation.allocate_live_ranges(lr_list)
     # The result is a list containing the allocated addresses
     total_sz = 0
-    for lr, address in zip(lrs, addresses):
+    for lr, address in zip(lr_list, addresses):
         total_sz = max(total_sz, address + lr.size)
         lr.set_address(address)
     verify_allocation(live_ranges, alloc_granularity)
@@ -179,8 +171,8 @@
             verify_allocation(lrs, cpu_tensor_alignment)
         elif tens_alloc == TensorAllocator.LinearAlloc:
             total_sz = linear_allocate_live_ranges(lrs, cpu_tensor_alignment)
-        elif tens_alloc == TensorAllocator.Search:
-            total_sz = search_allocate_live_ranges(lrs, cpu_tensor_alignment)
+        elif tens_alloc == TensorAllocator.HillClimb:
+            total_sz = hillclimb_allocate_live_ranges(lrs, cpu_tensor_alignment)
         else:
             assert 0
         alloc_ok = max_size is None or total_sz <= max_size
diff --git a/ethosu/vela/test/test_hillclimb_allocation.py b/ethosu/vela/test/test_hillclimb_allocation.py
new file mode 100644
index 0000000..8a56c3f
--- /dev/null
+++ b/ethosu/vela/test/test_hillclimb_allocation.py
@@ -0,0 +1,63 @@
+# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Description:
+# Unit tests for hillclimb_allocator.
+import pytest
+
+from ethosu.vela.hillclimb_allocation import allocate_live_ranges
+from ethosu.vela.live_range import LiveRange
+
+
+test_data = [
+    ([(0, 100, 8000), (0, 1, 8016), (100, 110, 2000), (108, 110, 4000), (109, 110, 6000)], 16016),
+    (
+        [
+            (0, 23, 131072),
+            (4, 5, 65568),
+            (4, 9, 8192),
+            (8, 30, 15360),
+            (10, 11, 65568),
+            (10, 15, 4096),
+            (16, 17, 65552),
+            (16, 21, 2048),
+            (22, 23, 32784),
+            (22, 27, 1024),
+        ],
+        216096,
+    ),
+]
+
+
+def live_range(start_time, end_time, size):
+    lr = LiveRange(None, 1)
+    lr.start_time = start_time
+    lr.end_time = end_time
+    lr.size = size
+    return lr
+
+
+@pytest.mark.parametrize("lrs, expected_size", test_data)
+def test_allocate(lrs, expected_size):
+    """Tests the search allocator"""
+    lr_list = [live_range(start, end, size) for start, end, size in lrs]
+    res = allocate_live_ranges(lr_list)
+    assert len(res) == len(lrs)
+    assert max(addr + lr[2] for addr, lr in zip(res, lrs)) == expected_size
+
+
+def test_allocate_empty_input():
+    assert [] == allocate_live_ranges([])
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index c4510b1..7271002 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -317,7 +317,7 @@
     )
     parser.add_argument(
         "--tensor-allocator",
-        default=TensorAllocator.Search,
+        default=TensorAllocator.HillClimb,
         type=lambda s: TensorAllocator[s],
         choices=list(TensorAllocator),
         help="Tensor Allocator algorithm (default: %(default)s)",