blob: 5b9ba8b01a8a0d70a93d55ca927ed70bcca8b015 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18# Description:
19# Vela separates CPU operations and NPU operations into separate internal subgraphs. The CPU operations are left
20# untouched in the final output.
21#
22# Vela does this by identifying NPU passes and pulling them out from the main CPU graph into separate subgraphs, invoked
23# by NpuOp operations. Later, Vela generates command streams and compressed weight streams for the NPU subgraphs and
24# attaches them to the NpuOp. This encapsulates everything the NPU subgraph is supposed to do.
25
26from .nn_graph import Pass, PassPlacement, NpuBlockType, Subgraph
27from .operation import Operation
28import numpy as np
29
30
31def make_npu_call_op_pass(npu_subgraph):
32 op = Operation("NpuOp", "call_" + npu_subgraph.name)
33 op.attrs["subgraph"] = npu_subgraph
34 ps = Pass(op.name, PassPlacement.MemoryOnly, False, NpuBlockType.Default)
35 ps.ops = [op]
36 ps.primary_op = op
37 op.attrs["npu_block_type"] = ps.npu_block_type
38 op.scheduled_pass = ps
39
40 # Inputs and outputs filled in later as we cut the graphs
41 return ps
42
43
44def switch_tensor_for_op(op, orig_tens, new_tens):
45
46 op.inputs = [new_tens if tens == orig_tens else tens for tens in op.inputs]
47 op.outputs = [new_tens if tens == orig_tens else tens for tens in op.outputs]
48
49 ps = op.scheduled_pass
50 if ps is None:
51 return
52
53 ps.inputs = [new_tens if tens == orig_tens else tens for tens in ps.inputs]
54 ps.outputs = [new_tens if tens == orig_tens else tens for tens in ps.outputs]
55
56 if ps.ifm_tensor == orig_tens:
57 ps.ifm_tensor = new_tens
58 if ps.ifm2_tensor == orig_tens:
59 ps.ifm2_tensor = new_tens
60 if ps.ofm_tensor == orig_tens:
61 ps.ofm_tensor = new_tens
62 if ps.weight_tensor == orig_tens:
63 ps.weight_tensor = new_tens
64 if ps.scale_tensor == orig_tens:
65 ps.scale_tensor = new_tens
66
67
68def rewrite_tensor_cpu_producer_npu_consumers(
69 orig_tens, call_ps, startup_init_ps, npu_subgraph, cpu_subgraph, subgraph_for_pass
70):
71 is_const = orig_tens.ops[0].type == "Const"
72
73 new_tens = orig_tens.clone("_npu")
74 orig_tens.npu_tensor = new_tens
75 new_tens.cpu_tensor = orig_tens
76
77 op_type = "SubgraphInput"
78 if is_const:
79 op_type = "Const"
80 op = Operation(op_type, orig_tens.name + "_input")
81 op.attrs["npu_block_type"] = NpuBlockType.Default
82 op.outputs = [new_tens]
83 op.scheduled_pass = startup_init_ps
84 new_tens.ops = [op]
85 startup_init_ps.ops.append(op)
86 startup_init_ps.outputs.append(new_tens)
87
88 if not is_const:
89 call_ps.inputs.append(orig_tens)
90 call_ps.primary_op.inputs.append(orig_tens)
91
92 for op in list(orig_tens.consumers()):
93 if op is None:
94 continue # Subgraph consumers handled separately.
95 ps = op.scheduled_pass
96 if subgraph_for_pass[ps] == npu_subgraph:
97 switch_tensor_for_op(op, orig_tens, new_tens)
98 orig_tens.consumer_list.remove(op)
99 new_tens.consumer_list.append(op)
100
101 # Deal with output tensors for the NPU graph. These are special.
102 npu_subgraph.output_tensors = [new_tens if tens == orig_tens else tens for tens in npu_subgraph.output_tensors]
103
104
105def rewrite_tensor_npu_producer_cpu_consumers(
106 orig_tens, call_ps, startup_init_ps, npu_subgraph, cpu_subgraph, subgraph_for_pass
107):
108
109 new_tens = orig_tens.clone("_cpu")
110 new_tens.npu_tensor = orig_tens
111 orig_tens.cpu_tensor = new_tens
112
113 npu_subgraph.output_tensors.append(orig_tens)
114
115 call_ps.outputs.append(new_tens)
116 call_ps.primary_op.outputs.append(new_tens)
117 new_tens.ops = [call_ps.primary_op]
118
119 for op in list(orig_tens.consumers()):
120 if op is None:
121 continue # Subgraph consumers handled separately.
122 ps = op.scheduled_pass
123 if subgraph_for_pass[ps] != npu_subgraph:
124 switch_tensor_for_op(op, orig_tens, new_tens)
125 orig_tens.consumer_list.remove(op)
126 new_tens.consumer_list.append(op)
127
128 # Deal with output tensors for the CPU graph. These are special.
129 cpu_subgraph.output_tensors = [new_tens if tens == orig_tens else tens for tens in cpu_subgraph.output_tensors]
130
131
132def extract_subgraph(nng, orig_sg, arch):
133 assert orig_sg.placement == PassPlacement.Cpu
134
135 passes = list(orig_sg.passes)
136 place_vec = np.array([ps.placement for ps in passes])
137 place_vec[
138 place_vec == PassPlacement.StartupInit
139 ] = PassPlacement.Cpu # Keep the startup init pass on the CPU, we'll make new ones to move onto NPU.
140
141 # MemoryOnly passes that are either squeezed between NPU passes or on the boundary of NPU and CPU
142 # passes should be assigned to the NPU.
143
144 # Forward, then backwards
145 for is_reversed in range(2):
146 last_place = PassPlacement.Cpu
147 seq = enumerate(place_vec)
148 if is_reversed:
149 seq = reversed(list(seq))
150 for idx, place in seq:
151 if place == PassPlacement.MemoryOnly:
152 if last_place == PassPlacement.Npu:
153 place = PassPlacement.Npu
154 place_vec[idx] = place
155
156 if place != PassPlacement.MemoryOnly:
157 last_place = place
158
159 # Anything left, assign to the CPU.
160 place_vec[place_vec == PassPlacement.MemoryOnly] = PassPlacement.Cpu
161
162 if np.all(place_vec == PassPlacement.Cpu):
163 return [] # Nothing to do
164
165 # Create the subgraphs and split passes between them
166
167 new_subgraphs = []
168 split_count = 0
169 subgraph_for_pass = {}
170 orig_sg.passes = []
171 call_pass = {}
172 startup_init_passes = {}
173
174 last_place = PassPlacement.Cpu
175 curr_sg = orig_sg
176
177 for idx, place in enumerate(place_vec):
178 if place != last_place:
179 if place == PassPlacement.Npu:
180 split_count += 1
181 curr_sg = Subgraph("%s_split_%d" % (orig_sg.name, split_count), PassPlacement.Npu)
182 new_subgraphs.append(curr_sg)
183 call_ps = make_npu_call_op_pass(curr_sg)
184 subgraph_for_pass[call_ps] = orig_sg
185 orig_sg.passes.append(call_ps)
186 call_pass[curr_sg] = call_ps
187
188 startup_init_ps = Pass(
189 curr_sg.name + "_startup_init", PassPlacement.StartupInit, False, NpuBlockType.Default
190 )
191 curr_sg.passes.append(startup_init_ps)
192 startup_init_passes[curr_sg] = startup_init_ps
193 subgraph_for_pass[startup_init_ps] = curr_sg
194
195 else:
196 curr_sg = orig_sg
197 last_place = place
198 ps = passes[idx]
199 subgraph_for_pass[ps] = curr_sg
200 curr_sg.passes.append(ps)
201
202 # Rewrite tensors to fix up graphs.
203
204 for curr_sg in new_subgraphs:
205 for ps in curr_sg.passes:
206 for tens in ps.inputs:
207 source_sgs = [subgraph_for_pass[op.scheduled_pass] for op in tens.ops]
208 assert len(source_sgs) >= 0
209 producer_sg = source_sgs[0]
210 for sg in source_sgs:
211 assert sg == producer_sg # All need to be the same.
212
213 if producer_sg != curr_sg:
214 assert (
215 producer_sg == orig_sg
216 ) # Because we go in-order, all the producers must be the original graph.
217 rewrite_tensor_cpu_producer_npu_consumers(
218 tens, call_pass[curr_sg], startup_init_passes[curr_sg], curr_sg, orig_sg, subgraph_for_pass
219 )
220
221 for tens in ps.outputs:
222
223 dest_sgs = [subgraph_for_pass[op.scheduled_pass] for op in tens.consumers() if op is not None]
224 need_rewrite = False
225 for sg in dest_sgs:
226 if sg != curr_sg:
227 need_rewrite = True
228 break
229 if tens in orig_sg.output_tensors:
230 need_rewrite = True
231
232 if need_rewrite:
233 rewrite_tensor_npu_producer_cpu_consumers(
234 tens, call_pass[curr_sg], startup_init_passes[curr_sg], curr_sg, orig_sg, subgraph_for_pass
235 )
236
237 return new_subgraphs
238
239
240def extract_npu_subgraphs(nng, arch):
241
242 nng.refresh_after_modification()
243
244 for sg in list(nng.subgraphs):
245 if sg.placement == PassPlacement.Cpu:
246 new_subgraphs = extract_subgraph(nng, sg, arch)
247 nng.subgraphs += new_subgraphs
248
249 nng.refresh_after_modification()
250 nng.prune_startup_init_pass()
251
252 for sg in nng.subgraphs:
253 sg.build_pass_links()