blob: 10a5c835c24c05885c3a92c5b5431a87a11353c0 [file] [log] [blame]
Rickard Bolinbc6ee582022-11-04 08:24:29 +00001# SPDX-FileCopyrightText: Copyright 2020, 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Tim Hall79d07d22020-04-27 18:20:16 +010017# Description:
18# Helper classes to track memory accesses for calculating dependencies between Commands.
Tim Hall79d07d22020-04-27 18:20:16 +010019from enum import IntEnum
Tim Hall79d07d22020-04-27 18:20:16 +010020from functools import lru_cache
21
22
23class RangeSet:
24 """A Range set class to track ranges and whether they intersect.
Jonas Ohlssond8575072022-03-30 10:30:25 +020025 Intended for e.g. tracking sets of memory ranges and whether two commands use the same memory areas."""
Tim Hall79d07d22020-04-27 18:20:16 +010026
27 def __init__(self, start=None, end=None, ranges=None):
28 if ranges is None:
29 ranges = []
30
31 self.ranges = ranges # track a list of (start, end) tuples, always in ascending order sorted by start.
32
33 if start is not None and start != end:
Tim Hall79d07d22020-04-27 18:20:16 +010034 self.ranges.append((start, end))
35
36 def __or__(self, other):
37 combined_ranges = list(sorted(self.ranges + other.ranges))
38 return RangeSet(ranges=combined_ranges)
39
40 def __ior__(self, other):
41 self.ranges = list(sorted(self.ranges + other.ranges))
42 return self
43
44 def intersects(self, other):
45 a_ranges = self.ranges
46 b_ranges = other.ranges
47
48 a_idx = 0
49 b_idx = 0
50
51 while a_idx < len(a_ranges) and b_idx < len(b_ranges):
52 ar = a_ranges[a_idx]
53 br = b_ranges[b_idx]
54 if max(ar[0], br[0]) < min(ar[1], br[1]):
55 return True # intersection
56
57 # advance one of the two upwards
58 if ar[0] < br[0]:
59 a_idx += 1
60 else:
61 assert ar[0] != br[0]
62 # note ar[0] == br[0] cannot happen, then we'd have an intersection
63 b_idx += 1
64
65 return False
66
67 def __str__(self):
68 return "<RangeSet %s>" % (["%#x:%#x" % (int(start), int(end)) for start, end in self.ranges],)
69
70 __repr__ = __str__
71
72
73class MemoryRangeSet:
74 """Extended version of the RangeSet class that handles having different memory areas"""
75
76 def __init__(self, mem_area=None, start=None, end=None, regions=None):
77
78 if regions is None:
79 regions = {}
80 self.regions = regions
81
82 if mem_area is not None:
83 self.regions[mem_area] = RangeSet(start, end)
84
85 def __or__(self, other):
86 combined_regions = {
87 mem_area: (self.regions.get(mem_area, RangeSet()) | other.regions.get(mem_area, RangeSet()))
88 for mem_area in (self.regions.keys() | other.regions.keys())
89 }
90 return MemoryRangeSet(regions=combined_regions)
91
92 def __ior__(self, other):
93 self.regions = {
94 mem_area: (self.regions.get(mem_area, RangeSet()) | other.regions.get(mem_area, RangeSet()))
95 for mem_area in (self.regions.keys() | other.regions.keys())
96 }
97 return self
98
99 def intersects(self, other):
100 for mem_area in self.regions.keys() & other.regions.keys():
101 if self.regions[mem_area].intersects(other.regions[mem_area]):
102 return True
103 return False
104
105 def __str__(self):
106 s = "<MemoryRangeSet>"
107 for mem_area, rng in self.regions.items():
108 s += "%s: %s\t" % (mem_area, rng)
109 return s
110
111 __repr__ = __str__
112
113
114class AccessDirection(IntEnum):
115 Read = 0
116 Write = 1
117 Size = 2
118
119
120class MemoryAccessSet:
121 """Tracks memory ranges, but also access patterns to know which accesses actually are in conflict"""
122
123 def __init__(self):
124 self.accesses = [MemoryRangeSet() for i in range(AccessDirection.Size)]
125
126 def add(self, memory_range_set, access):
127 self.accesses[access] |= memory_range_set
128
129 @lru_cache(maxsize=None)
130 def conflicts(self, other):
131
132 # True dependencies, or write -> read
133 if self.accesses[AccessDirection.Write].intersects(other.accesses[AccessDirection.Read]):
134 return True
135
136 # Anti-dependencies, or read -> write
137 if self.accesses[AccessDirection.Read].intersects(other.accesses[AccessDirection.Write]):
138 return True
139
140 # Output dependencies, or write -> write
141 if self.accesses[AccessDirection.Write].intersects(other.accesses[AccessDirection.Write]):
142 return True
143
144 # read -> read does not cause a conflict
145 return False
146
147 def __str__(self):
148 return "Read: %s\nWrite: %s\n\n" % (self.accesses[AccessDirection.Read], self.accesses[AccessDirection.Write])
149
150 __repr__ = __str__