| # Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. |
| # |
| # SPDX-License-Identifier: Apache-2.0 |
| # |
| # Licensed under the Apache License, Version 2.0 (the License); you may |
| # not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an AS IS BASIS, WITHOUT |
| # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # Description: |
| # Functions for abstracting out the traversal and rewriting of graphs so that the optimisation passes can focus on the |
| # correct operation. |
| # |
| # Requires two lists, one of functions that rewrite Tensors, and one of functions that rewrite Operations. |
| # |
| # Pre-order traversal, this supports rewrites. Therefore, functions can return something other than the original value. |
| # |
| # Post-order traversal, this does not support rewrites. Therefore, functions must return the original value. |
| |
| |
| def rewrite_graph_pre_order(nng, sg, arch, tensor_rewrite_list, op_rewrite_list, rewrite_unsupported=True): |
| |
| op_visit_dict = dict() |
| tens_visit_dict = dict() |
| |
| def visit_op(op): |
| if op in op_visit_dict: |
| return op_visit_dict[op] |
| res = op |
| prev_res = None |
| while prev_res != res: |
| prev_res = res |
| for rewrite in op_rewrite_list: |
| if res.run_on_npu or rewrite_unsupported: |
| res = rewrite(res, arch, nng) |
| |
| op_visit_dict[op] = res |
| op_visit_dict[res] = res |
| |
| inputs = res.inputs |
| res.inputs = [] |
| for tens in inputs: |
| res.inputs.append(visit_tens(tens)) |
| |
| outputs = res.outputs |
| res.outputs = [] |
| for tens in outputs: |
| res.outputs.append(visit_tens(tens)) |
| |
| return res |
| |
| def visit_tens(tens): |
| if tens in tens_visit_dict: |
| return tens_visit_dict[tens] |
| |
| res = tens |
| prev_res = None |
| while prev_res != res: |
| prev_res = res |
| for rewrite in tensor_rewrite_list: |
| res = rewrite(res, arch, nng) |
| |
| tens_visit_dict[tens] = res |
| tens_visit_dict[res] = res |
| |
| if res: |
| ops = res.ops |
| res.ops = [] |
| for op in ops: |
| res.ops.append(visit_op(op)) |
| return res |
| |
| sg.output_tensors = [visit_tens(tens) for tens in sg.output_tensors] |
| sg.refresh_after_modification() |
| |
| return sg |
| |
| |
| def visit_graph_post_order(sg, arch, tensor_visit_list, op_visit_list): |
| |
| op_visit_dict = dict() |
| tens_visit_dict = dict() |
| |
| def visit_op(op): |
| if op in op_visit_dict: |
| return op_visit_dict[op] |
| op_visit_dict[op] = op |
| |
| for tens in op.inputs: |
| visit_tens(tens) |
| |
| for visit in op_visit_list: |
| visit(op, arch) |
| |
| for tens in op.outputs: |
| visit_tens(tens) |
| |
| return op |
| |
| def visit_tens(tens): |
| if tens in tens_visit_dict: |
| return tens_visit_dict[tens] |
| |
| tens_visit_dict[tens] = tens |
| |
| for op in tens.ops: |
| visit_op(op) |
| |
| for visit in tensor_visit_list: |
| visit(tens, arch) |
| |
| return tens |
| |
| for tens in sg.output_tensors: |
| visit_tens(tens) |
| |
| sg.refresh_after_modification() |
| |
| return sg |
| |
| |
| def verify_graph_health(nng): |
| |
| for sg in nng.subgraphs: |
| verify_subgraph_health(sg) |
| |
| return True |
| |
| |
| def verify_subgraph_health(sg): |
| op_visit_dict = dict() |
| tens_visit_dict = dict() |
| |
| def visit_op(op): |
| if op in op_visit_dict: |
| return op_visit_dict[op] |
| op_visit_dict[op] = op |
| |
| for tens in op.inputs: |
| if not tens: |
| continue |
| assert op in tens.consumers() |
| visit_tens(tens) |
| |
| for tens in op.outputs: |
| assert op in tens.ops |
| visit_tens(tens) |
| |
| return op |
| |
| def visit_tens(tens): |
| if tens in tens_visit_dict: |
| return tens_visit_dict[tens] |
| |
| tens_visit_dict[tens] = tens |
| |
| for op in tens.ops: |
| assert tens in op.outputs |
| visit_op(op) |
| |
| return tens |
| |
| for tens in sg.output_tensors: |
| visit_tens(tens) |
| |
| return True |