blob: 42acaf9bc9e59e29c53c5d4438284827a5e91dae [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Functions for abstracting out the traversal and rewriting of graphs so that the optimisation passes can focus on the
18# correct operation.
19#
20# Requires two lists, one of functions that rewrite Tensors, and one of functions that rewrite Operations.
21#
22# Pre-order traversal, this supports rewrites. Therefore, functions can return something other than the original value.
23#
24# Post-order traversal, this does not support rewrites. Therefore, functions must return the original value.
25
26
Patrik Gustavsson3010d9b2020-10-01 08:22:10 +020027def rewrite_graph_pre_order(nng, sg, arch, tensor_rewrite_list, op_rewrite_list, rewrite_unsupported=True):
Tim Hall79d07d22020-04-27 18:20:16 +010028
29 op_visit_dict = dict()
30 tens_visit_dict = dict()
31
32 def visit_op(op):
33 if op in op_visit_dict:
34 return op_visit_dict[op]
35 res = op
36 prev_res = None
37 while prev_res != res:
38 prev_res = res
39 for rewrite in op_rewrite_list:
40 if res.run_on_npu or rewrite_unsupported:
Patrik Gustavsson3010d9b2020-10-01 08:22:10 +020041 res = rewrite(res, arch, nng)
Tim Hall79d07d22020-04-27 18:20:16 +010042
43 op_visit_dict[op] = res
44 op_visit_dict[res] = res
45
46 inputs = res.inputs
47 res.inputs = []
48 for tens in inputs:
49 res.inputs.append(visit_tens(tens))
50
51 outputs = res.outputs
52 res.outputs = []
53 for tens in outputs:
54 res.outputs.append(visit_tens(tens))
55
56 return res
57
58 def visit_tens(tens):
59 if tens in tens_visit_dict:
60 return tens_visit_dict[tens]
61
62 res = tens
63 prev_res = None
64 while prev_res != res:
65 prev_res = res
66 for rewrite in tensor_rewrite_list:
Patrik Gustavsson3010d9b2020-10-01 08:22:10 +020067 res = rewrite(res, arch, nng)
Tim Hall79d07d22020-04-27 18:20:16 +010068
69 tens_visit_dict[tens] = res
70 tens_visit_dict[res] = res
71
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +020072 if res:
73 ops = res.ops
74 res.ops = []
75 for op in ops:
76 res.ops.append(visit_op(op))
Tim Hall79d07d22020-04-27 18:20:16 +010077 return res
78
79 sg.output_tensors = [visit_tens(tens) for tens in sg.output_tensors]
80 sg.refresh_after_modification()
81
82 return sg
83
84
Louis Verhaard17afa282020-10-14 08:32:41 +020085def visit_graph_post_order(start_tensors, arch, tensor_visit_list, op_visit_list):
86 # Depth-first graph traversal, starting from the given list of tensors
87 # (typically a subgraph's output_tensors).
88 # Visits ops and tensors in input to output order.
Tim Hall79d07d22020-04-27 18:20:16 +010089 op_visit_dict = dict()
90 tens_visit_dict = dict()
91
92 def visit_op(op):
93 if op in op_visit_dict:
Louis Verhaard17afa282020-10-14 08:32:41 +020094 return
Tim Hall79d07d22020-04-27 18:20:16 +010095 op_visit_dict[op] = op
96
97 for tens in op.inputs:
98 visit_tens(tens)
99
100 for visit in op_visit_list:
101 visit(op, arch)
102
103 for tens in op.outputs:
104 visit_tens(tens)
105
Tim Hall79d07d22020-04-27 18:20:16 +0100106 def visit_tens(tens):
Louis Verhaard17afa282020-10-14 08:32:41 +0200107 if tens is None or tens in tens_visit_dict:
108 return
Tim Hall79d07d22020-04-27 18:20:16 +0100109
110 tens_visit_dict[tens] = tens
111
112 for op in tens.ops:
113 visit_op(op)
114
115 for visit in tensor_visit_list:
116 visit(tens, arch)
117
Louis Verhaard17afa282020-10-14 08:32:41 +0200118 for tens in start_tensors:
Tim Hall79d07d22020-04-27 18:20:16 +0100119 visit_tens(tens)
120
Tim Hall79d07d22020-04-27 18:20:16 +0100121
122def verify_graph_health(nng):
123
124 for sg in nng.subgraphs:
125 verify_subgraph_health(sg)
126
127 return True
128
129
130def verify_subgraph_health(sg):
131 op_visit_dict = dict()
132 tens_visit_dict = dict()
133
134 def visit_op(op):
135 if op in op_visit_dict:
136 return op_visit_dict[op]
137 op_visit_dict[op] = op
138
139 for tens in op.inputs:
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200140 if not tens:
141 continue
Tim Hall79d07d22020-04-27 18:20:16 +0100142 assert op in tens.consumers()
143 visit_tens(tens)
144
145 for tens in op.outputs:
146 assert op in tens.ops
147 visit_tens(tens)
148
149 return op
150
151 def visit_tens(tens):
152 if tens in tens_visit_dict:
153 return tens_visit_dict[tens]
154
155 tens_visit_dict[tens] = tens
156
157 for op in tens.ops:
158 assert tens in op.outputs
159 visit_op(op)
160
161 return tens
162
163 for tens in sg.output_tensors:
164 visit_tens(tens)
165
166 return True