| # Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved. |
| # |
| # SPDX-License-Identifier: Apache-2.0 |
| # |
| # Licensed under the Apache License, Version 2.0 (the License); you may |
| # not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an AS IS BASIS, WITHOUT |
| # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # Description: |
| # Early optimisation of the network graph, using the rewrite_graph module to do the traversal of the graph. |
| from . import rewrite_graph |
| from .graph_optimiser_util import check_format_restrictions |
| from .graph_optimiser_util import check_memory_only_removed |
| from .graph_optimiser_util import record_optimised |
| from .nn_graph import NetworkType |
| from .tflite_graph_optimiser import tflite_optimise_graph |
| from .tosa_graph_optimiser import tosa_optimise_graph |
| |
| |
| def optimise_graph(nng, arch, network_type, verbose_graph=False): |
| if verbose_graph: |
| nng.print_graph("Before Graph Optimization") |
| |
| if network_type == NetworkType.TFLite: |
| # TensorFlow Lite graph optimization |
| nng = tflite_optimise_graph(nng, arch) |
| else: |
| # TOSA graph optimization |
| nng = tosa_optimise_graph(nng, arch) |
| |
| # Post-optimisation operator debug tracing, and checking that no undesired reshapes are left in the graph |
| for sg in nng.subgraphs: |
| rewrite_graph.visit_graph_post_order( |
| sg.output_tensors, arch, [check_format_restrictions], [check_memory_only_removed, record_optimised] |
| ) |
| |
| if verbose_graph: |
| nng.print_graph("After Graph Optimization") |
| return nng |