blob: c42a28dfee82a7265bad77cb5d98e0b170dc754b [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18# Description:
19# Mark purpose and select formats for Tensors. Also compresses the weights.
20
21from . import rewrite_graph
22from . import weight_compressor
23from .architecture_features import Block
Diego Russoea6111a2020-04-14 18:41:58 +010024from .tensor import TensorPurpose, TensorFormat
Tim Hall79d07d22020-04-27 18:20:16 +010025from .operation import NpuBlockType
26
27
28def purpose_from_list(lst):
29 def purpose(op, idx):
30 return lst[idx]
31
32 return purpose
33
34
35def all_fm(op, idx):
36 return TensorPurpose.FeatureMap
37
38
39def all_parameter(op, idx):
40 return TensorPurpose.FeatureMap
41
42
43def input0_from_output_rest_parameter(op, idx):
44 if idx == 0:
45 res = op.outputs[0].purpose
46 if res == TensorPurpose.Unknown:
47 print("Warning: Propagating unknown tensor purpose", op)
48 return res
49 return TensorPurpose.FeatureMap
50
51
52def inputs_from_output(op, idx):
53 res = op.outputs[0].purpose
54 if res == TensorPurpose.Unknown:
55 print("Warning: Propagating unknown tensor purpose", op)
56 return res
57
Diego Russoea6111a2020-04-14 18:41:58 +010058
Tim Hall79d07d22020-04-27 18:20:16 +010059tensor_purposes = [ # ops, input_purpose
60 (
61 set(
62 (
63 "Relu",
64 "Relu6",
65 "Mul",
66 "Add",
67 "Sub",
68 "Rsqrt",
69 "Abs",
70 "Cast",
71 "Exp",
72 "Floor",
73 "FloorDiv",
74 "FloorMod",
75 "SquaredDifference",
76 "AddN",
77 "BiasAdd",
78 "RealDiv",
79 "Maximum",
80 "Minimum",
81 "Sigmoid",
82 "Tanh",
83 "FusedBatchNorm",
84 "AvgPool",
85 "MaxPool",
86 "Squeeze",
87 "Softmax",
88 "LRN",
89 "Assign",
90 "BatchMatMul",
91 "ZerosLike",
92 "ExtractImagePatches",
93 "MulAct",
94 "AddAct",
95 "SubAct",
96 "DivAct",
97 "AvgPoolAct",
98 "MaxPoolAct",
99 "LeakyRelu",
100 )
101 ),
102 all_fm,
103 ),
104 (
105 set(
106 (
107 "Conv2D",
108 "DepthwiseConv2dNative",
109 "MatMul",
110 "Conv2DBiasAct",
111 "DepthwiseConv2dBiasAct",
112 "FullyConnectedAct",
113 )
114 ),
115 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.Weights, TensorPurpose.FeatureMap]),
116 ),
117 (
118 set(("Conv2DBackpropInputSwitched",)),
119 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.Weights, TensorPurpose.FeatureMap]),
120 ),
121 (
122 set(("QuantizedConv2D", "QuantizedMatMul")),
123 purpose_from_list(
124 [
125 TensorPurpose.FeatureMap,
126 TensorPurpose.Weights,
127 TensorPurpose.FeatureMap,
128 TensorPurpose.FeatureMap,
129 TensorPurpose.FeatureMap,
130 TensorPurpose.FeatureMap,
131 ]
132 ),
133 ),
134 (
135 set(
136 (
137 "Reshape",
138 "Min",
139 "Max",
140 "Mean",
141 "Pad",
142 "MirrorPad",
143 "ArgMax",
144 "ArgMin",
145 "ExpandDims",
146 "ResizeNearestNeighbor",
147 "ResizeBilinear",
148 "Tile",
149 "Transpose",
150 "Mfcc",
151 )
152 ),
153 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
154 ),
155 (
156 set(("QuantizedReshape", "QuantizedResizeBilinear")),
157 purpose_from_list(
158 [TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
159 ),
160 ),
161 (
162 set(("QuantizedBiasAdd", "QuantizedAdd", "QuantizedMul")),
163 purpose_from_list(
164 [
165 TensorPurpose.FeatureMap,
166 TensorPurpose.FeatureMap,
167 TensorPurpose.FeatureMap,
168 TensorPurpose.FeatureMap,
169 TensorPurpose.FeatureMap,
170 TensorPurpose.FeatureMap,
171 ]
172 ),
173 ),
174 (
175 set(
176 (
177 "Dequantize",
178 "Quantize",
179 "QuantizeV2",
180 "QuantizedRelu",
181 "QuantizedRelu1",
182 "QuantizedRelu6",
183 "QuantizedAvgPool",
184 "QuantizedMaxPool",
185 "Slice",
186 "SplitV",
187 )
188 ),
189 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
190 ),
191 (
192 set(("BatchToSpaceND", "SpaceToBatchND", "DepthToSpaceND", "SpaceToDepthND")),
193 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
194 ),
195 (
196 set(("BlockLSTM",)),
197 purpose_from_list(
198 [
199 TensorPurpose.FeatureMap,
200 TensorPurpose.FeatureMap,
201 TensorPurpose.FeatureMap,
202 TensorPurpose.FeatureMap,
203 TensorPurpose.Weights,
204 TensorPurpose.FeatureMap,
205 TensorPurpose.FeatureMap,
206 TensorPurpose.FeatureMap,
207 TensorPurpose.FeatureMap,
208 ]
209 ),
210 ),
211 (set(("SplitSliceRead",)), purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap])),
212 (set(("Shape", "ConcatSliceWrite", "AudioSpectrogram")), purpose_from_list([TensorPurpose.FeatureMap])),
213 (
214 set(("StridedSlice",)),
215 purpose_from_list(
216 [TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
217 ),
218 ),
219 (set(("Fill", "Pack", "Range")), all_parameter),
220 (
221 set(("Requantize",)),
222 purpose_from_list(
223 [
224 TensorPurpose.FeatureMap,
225 TensorPurpose.FeatureMap,
226 TensorPurpose.FeatureMap,
227 TensorPurpose.FeatureMap,
228 TensorPurpose.FeatureMap,
229 ]
230 ),
231 ),
232 (set(("Placeholder", "SubgraphInput", "Const", "VariableV2")), purpose_from_list([])),
233 (set(("FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxVars")), input0_from_output_rest_parameter),
234 (
235 set(("Square", "Sqrt", "Log", "Less", "Enter", "Exit", "Identity", "StopGradient", "Merge", "Switch")),
236 inputs_from_output,
237 ),
238 (None, all_fm),
239]
240
241
242for ops, input_purpose in tensor_purposes:
243 if ops is None:
244 continue
245 for op in ops:
246 assert len(op) > 1, "string literal has been decomposed"
247
248
249def mark_tensor_purpose(nng, arch, verbose_tensor_purpose=False):
250 def mark_tensor_helper(tens, purpose):
251
252 if tens.purpose == TensorPurpose.Unknown or tens.purpose == purpose:
253 tens.purpose = purpose
254 else:
255 assert 0, "Cannot resolve tensor purpose %s and %s for tensor %s" % (tens.purpose, purpose, tens)
256 tens.mem_area = arch.tensor_storage_mem_area[tens.purpose]
257
258 if len(tens.ops) == 1 and tens.ops[0].type == "Const":
259 tens.mem_area = (
260 arch.permanent_storage_mem_area
261 ) # special case constants, as they must be in permanent storage
262
263 def rewrite_mark_tensor_purpose(op, arch):
264 # find disconnected outputs and mark as parameters
265 for tens in op.outputs:
266 if not tens.consumers():
267 mark_tensor_helper(tens, TensorPurpose.FeatureMap)
268
269 for ops, input_purpose in tensor_purposes:
270 if ops is None or op.type in ops:
271 if ops is None:
272 print(
273 "warning: don't know how to mark up purpose for",
274 op.type,
275 op.inputs,
276 "triggering all feature map fallback",
277 )
278 for idx, tens in enumerate(op.inputs):
279 purpose = input_purpose(op, idx)
280 mark_tensor_helper(tens, purpose)
281 break
282 return op
283
284 for sg in nng.subgraphs:
285 sg = rewrite_graph.rewrite_graph_pre_order(sg, arch, [], [rewrite_mark_tensor_purpose])
286 for tens in sg.output_tensors:
287 mark_tensor_helper(tens, TensorPurpose.FeatureMap)
288
289 if verbose_tensor_purpose:
290 nng.print_graph_with_tensors()
291
292 return nng
293
294
295reshape_operations = set(
296 (
297 "Reshape",
298 "QuantizedReshape",
299 "ExpandDims",
300 "Squeeze",
301 "BatchToSpaceND",
302 "SpaceToBatchND",
303 "DepthToSpaceND",
304 "SpaceToDepthND",
305 "Placeholder",
306 )
307)
308
309
310def mark_tensor_format(nng, arch, verbose_tensor_format=False):
311 formats_for_tensor = {}
312
313 def init_tens(tens):
314 if tens.purpose == TensorPurpose.FeatureMap:
315 fmt = arch.default_feature_map_format
316 elif tens.purpose == TensorPurpose.Weights:
317 fmt = arch.default_weight_format
318 else:
319 assert 0, "unknown tensor purpose %s" % (tens.purpose,)
320 return fmt
321
322 def find_npu_usage_of_tensor(tens):
323 for op in tens.consumers():
324 if op.type == "DMA":
325 return find_npu_usage_of_tensor(op.outputs[0])
326 if "npu_block_type" in op.attrs:
327 return op.attrs["npu_block_type"]
328 return NpuBlockType.Default
329
330 def visit_tens(tens, ps):
Diego Russoea6111a2020-04-14 18:41:58 +0100331 if tens not in formats_for_tensor:
Tim Hall79d07d22020-04-27 18:20:16 +0100332 fmt = init_tens(tens)
333 else:
334 fmt = formats_for_tensor[tens]
335
336 formats_for_tensor[tens] = fmt
337
338 for sg in nng.subgraphs:
339 for ps in sg.passes:
340 for tens in ps.outputs:
341 visit_tens(tens, ps)
342 for tens in ps.intermediates:
343 visit_tens(tens, ps)
344 for tens in ps.inputs:
345 visit_tens(tens, ps)
346
347 for tens, fmt in formats_for_tensor.items():
348 tens.set_format(fmt, arch)
349 if fmt == TensorFormat.WeightsCompressed and tens.values is not None:
350 npu_block_type = find_npu_usage_of_tensor(tens)
351 if len(tens.ops) == 1 and tens.ops[0].type == "DMA":
352 weight_compressor.compress_weights(tens, arch, npu_block_type, Block(32, 32, 32), 32)
353 # Alias compressed weights back into source tensor
354 src_tens = tens.ops[0].inputs[0]
355 src_tens.compressed_values = tens.compressed_values
356 src_tens.storage_shape = tens.storage_shape
357 src_tens.brick_size = tens.brick_size
358 src_tens.weight_compression_scales = tens.weight_compression_scales
359 src_tens.weight_compressed_offsets = tens.weight_compressed_offsets
360 src_tens.compression_scale_for_worst_weight_stream = tens.compression_scale_for_worst_weight_stream
361 src_tens.storage_compression_scale = tens.storage_compression_scale
362
363 if verbose_tensor_format:
364 nng.print_passes_with_tensors()