Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 1 | # Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. |
| 2 | # |
| 3 | # SPDX-License-Identifier: Apache-2.0 |
| 4 | # |
| 5 | # Licensed under the Apache License, Version 2.0 (the License); you may |
| 6 | # not use this file except in compliance with the License. |
| 7 | # You may obtain a copy of the License at |
| 8 | # |
| 9 | # www.apache.org/licenses/LICENSE-2.0 |
| 10 | # |
| 11 | # Unless required by applicable law or agreed to in writing, software |
| 12 | # distributed under the License is distributed on an AS IS BASIS, WITHOUT |
| 13 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | # See the License for the specific language governing permissions and |
| 15 | # limitations under the License. |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 16 | # Description: |
| 17 | # Helper classes to track memory accesses for calculating dependencies between Commands. |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 18 | from enum import IntEnum |
Tim Hall | 79d07d2 | 2020-04-27 18:20:16 +0100 | [diff] [blame] | 19 | from functools import lru_cache |
| 20 | |
| 21 | |
| 22 | class RangeSet: |
| 23 | """A Range set class to track ranges and whether they intersect. |
| 24 | Intended for e.g. tracking sets of memory ranges and whether two commands use the same memory areas.""" |
| 25 | |
| 26 | def __init__(self, start=None, end=None, ranges=None): |
| 27 | if ranges is None: |
| 28 | ranges = [] |
| 29 | |
| 30 | self.ranges = ranges # track a list of (start, end) tuples, always in ascending order sorted by start. |
| 31 | |
| 32 | if start is not None and start != end: |
| 33 | assert start < end |
| 34 | self.ranges.append((start, end)) |
| 35 | |
| 36 | def __or__(self, other): |
| 37 | combined_ranges = list(sorted(self.ranges + other.ranges)) |
| 38 | return RangeSet(ranges=combined_ranges) |
| 39 | |
| 40 | def __ior__(self, other): |
| 41 | self.ranges = list(sorted(self.ranges + other.ranges)) |
| 42 | return self |
| 43 | |
| 44 | def intersects(self, other): |
| 45 | a_ranges = self.ranges |
| 46 | b_ranges = other.ranges |
| 47 | |
| 48 | a_idx = 0 |
| 49 | b_idx = 0 |
| 50 | |
| 51 | while a_idx < len(a_ranges) and b_idx < len(b_ranges): |
| 52 | ar = a_ranges[a_idx] |
| 53 | br = b_ranges[b_idx] |
| 54 | if max(ar[0], br[0]) < min(ar[1], br[1]): |
| 55 | return True # intersection |
| 56 | |
| 57 | # advance one of the two upwards |
| 58 | if ar[0] < br[0]: |
| 59 | a_idx += 1 |
| 60 | else: |
| 61 | assert ar[0] != br[0] |
| 62 | # note ar[0] == br[0] cannot happen, then we'd have an intersection |
| 63 | b_idx += 1 |
| 64 | |
| 65 | return False |
| 66 | |
| 67 | def __str__(self): |
| 68 | return "<RangeSet %s>" % (["%#x:%#x" % (int(start), int(end)) for start, end in self.ranges],) |
| 69 | |
| 70 | __repr__ = __str__ |
| 71 | |
| 72 | |
| 73 | class MemoryRangeSet: |
| 74 | """Extended version of the RangeSet class that handles having different memory areas""" |
| 75 | |
| 76 | def __init__(self, mem_area=None, start=None, end=None, regions=None): |
| 77 | |
| 78 | if regions is None: |
| 79 | regions = {} |
| 80 | self.regions = regions |
| 81 | |
| 82 | if mem_area is not None: |
| 83 | self.regions[mem_area] = RangeSet(start, end) |
| 84 | |
| 85 | def __or__(self, other): |
| 86 | combined_regions = { |
| 87 | mem_area: (self.regions.get(mem_area, RangeSet()) | other.regions.get(mem_area, RangeSet())) |
| 88 | for mem_area in (self.regions.keys() | other.regions.keys()) |
| 89 | } |
| 90 | return MemoryRangeSet(regions=combined_regions) |
| 91 | |
| 92 | def __ior__(self, other): |
| 93 | self.regions = { |
| 94 | mem_area: (self.regions.get(mem_area, RangeSet()) | other.regions.get(mem_area, RangeSet())) |
| 95 | for mem_area in (self.regions.keys() | other.regions.keys()) |
| 96 | } |
| 97 | return self |
| 98 | |
| 99 | def intersects(self, other): |
| 100 | for mem_area in self.regions.keys() & other.regions.keys(): |
| 101 | if self.regions[mem_area].intersects(other.regions[mem_area]): |
| 102 | return True |
| 103 | return False |
| 104 | |
| 105 | def __str__(self): |
| 106 | s = "<MemoryRangeSet>" |
| 107 | for mem_area, rng in self.regions.items(): |
| 108 | s += "%s: %s\t" % (mem_area, rng) |
| 109 | return s |
| 110 | |
| 111 | __repr__ = __str__ |
| 112 | |
| 113 | |
| 114 | class AccessDirection(IntEnum): |
| 115 | Read = 0 |
| 116 | Write = 1 |
| 117 | Size = 2 |
| 118 | |
| 119 | |
| 120 | class MemoryAccessSet: |
| 121 | """Tracks memory ranges, but also access patterns to know which accesses actually are in conflict""" |
| 122 | |
| 123 | def __init__(self): |
| 124 | self.accesses = [MemoryRangeSet() for i in range(AccessDirection.Size)] |
| 125 | |
| 126 | def add(self, memory_range_set, access): |
| 127 | self.accesses[access] |= memory_range_set |
| 128 | |
| 129 | @lru_cache(maxsize=None) |
| 130 | def conflicts(self, other): |
| 131 | |
| 132 | # True dependencies, or write -> read |
| 133 | if self.accesses[AccessDirection.Write].intersects(other.accesses[AccessDirection.Read]): |
| 134 | return True |
| 135 | |
| 136 | # Anti-dependencies, or read -> write |
| 137 | if self.accesses[AccessDirection.Read].intersects(other.accesses[AccessDirection.Write]): |
| 138 | return True |
| 139 | |
| 140 | # Output dependencies, or write -> write |
| 141 | if self.accesses[AccessDirection.Write].intersects(other.accesses[AccessDirection.Write]): |
| 142 | return True |
| 143 | |
| 144 | # read -> read does not cause a conflict |
| 145 | return False |
| 146 | |
| 147 | def __str__(self): |
| 148 | return "Read: %s\nWrite: %s\n\n" % (self.accesses[AccessDirection.Read], self.accesses[AccessDirection.Write]) |
| 149 | |
| 150 | __repr__ = __str__ |