blob: 64de9709c5319e6e7a3b6ac3730fbcccd676414f [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18# Description:
19# Helper classes to track memory accesses for calculating dependencies between Commands.
20
21from enum import IntEnum
22from collections import defaultdict
23from functools import lru_cache
24
25
26class RangeSet:
27 """A Range set class to track ranges and whether they intersect.
28Intended for e.g. tracking sets of memory ranges and whether two commands use the same memory areas."""
29
30 def __init__(self, start=None, end=None, ranges=None):
31 if ranges is None:
32 ranges = []
33
34 self.ranges = ranges # track a list of (start, end) tuples, always in ascending order sorted by start.
35
36 if start is not None and start != end:
37 assert start < end
38 self.ranges.append((start, end))
39
40 def __or__(self, other):
41 combined_ranges = list(sorted(self.ranges + other.ranges))
42 return RangeSet(ranges=combined_ranges)
43
44 def __ior__(self, other):
45 self.ranges = list(sorted(self.ranges + other.ranges))
46 return self
47
48 def intersects(self, other):
49 a_ranges = self.ranges
50 b_ranges = other.ranges
51
52 a_idx = 0
53 b_idx = 0
54
55 while a_idx < len(a_ranges) and b_idx < len(b_ranges):
56 ar = a_ranges[a_idx]
57 br = b_ranges[b_idx]
58 if max(ar[0], br[0]) < min(ar[1], br[1]):
59 return True # intersection
60
61 # advance one of the two upwards
62 if ar[0] < br[0]:
63 a_idx += 1
64 else:
65 assert ar[0] != br[0]
66 # note ar[0] == br[0] cannot happen, then we'd have an intersection
67 b_idx += 1
68
69 return False
70
71 def __str__(self):
72 return "<RangeSet %s>" % (["%#x:%#x" % (int(start), int(end)) for start, end in self.ranges],)
73
74 __repr__ = __str__
75
76
77class MemoryRangeSet:
78 """Extended version of the RangeSet class that handles having different memory areas"""
79
80 def __init__(self, mem_area=None, start=None, end=None, regions=None):
81
82 if regions is None:
83 regions = {}
84 self.regions = regions
85
86 if mem_area is not None:
87 self.regions[mem_area] = RangeSet(start, end)
88
89 def __or__(self, other):
90 combined_regions = {
91 mem_area: (self.regions.get(mem_area, RangeSet()) | other.regions.get(mem_area, RangeSet()))
92 for mem_area in (self.regions.keys() | other.regions.keys())
93 }
94 return MemoryRangeSet(regions=combined_regions)
95
96 def __ior__(self, other):
97 self.regions = {
98 mem_area: (self.regions.get(mem_area, RangeSet()) | other.regions.get(mem_area, RangeSet()))
99 for mem_area in (self.regions.keys() | other.regions.keys())
100 }
101 return self
102
103 def intersects(self, other):
104 for mem_area in self.regions.keys() & other.regions.keys():
105 if self.regions[mem_area].intersects(other.regions[mem_area]):
106 return True
107 return False
108
109 def __str__(self):
110 s = "<MemoryRangeSet>"
111 for mem_area, rng in self.regions.items():
112 s += "%s: %s\t" % (mem_area, rng)
113 return s
114
115 __repr__ = __str__
116
117
118class AccessDirection(IntEnum):
119 Read = 0
120 Write = 1
121 Size = 2
122
123
124class MemoryAccessSet:
125 """Tracks memory ranges, but also access patterns to know which accesses actually are in conflict"""
126
127 def __init__(self):
128 self.accesses = [MemoryRangeSet() for i in range(AccessDirection.Size)]
129
130 def add(self, memory_range_set, access):
131 self.accesses[access] |= memory_range_set
132
133 @lru_cache(maxsize=None)
134 def conflicts(self, other):
135
136 # True dependencies, or write -> read
137 if self.accesses[AccessDirection.Write].intersects(other.accesses[AccessDirection.Read]):
138 return True
139
140 # Anti-dependencies, or read -> write
141 if self.accesses[AccessDirection.Read].intersects(other.accesses[AccessDirection.Write]):
142 return True
143
144 # Output dependencies, or write -> write
145 if self.accesses[AccessDirection.Write].intersects(other.accesses[AccessDirection.Write]):
146 return True
147
148 # read -> read does not cause a conflict
149 return False
150
151 def __str__(self):
152 return "Read: %s\nWrite: %s\n\n" % (self.accesses[AccessDirection.Read], self.accesses[AccessDirection.Write])
153
154 __repr__ = __str__