blob: 9b1824b5c676b570baefb28f0d5192225bf0824d [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18# Description:
19# Mark purpose and select formats for Tensors. Also compresses the weights.
20
21from . import rewrite_graph
22from . import weight_compressor
23from .architecture_features import Block
24from .nn_graph import TensorPurpose, TensorFormat, PassPlacement
25from .operation import NpuBlockType
26
27
28def purpose_from_list(lst):
29 def purpose(op, idx):
30 return lst[idx]
31
32 return purpose
33
34
35def all_fm(op, idx):
36 return TensorPurpose.FeatureMap
37
38
39def all_parameter(op, idx):
40 return TensorPurpose.FeatureMap
41
42
43def input0_from_output_rest_parameter(op, idx):
44 if idx == 0:
45 res = op.outputs[0].purpose
46 if res == TensorPurpose.Unknown:
47 print("Warning: Propagating unknown tensor purpose", op)
48 return res
49 return TensorPurpose.FeatureMap
50
51
52def inputs_from_output(op, idx):
53 res = op.outputs[0].purpose
54 if res == TensorPurpose.Unknown:
55 print("Warning: Propagating unknown tensor purpose", op)
56 return res
57
58tensor_purposes = [ # ops, input_purpose
59 (
60 set(
61 (
62 "Relu",
63 "Relu6",
64 "Mul",
65 "Add",
66 "Sub",
67 "Rsqrt",
68 "Abs",
69 "Cast",
70 "Exp",
71 "Floor",
72 "FloorDiv",
73 "FloorMod",
74 "SquaredDifference",
75 "AddN",
76 "BiasAdd",
77 "RealDiv",
78 "Maximum",
79 "Minimum",
80 "Sigmoid",
81 "Tanh",
82 "FusedBatchNorm",
83 "AvgPool",
84 "MaxPool",
85 "Squeeze",
86 "Softmax",
87 "LRN",
88 "Assign",
89 "BatchMatMul",
90 "ZerosLike",
91 "ExtractImagePatches",
92 "MulAct",
93 "AddAct",
94 "SubAct",
95 "DivAct",
96 "AvgPoolAct",
97 "MaxPoolAct",
98 "LeakyRelu",
99 )
100 ),
101 all_fm,
102 ),
103 (
104 set(
105 (
106 "Conv2D",
107 "DepthwiseConv2dNative",
108 "MatMul",
109 "Conv2DBiasAct",
110 "DepthwiseConv2dBiasAct",
111 "FullyConnectedAct",
112 )
113 ),
114 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.Weights, TensorPurpose.FeatureMap]),
115 ),
116 (
117 set(("Conv2DBackpropInputSwitched",)),
118 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.Weights, TensorPurpose.FeatureMap]),
119 ),
120 (
121 set(("QuantizedConv2D", "QuantizedMatMul")),
122 purpose_from_list(
123 [
124 TensorPurpose.FeatureMap,
125 TensorPurpose.Weights,
126 TensorPurpose.FeatureMap,
127 TensorPurpose.FeatureMap,
128 TensorPurpose.FeatureMap,
129 TensorPurpose.FeatureMap,
130 ]
131 ),
132 ),
133 (
134 set(
135 (
136 "Reshape",
137 "Min",
138 "Max",
139 "Mean",
140 "Pad",
141 "MirrorPad",
142 "ArgMax",
143 "ArgMin",
144 "ExpandDims",
145 "ResizeNearestNeighbor",
146 "ResizeBilinear",
147 "Tile",
148 "Transpose",
149 "Mfcc",
150 )
151 ),
152 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
153 ),
154 (
155 set(("QuantizedReshape", "QuantizedResizeBilinear")),
156 purpose_from_list(
157 [TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
158 ),
159 ),
160 (
161 set(("QuantizedBiasAdd", "QuantizedAdd", "QuantizedMul")),
162 purpose_from_list(
163 [
164 TensorPurpose.FeatureMap,
165 TensorPurpose.FeatureMap,
166 TensorPurpose.FeatureMap,
167 TensorPurpose.FeatureMap,
168 TensorPurpose.FeatureMap,
169 TensorPurpose.FeatureMap,
170 ]
171 ),
172 ),
173 (
174 set(
175 (
176 "Dequantize",
177 "Quantize",
178 "QuantizeV2",
179 "QuantizedRelu",
180 "QuantizedRelu1",
181 "QuantizedRelu6",
182 "QuantizedAvgPool",
183 "QuantizedMaxPool",
184 "Slice",
185 "SplitV",
186 )
187 ),
188 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
189 ),
190 (
191 set(("BatchToSpaceND", "SpaceToBatchND", "DepthToSpaceND", "SpaceToDepthND")),
192 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
193 ),
194 (
195 set(("BlockLSTM",)),
196 purpose_from_list(
197 [
198 TensorPurpose.FeatureMap,
199 TensorPurpose.FeatureMap,
200 TensorPurpose.FeatureMap,
201 TensorPurpose.FeatureMap,
202 TensorPurpose.Weights,
203 TensorPurpose.FeatureMap,
204 TensorPurpose.FeatureMap,
205 TensorPurpose.FeatureMap,
206 TensorPurpose.FeatureMap,
207 ]
208 ),
209 ),
210 (set(("SplitSliceRead",)), purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap])),
211 (set(("Shape", "ConcatSliceWrite", "AudioSpectrogram")), purpose_from_list([TensorPurpose.FeatureMap])),
212 (
213 set(("StridedSlice",)),
214 purpose_from_list(
215 [TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
216 ),
217 ),
218 (set(("Fill", "Pack", "Range")), all_parameter),
219 (
220 set(("Requantize",)),
221 purpose_from_list(
222 [
223 TensorPurpose.FeatureMap,
224 TensorPurpose.FeatureMap,
225 TensorPurpose.FeatureMap,
226 TensorPurpose.FeatureMap,
227 TensorPurpose.FeatureMap,
228 ]
229 ),
230 ),
231 (set(("Placeholder", "SubgraphInput", "Const", "VariableV2")), purpose_from_list([])),
232 (set(("FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxVars")), input0_from_output_rest_parameter),
233 (
234 set(("Square", "Sqrt", "Log", "Less", "Enter", "Exit", "Identity", "StopGradient", "Merge", "Switch")),
235 inputs_from_output,
236 ),
237 (None, all_fm),
238]
239
240
241for ops, input_purpose in tensor_purposes:
242 if ops is None:
243 continue
244 for op in ops:
245 assert len(op) > 1, "string literal has been decomposed"
246
247
248def mark_tensor_purpose(nng, arch, verbose_tensor_purpose=False):
249 def mark_tensor_helper(tens, purpose):
250
251 if tens.purpose == TensorPurpose.Unknown or tens.purpose == purpose:
252 tens.purpose = purpose
253 else:
254 assert 0, "Cannot resolve tensor purpose %s and %s for tensor %s" % (tens.purpose, purpose, tens)
255 tens.mem_area = arch.tensor_storage_mem_area[tens.purpose]
256
257 if len(tens.ops) == 1 and tens.ops[0].type == "Const":
258 tens.mem_area = (
259 arch.permanent_storage_mem_area
260 ) # special case constants, as they must be in permanent storage
261
262 def rewrite_mark_tensor_purpose(op, arch):
263 # find disconnected outputs and mark as parameters
264 for tens in op.outputs:
265 if not tens.consumers():
266 mark_tensor_helper(tens, TensorPurpose.FeatureMap)
267
268 for ops, input_purpose in tensor_purposes:
269 if ops is None or op.type in ops:
270 if ops is None:
271 print(
272 "warning: don't know how to mark up purpose for",
273 op.type,
274 op.inputs,
275 "triggering all feature map fallback",
276 )
277 for idx, tens in enumerate(op.inputs):
278 purpose = input_purpose(op, idx)
279 mark_tensor_helper(tens, purpose)
280 break
281 return op
282
283 for sg in nng.subgraphs:
284 sg = rewrite_graph.rewrite_graph_pre_order(sg, arch, [], [rewrite_mark_tensor_purpose])
285 for tens in sg.output_tensors:
286 mark_tensor_helper(tens, TensorPurpose.FeatureMap)
287
288 if verbose_tensor_purpose:
289 nng.print_graph_with_tensors()
290
291 return nng
292
293
294reshape_operations = set(
295 (
296 "Reshape",
297 "QuantizedReshape",
298 "ExpandDims",
299 "Squeeze",
300 "BatchToSpaceND",
301 "SpaceToBatchND",
302 "DepthToSpaceND",
303 "SpaceToDepthND",
304 "Placeholder",
305 )
306)
307
308
309def mark_tensor_format(nng, arch, verbose_tensor_format=False):
310 formats_for_tensor = {}
311
312 def init_tens(tens):
313 if tens.purpose == TensorPurpose.FeatureMap:
314 fmt = arch.default_feature_map_format
315 elif tens.purpose == TensorPurpose.Weights:
316 fmt = arch.default_weight_format
317 else:
318 assert 0, "unknown tensor purpose %s" % (tens.purpose,)
319 return fmt
320
321 def find_npu_usage_of_tensor(tens):
322 for op in tens.consumers():
323 if op.type == "DMA":
324 return find_npu_usage_of_tensor(op.outputs[0])
325 if "npu_block_type" in op.attrs:
326 return op.attrs["npu_block_type"]
327 return NpuBlockType.Default
328
329 def visit_tens(tens, ps):
330 if not tens in formats_for_tensor:
331 fmt = init_tens(tens)
332 else:
333 fmt = formats_for_tensor[tens]
334
335 formats_for_tensor[tens] = fmt
336
337 for sg in nng.subgraphs:
338 for ps in sg.passes:
339 for tens in ps.outputs:
340 visit_tens(tens, ps)
341 for tens in ps.intermediates:
342 visit_tens(tens, ps)
343 for tens in ps.inputs:
344 visit_tens(tens, ps)
345
346 for tens, fmt in formats_for_tensor.items():
347 tens.set_format(fmt, arch)
348 if fmt == TensorFormat.WeightsCompressed and tens.values is not None:
349 npu_block_type = find_npu_usage_of_tensor(tens)
350 if len(tens.ops) == 1 and tens.ops[0].type == "DMA":
351 weight_compressor.compress_weights(tens, arch, npu_block_type, Block(32, 32, 32), 32)
352 # Alias compressed weights back into source tensor
353 src_tens = tens.ops[0].inputs[0]
354 src_tens.compressed_values = tens.compressed_values
355 src_tens.storage_shape = tens.storage_shape
356 src_tens.brick_size = tens.brick_size
357 src_tens.weight_compression_scales = tens.weight_compression_scales
358 src_tens.weight_compressed_offsets = tens.weight_compressed_offsets
359 src_tens.compression_scale_for_worst_weight_stream = tens.compression_scale_for_worst_weight_stream
360 src_tens.storage_compression_scale = tens.storage_compression_scale
361
362 if verbose_tensor_format:
363 nng.print_passes_with_tensors()