blob: 6747ec985cfe0e573db331202fe00fb628bd1e11 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Vela separates CPU operations and NPU operations into separate internal subgraphs. The CPU operations are left
18# untouched in the final output.
19#
20# Vela does this by identifying NPU passes and pulling them out from the main CPU graph into separate subgraphs, invoked
21# by NpuOp operations. Later, Vela generates command streams and compressed weight streams for the NPU subgraphs and
22# attaches them to the NpuOp. This encapsulates everything the NPU subgraph is supposed to do.
Tim Hall79d07d22020-04-27 18:20:16 +010023import numpy as np
24
Diego Russoe8a10452020-04-21 17:39:10 +010025from .nn_graph import Pass
26from .nn_graph import PassPlacement
27from .nn_graph import Subgraph
28from .operation import NpuBlockType
29from .operation import Operation
Diego Russoea6111a2020-04-14 18:41:58 +010030
Tim Hall79d07d22020-04-27 18:20:16 +010031
32def make_npu_call_op_pass(npu_subgraph):
33 op = Operation("NpuOp", "call_" + npu_subgraph.name)
34 op.attrs["subgraph"] = npu_subgraph
35 ps = Pass(op.name, PassPlacement.MemoryOnly, False, NpuBlockType.Default)
36 ps.ops = [op]
37 ps.primary_op = op
38 op.attrs["npu_block_type"] = ps.npu_block_type
39 op.scheduled_pass = ps
40
41 # Inputs and outputs filled in later as we cut the graphs
42 return ps
43
44
45def switch_tensor_for_op(op, orig_tens, new_tens):
46
47 op.inputs = [new_tens if tens == orig_tens else tens for tens in op.inputs]
48 op.outputs = [new_tens if tens == orig_tens else tens for tens in op.outputs]
49
50 ps = op.scheduled_pass
51 if ps is None:
52 return
53
54 ps.inputs = [new_tens if tens == orig_tens else tens for tens in ps.inputs]
55 ps.outputs = [new_tens if tens == orig_tens else tens for tens in ps.outputs]
56
57 if ps.ifm_tensor == orig_tens:
58 ps.ifm_tensor = new_tens
59 if ps.ifm2_tensor == orig_tens:
60 ps.ifm2_tensor = new_tens
61 if ps.ofm_tensor == orig_tens:
62 ps.ofm_tensor = new_tens
63 if ps.weight_tensor == orig_tens:
64 ps.weight_tensor = new_tens
65 if ps.scale_tensor == orig_tens:
66 ps.scale_tensor = new_tens
67
68
69def rewrite_tensor_cpu_producer_npu_consumers(
70 orig_tens, call_ps, startup_init_ps, npu_subgraph, cpu_subgraph, subgraph_for_pass
71):
72 is_const = orig_tens.ops[0].type == "Const"
73
74 new_tens = orig_tens.clone("_npu")
75 orig_tens.npu_tensor = new_tens
76 new_tens.cpu_tensor = orig_tens
77
78 op_type = "SubgraphInput"
79 if is_const:
80 op_type = "Const"
81 op = Operation(op_type, orig_tens.name + "_input")
82 op.attrs["npu_block_type"] = NpuBlockType.Default
83 op.outputs = [new_tens]
84 op.scheduled_pass = startup_init_ps
85 new_tens.ops = [op]
86 startup_init_ps.ops.append(op)
87 startup_init_ps.outputs.append(new_tens)
88
89 if not is_const:
90 call_ps.inputs.append(orig_tens)
91 call_ps.primary_op.inputs.append(orig_tens)
92
93 for op in list(orig_tens.consumers()):
94 if op is None:
95 continue # Subgraph consumers handled separately.
96 ps = op.scheduled_pass
97 if subgraph_for_pass[ps] == npu_subgraph:
98 switch_tensor_for_op(op, orig_tens, new_tens)
99 orig_tens.consumer_list.remove(op)
100 new_tens.consumer_list.append(op)
101
102 # Deal with output tensors for the NPU graph. These are special.
103 npu_subgraph.output_tensors = [new_tens if tens == orig_tens else tens for tens in npu_subgraph.output_tensors]
104
105
106def rewrite_tensor_npu_producer_cpu_consumers(
107 orig_tens, call_ps, startup_init_ps, npu_subgraph, cpu_subgraph, subgraph_for_pass
108):
109
110 new_tens = orig_tens.clone("_cpu")
111 new_tens.npu_tensor = orig_tens
112 orig_tens.cpu_tensor = new_tens
113
114 npu_subgraph.output_tensors.append(orig_tens)
115
116 call_ps.outputs.append(new_tens)
117 call_ps.primary_op.outputs.append(new_tens)
118 new_tens.ops = [call_ps.primary_op]
119
120 for op in list(orig_tens.consumers()):
121 if op is None:
122 continue # Subgraph consumers handled separately.
123 ps = op.scheduled_pass
124 if subgraph_for_pass[ps] != npu_subgraph:
125 switch_tensor_for_op(op, orig_tens, new_tens)
126 orig_tens.consumer_list.remove(op)
127 new_tens.consumer_list.append(op)
128
129 # Deal with output tensors for the CPU graph. These are special.
130 cpu_subgraph.output_tensors = [new_tens if tens == orig_tens else tens for tens in cpu_subgraph.output_tensors]
131
132
133def extract_subgraph(nng, orig_sg, arch):
134 assert orig_sg.placement == PassPlacement.Cpu
135
136 passes = list(orig_sg.passes)
137 place_vec = np.array([ps.placement for ps in passes])
138 place_vec[
139 place_vec == PassPlacement.StartupInit
140 ] = PassPlacement.Cpu # Keep the startup init pass on the CPU, we'll make new ones to move onto NPU.
141
142 # MemoryOnly passes that are either squeezed between NPU passes or on the boundary of NPU and CPU
143 # passes should be assigned to the NPU.
144
145 # Forward, then backwards
146 for is_reversed in range(2):
147 last_place = PassPlacement.Cpu
148 seq = enumerate(place_vec)
149 if is_reversed:
150 seq = reversed(list(seq))
151 for idx, place in seq:
152 if place == PassPlacement.MemoryOnly:
153 if last_place == PassPlacement.Npu:
154 place = PassPlacement.Npu
155 place_vec[idx] = place
156
157 if place != PassPlacement.MemoryOnly:
158 last_place = place
159
160 # Anything left, assign to the CPU.
161 place_vec[place_vec == PassPlacement.MemoryOnly] = PassPlacement.Cpu
162
163 if np.all(place_vec == PassPlacement.Cpu):
164 return [] # Nothing to do
165
166 # Create the subgraphs and split passes between them
167
168 new_subgraphs = []
169 split_count = 0
170 subgraph_for_pass = {}
171 orig_sg.passes = []
172 call_pass = {}
173 startup_init_passes = {}
174
175 last_place = PassPlacement.Cpu
176 curr_sg = orig_sg
177
178 for idx, place in enumerate(place_vec):
179 if place != last_place:
180 if place == PassPlacement.Npu:
181 split_count += 1
182 curr_sg = Subgraph("%s_split_%d" % (orig_sg.name, split_count), PassPlacement.Npu)
183 new_subgraphs.append(curr_sg)
184 call_ps = make_npu_call_op_pass(curr_sg)
185 subgraph_for_pass[call_ps] = orig_sg
186 orig_sg.passes.append(call_ps)
187 call_pass[curr_sg] = call_ps
188
189 startup_init_ps = Pass(
190 curr_sg.name + "_startup_init", PassPlacement.StartupInit, False, NpuBlockType.Default
191 )
192 curr_sg.passes.append(startup_init_ps)
193 startup_init_passes[curr_sg] = startup_init_ps
194 subgraph_for_pass[startup_init_ps] = curr_sg
195
196 else:
197 curr_sg = orig_sg
198 last_place = place
199 ps = passes[idx]
200 subgraph_for_pass[ps] = curr_sg
201 curr_sg.passes.append(ps)
202
203 # Rewrite tensors to fix up graphs.
204
205 for curr_sg in new_subgraphs:
206 for ps in curr_sg.passes:
207 for tens in ps.inputs:
208 source_sgs = [subgraph_for_pass[op.scheduled_pass] for op in tens.ops]
209 assert len(source_sgs) >= 0
210 producer_sg = source_sgs[0]
211 for sg in source_sgs:
212 assert sg == producer_sg # All need to be the same.
213
214 if producer_sg != curr_sg:
215 assert (
216 producer_sg == orig_sg
217 ) # Because we go in-order, all the producers must be the original graph.
218 rewrite_tensor_cpu_producer_npu_consumers(
219 tens, call_pass[curr_sg], startup_init_passes[curr_sg], curr_sg, orig_sg, subgraph_for_pass
220 )
221
222 for tens in ps.outputs:
223
224 dest_sgs = [subgraph_for_pass[op.scheduled_pass] for op in tens.consumers() if op is not None]
225 need_rewrite = False
226 for sg in dest_sgs:
227 if sg != curr_sg:
228 need_rewrite = True
229 break
230 if tens in orig_sg.output_tensors:
231 need_rewrite = True
232
233 if need_rewrite:
234 rewrite_tensor_npu_producer_cpu_consumers(
235 tens, call_pass[curr_sg], startup_init_passes[curr_sg], curr_sg, orig_sg, subgraph_for_pass
236 )
237
238 return new_subgraphs
239
240
241def extract_npu_subgraphs(nng, arch):
242
243 nng.refresh_after_modification()
244
245 for sg in list(nng.subgraphs):
246 if sg.placement == PassPlacement.Cpu:
247 new_subgraphs = extract_subgraph(nng, sg, arch)
248 nng.subgraphs += new_subgraphs
249
250 nng.refresh_after_modification()
251 nng.prune_startup_init_pass()
252
253 for sg in nng.subgraphs:
254 sg.build_pass_links()