blob: e6e24e6291acbd3e23355c405c4346514e716da6 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18# Description:
19# Functions for abstracting out the traversal and rewriting of graphs so that the optimisation passes can focus on the
20# correct operation.
21#
22# Requires two lists, one of functions that rewrite Tensors, and one of functions that rewrite Operations.
23#
24# Pre-order traversal, this supports rewrites. Therefore, functions can return something other than the original value.
25#
26# Post-order traversal, this does not support rewrites. Therefore, functions must return the original value.
27
28
29def rewrite_graph_pre_order(sg, arch, tensor_rewrite_list, op_rewrite_list, rewrite_unsupported=True):
30
31 op_visit_dict = dict()
32 tens_visit_dict = dict()
33
34 def visit_op(op):
35 if op in op_visit_dict:
36 return op_visit_dict[op]
37 res = op
38 prev_res = None
39 while prev_res != res:
40 prev_res = res
41 for rewrite in op_rewrite_list:
42 if res.run_on_npu or rewrite_unsupported:
43 res = rewrite(res, arch)
44
45 op_visit_dict[op] = res
46 op_visit_dict[res] = res
47
48 inputs = res.inputs
49 res.inputs = []
50 for tens in inputs:
51 res.inputs.append(visit_tens(tens))
52
53 outputs = res.outputs
54 res.outputs = []
55 for tens in outputs:
56 res.outputs.append(visit_tens(tens))
57
58 return res
59
60 def visit_tens(tens):
61 if tens in tens_visit_dict:
62 return tens_visit_dict[tens]
63
64 res = tens
65 prev_res = None
66 while prev_res != res:
67 prev_res = res
68 for rewrite in tensor_rewrite_list:
69 res = rewrite(res, arch)
70
71 tens_visit_dict[tens] = res
72 tens_visit_dict[res] = res
73
74 ops = res.ops
75 res.ops = []
76 for op in ops:
77 res.ops.append(visit_op(op))
78 return res
79
80 sg.output_tensors = [visit_tens(tens) for tens in sg.output_tensors]
81 sg.refresh_after_modification()
82
83 return sg
84
85
86def visit_graph_post_order(sg, arch, tensor_visit_list, op_visit_list):
87
88 op_visit_dict = dict()
89 tens_visit_dict = dict()
90
91 def visit_op(op):
92 if op in op_visit_dict:
93 return op_visit_dict[op]
94 op_visit_dict[op] = op
95
96 for tens in op.inputs:
97 visit_tens(tens)
98
99 for visit in op_visit_list:
100 visit(op, arch)
101
102 for tens in op.outputs:
103 visit_tens(tens)
104
105 return op
106
107 def visit_tens(tens):
108 if tens in tens_visit_dict:
109 return tens_visit_dict[tens]
110
111 tens_visit_dict[tens] = tens
112
113 for op in tens.ops:
114 visit_op(op)
115
116 for visit in tensor_visit_list:
117 visit(tens, arch)
118
119 return tens
120
121 for tens in sg.output_tensors:
122 visit_tens(tens)
123
124 sg.refresh_after_modification()
125
126 return sg
127
128
129def verify_graph_health(nng):
130
131 for sg in nng.subgraphs:
132 verify_subgraph_health(sg)
133
134 return True
135
136
137def verify_subgraph_health(sg):
138 op_visit_dict = dict()
139 tens_visit_dict = dict()
140
141 def visit_op(op):
142 if op in op_visit_dict:
143 return op_visit_dict[op]
144 op_visit_dict[op] = op
145
146 for tens in op.inputs:
147 assert op in tens.consumers()
148 visit_tens(tens)
149
150 for tens in op.outputs:
151 assert op in tens.ops
152 visit_tens(tens)
153
154 return op
155
156 def visit_tens(tens):
157 if tens in tens_visit_dict:
158 return tens_visit_dict[tens]
159
160 tens_visit_dict[tens] = tens
161
162 for op in tens.ops:
163 assert tens in op.outputs
164 visit_op(op)
165
166 return tens
167
168 for tens in sg.output_tensors:
169 visit_tens(tens)
170
171 return True