blob: 1a00373575628609ca4169be85a026201b71fbde [file] [log] [blame]
Rickard Bolinbc6ee582022-11-04 08:24:29 +00001# SPDX-FileCopyrightText: Copyright 2020, 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Tim Hall79d07d22020-04-27 18:20:16 +010017# Description:
18# Helper classes to track memory accesses for calculating dependencies between Commands.
Tim Hall79d07d22020-04-27 18:20:16 +010019from enum import IntEnum
Tim Hall79d07d22020-04-27 18:20:16 +010020from functools import lru_cache
21
22
23class RangeSet:
24 """A Range set class to track ranges and whether they intersect.
Jonas Ohlssond8575072022-03-30 10:30:25 +020025 Intended for e.g. tracking sets of memory ranges and whether two commands use the same memory areas."""
Tim Hall79d07d22020-04-27 18:20:16 +010026
27 def __init__(self, start=None, end=None, ranges=None):
28 if ranges is None:
29 ranges = []
30
31 self.ranges = ranges # track a list of (start, end) tuples, always in ascending order sorted by start.
32
33 if start is not None and start != end:
34 assert start < end
35 self.ranges.append((start, end))
36
37 def __or__(self, other):
38 combined_ranges = list(sorted(self.ranges + other.ranges))
39 return RangeSet(ranges=combined_ranges)
40
41 def __ior__(self, other):
42 self.ranges = list(sorted(self.ranges + other.ranges))
43 return self
44
45 def intersects(self, other):
46 a_ranges = self.ranges
47 b_ranges = other.ranges
48
49 a_idx = 0
50 b_idx = 0
51
52 while a_idx < len(a_ranges) and b_idx < len(b_ranges):
53 ar = a_ranges[a_idx]
54 br = b_ranges[b_idx]
55 if max(ar[0], br[0]) < min(ar[1], br[1]):
56 return True # intersection
57
58 # advance one of the two upwards
59 if ar[0] < br[0]:
60 a_idx += 1
61 else:
62 assert ar[0] != br[0]
63 # note ar[0] == br[0] cannot happen, then we'd have an intersection
64 b_idx += 1
65
66 return False
67
68 def __str__(self):
69 return "<RangeSet %s>" % (["%#x:%#x" % (int(start), int(end)) for start, end in self.ranges],)
70
71 __repr__ = __str__
72
73
74class MemoryRangeSet:
75 """Extended version of the RangeSet class that handles having different memory areas"""
76
77 def __init__(self, mem_area=None, start=None, end=None, regions=None):
78
79 if regions is None:
80 regions = {}
81 self.regions = regions
82
83 if mem_area is not None:
84 self.regions[mem_area] = RangeSet(start, end)
85
86 def __or__(self, other):
87 combined_regions = {
88 mem_area: (self.regions.get(mem_area, RangeSet()) | other.regions.get(mem_area, RangeSet()))
89 for mem_area in (self.regions.keys() | other.regions.keys())
90 }
91 return MemoryRangeSet(regions=combined_regions)
92
93 def __ior__(self, other):
94 self.regions = {
95 mem_area: (self.regions.get(mem_area, RangeSet()) | other.regions.get(mem_area, RangeSet()))
96 for mem_area in (self.regions.keys() | other.regions.keys())
97 }
98 return self
99
100 def intersects(self, other):
101 for mem_area in self.regions.keys() & other.regions.keys():
102 if self.regions[mem_area].intersects(other.regions[mem_area]):
103 return True
104 return False
105
106 def __str__(self):
107 s = "<MemoryRangeSet>"
108 for mem_area, rng in self.regions.items():
109 s += "%s: %s\t" % (mem_area, rng)
110 return s
111
112 __repr__ = __str__
113
114
115class AccessDirection(IntEnum):
116 Read = 0
117 Write = 1
118 Size = 2
119
120
121class MemoryAccessSet:
122 """Tracks memory ranges, but also access patterns to know which accesses actually are in conflict"""
123
124 def __init__(self):
125 self.accesses = [MemoryRangeSet() for i in range(AccessDirection.Size)]
126
127 def add(self, memory_range_set, access):
128 self.accesses[access] |= memory_range_set
129
130 @lru_cache(maxsize=None)
131 def conflicts(self, other):
132
133 # True dependencies, or write -> read
134 if self.accesses[AccessDirection.Write].intersects(other.accesses[AccessDirection.Read]):
135 return True
136
137 # Anti-dependencies, or read -> write
138 if self.accesses[AccessDirection.Read].intersects(other.accesses[AccessDirection.Write]):
139 return True
140
141 # Output dependencies, or write -> write
142 if self.accesses[AccessDirection.Write].intersects(other.accesses[AccessDirection.Write]):
143 return True
144
145 # read -> read does not cause a conflict
146 return False
147
148 def __str__(self):
149 return "Read: %s\nWrite: %s\n\n" % (self.accesses[AccessDirection.Read], self.accesses[AccessDirection.Write])
150
151 __repr__ = __str__