blob: b29f7a280519d1978e9ed27699530de19f018523 [file] [log] [blame]
Rickard Bolinbc6ee582022-11-04 08:24:29 +00001# SPDX-FileCopyrightText: Copyright 2020 Arm Limited and/or its affiliates <open-source-office@arm.com>
Tim Hall79d07d22020-04-27 18:20:16 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Tim Hall79d07d22020-04-27 18:20:16 +010017# Description:
18# Functions for abstracting out the traversal and rewriting of graphs so that the optimisation passes can focus on the
19# correct operation.
20#
21# Requires two lists, one of functions that rewrite Tensors, and one of functions that rewrite Operations.
22#
23# Pre-order traversal, this supports rewrites. Therefore, functions can return something other than the original value.
24#
25# Post-order traversal, this does not support rewrites. Therefore, functions must return the original value.
26
27
Patrik Gustavsson3010d9b2020-10-01 08:22:10 +020028def rewrite_graph_pre_order(nng, sg, arch, tensor_rewrite_list, op_rewrite_list, rewrite_unsupported=True):
Tim Hall79d07d22020-04-27 18:20:16 +010029
30 op_visit_dict = dict()
31 tens_visit_dict = dict()
32
33 def visit_op(op):
34 if op in op_visit_dict:
35 return op_visit_dict[op]
36 res = op
37 prev_res = None
38 while prev_res != res:
39 prev_res = res
40 for rewrite in op_rewrite_list:
41 if res.run_on_npu or rewrite_unsupported:
Patrik Gustavsson3010d9b2020-10-01 08:22:10 +020042 res = rewrite(res, arch, nng)
Tim Hall79d07d22020-04-27 18:20:16 +010043
44 op_visit_dict[op] = res
45 op_visit_dict[res] = res
46
47 inputs = res.inputs
48 res.inputs = []
49 for tens in inputs:
50 res.inputs.append(visit_tens(tens))
51
52 outputs = res.outputs
53 res.outputs = []
54 for tens in outputs:
55 res.outputs.append(visit_tens(tens))
56
57 return res
58
59 def visit_tens(tens):
60 if tens in tens_visit_dict:
61 return tens_visit_dict[tens]
62
63 res = tens
64 prev_res = None
65 while prev_res != res:
66 prev_res = res
67 for rewrite in tensor_rewrite_list:
Patrik Gustavsson3010d9b2020-10-01 08:22:10 +020068 res = rewrite(res, arch, nng)
Tim Hall79d07d22020-04-27 18:20:16 +010069
70 tens_visit_dict[tens] = res
71 tens_visit_dict[res] = res
72
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +020073 if res:
74 ops = res.ops
75 res.ops = []
76 for op in ops:
77 res.ops.append(visit_op(op))
Tim Hall79d07d22020-04-27 18:20:16 +010078 return res
79
80 sg.output_tensors = [visit_tens(tens) for tens in sg.output_tensors]
81 sg.refresh_after_modification()
82
83 return sg
84
85
Louis Verhaard17afa282020-10-14 08:32:41 +020086def visit_graph_post_order(start_tensors, arch, tensor_visit_list, op_visit_list):
87 # Depth-first graph traversal, starting from the given list of tensors
88 # (typically a subgraph's output_tensors).
89 # Visits ops and tensors in input to output order.
Tim Hall79d07d22020-04-27 18:20:16 +010090 op_visit_dict = dict()
91 tens_visit_dict = dict()
92
93 def visit_op(op):
94 if op in op_visit_dict:
Louis Verhaard17afa282020-10-14 08:32:41 +020095 return
Tim Hall79d07d22020-04-27 18:20:16 +010096 op_visit_dict[op] = op
97
98 for tens in op.inputs:
99 visit_tens(tens)
100
101 for visit in op_visit_list:
102 visit(op, arch)
103
104 for tens in op.outputs:
105 visit_tens(tens)
106
Tim Hall79d07d22020-04-27 18:20:16 +0100107 def visit_tens(tens):
Louis Verhaard17afa282020-10-14 08:32:41 +0200108 if tens is None or tens in tens_visit_dict:
109 return
Tim Hall79d07d22020-04-27 18:20:16 +0100110
111 tens_visit_dict[tens] = tens
112
113 for op in tens.ops:
114 visit_op(op)
115
116 for visit in tensor_visit_list:
117 visit(tens, arch)
118
Louis Verhaard17afa282020-10-14 08:32:41 +0200119 for tens in start_tensors:
Tim Hall79d07d22020-04-27 18:20:16 +0100120 visit_tens(tens)
121
Tim Hall79d07d22020-04-27 18:20:16 +0100122
123def verify_graph_health(nng):
124
125 for sg in nng.subgraphs:
126 verify_subgraph_health(sg)
127
128 return True
129
130
131def verify_subgraph_health(sg):
132 op_visit_dict = dict()
133 tens_visit_dict = dict()
134
135 def visit_op(op):
136 if op in op_visit_dict:
137 return op_visit_dict[op]
138 op_visit_dict[op] = op
139
140 for tens in op.inputs:
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200141 if not tens:
142 continue
Tim Hall79d07d22020-04-27 18:20:16 +0100143 assert op in tens.consumers()
144 visit_tens(tens)
145
146 for tens in op.outputs:
147 assert op in tens.ops
148 visit_tens(tens)
149
150 return op
151
152 def visit_tens(tens):
153 if tens in tens_visit_dict:
154 return tens_visit_dict[tens]
155
156 tens_visit_dict[tens] = tens
157
158 for op in tens.ops:
159 assert tens in op.outputs
160 visit_op(op)
161
162 return tens
163
164 for tens in sg.output_tensors:
165 visit_tens(tens)
166
167 return True