blob: 33f1a02c9dc5218801cffc805445bb052696cec6 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18# Description:
19# Insert DMA operations into the graph for transfering weights.
20
Tim Hall79d07d22020-04-27 18:20:16 +010021from . import rewrite_graph
Diego Russoea6111a2020-04-14 18:41:58 +010022from .tensor import MemArea, TensorPurpose
23from .operation import Operation, NpuBlockType
Tim Hall79d07d22020-04-27 18:20:16 +010024
25
26def insert_dma_cmd(op, arch):
27 if op.type == "DMA":
Diego Russoea6111a2020-04-14 18:41:58 +010028 return op # Already rewritten
Tim Hall79d07d22020-04-27 18:20:16 +010029 for idx, tens in enumerate(op.inputs):
30
31 if tens.mem_area in (MemArea.Dram, MemArea.OffChipFlash) and tens.mem_area != arch.fast_storage_mem_area:
32 if tens.purpose == TensorPurpose.Weights:
33 only_vector_product_consumers = True
34 for oper in tens.consumers():
35 if oper is None or oper.attrs.get("npu_block_type") != NpuBlockType.VectorProduct:
36 only_vector_product_consumers = False
37 break
38
39 # Tensor products has no need for DMA, tensors are only read once and can be in flash.
40 # Other operations re-reads tensors, this is better done from SRAM.
41 if not only_vector_product_consumers:
42 # Insert a DMA command here, as well as a new tensor situated in SRAM of the same size.
43 new_tens = tens.clone_into_fast_storage(arch)
44 dma_cmd = Operation("DMA", tens.ops[0].name + "_dma")
45 dma_cmd.inputs = [tens]
46 dma_cmd.outputs = [new_tens]
47 dma_cmd.attrs["source"] = tens.mem_area
48 dma_cmd.attrs["destination"] = new_tens.mem_area
49 dma_cmd.run_on_npu = True
50 new_tens.ops = [dma_cmd]
51 op.inputs[idx] = new_tens
52 return op
53
54
55def insert_dma_commands(nng, arch, verbose_graph=False):
56
57 for idx, sg in enumerate(nng.subgraphs):
58 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(sg, arch, [], [insert_dma_cmd])
59 if verbose_graph:
60 nng.print_graph()
61 return nng