blob: cd70446be5c8bd648984c30579778853bf280ff6 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Mark purpose and select formats for Tensors. Also compresses the weights.
Tim Hall79d07d22020-04-27 18:20:16 +010018from . import rewrite_graph
19from . import weight_compressor
Tim Hall79d07d22020-04-27 18:20:16 +010020from .operation import NpuBlockType
Diego Russoe8a10452020-04-21 17:39:10 +010021from .tensor import TensorFormat
22from .tensor import TensorPurpose
Tim Hall79d07d22020-04-27 18:20:16 +010023
24
25def purpose_from_list(lst):
26 def purpose(op, idx):
27 return lst[idx]
28
29 return purpose
30
31
32def all_fm(op, idx):
33 return TensorPurpose.FeatureMap
34
35
36def all_parameter(op, idx):
37 return TensorPurpose.FeatureMap
38
39
40def input0_from_output_rest_parameter(op, idx):
41 if idx == 0:
42 res = op.outputs[0].purpose
43 if res == TensorPurpose.Unknown:
44 print("Warning: Propagating unknown tensor purpose", op)
45 return res
46 return TensorPurpose.FeatureMap
47
48
49def inputs_from_output(op, idx):
50 res = op.outputs[0].purpose
51 if res == TensorPurpose.Unknown:
52 print("Warning: Propagating unknown tensor purpose", op)
53 return res
54
Diego Russoea6111a2020-04-14 18:41:58 +010055
Tim Hall79d07d22020-04-27 18:20:16 +010056tensor_purposes = [ # ops, input_purpose
57 (
58 set(
59 (
60 "Relu",
61 "Relu6",
62 "Mul",
63 "Add",
64 "Sub",
65 "Rsqrt",
66 "Abs",
67 "Cast",
68 "Exp",
69 "Floor",
70 "FloorDiv",
71 "FloorMod",
72 "SquaredDifference",
73 "AddN",
74 "BiasAdd",
75 "RealDiv",
76 "Maximum",
77 "Minimum",
78 "Sigmoid",
79 "Tanh",
80 "FusedBatchNorm",
81 "AvgPool",
82 "MaxPool",
83 "Squeeze",
84 "Softmax",
85 "LRN",
86 "Assign",
87 "BatchMatMul",
88 "ZerosLike",
89 "ExtractImagePatches",
90 "MulAct",
91 "AddAct",
92 "SubAct",
93 "DivAct",
94 "AvgPoolAct",
95 "MaxPoolAct",
96 "LeakyRelu",
97 )
98 ),
99 all_fm,
100 ),
101 (
102 set(
103 (
104 "Conv2D",
105 "DepthwiseConv2dNative",
106 "MatMul",
107 "Conv2DBiasAct",
108 "DepthwiseConv2dBiasAct",
109 "FullyConnectedAct",
110 )
111 ),
112 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.Weights, TensorPurpose.FeatureMap]),
113 ),
114 (
115 set(("Conv2DBackpropInputSwitched",)),
116 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.Weights, TensorPurpose.FeatureMap]),
117 ),
118 (
119 set(("QuantizedConv2D", "QuantizedMatMul")),
120 purpose_from_list(
121 [
122 TensorPurpose.FeatureMap,
123 TensorPurpose.Weights,
124 TensorPurpose.FeatureMap,
125 TensorPurpose.FeatureMap,
126 TensorPurpose.FeatureMap,
127 TensorPurpose.FeatureMap,
128 ]
129 ),
130 ),
131 (
132 set(
133 (
134 "Reshape",
135 "Min",
136 "Max",
137 "Mean",
138 "Pad",
139 "MirrorPad",
140 "ArgMax",
141 "ArgMin",
142 "ExpandDims",
143 "ResizeNearestNeighbor",
144 "ResizeBilinear",
145 "Tile",
146 "Transpose",
147 "Mfcc",
148 )
149 ),
150 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
151 ),
152 (
153 set(("QuantizedReshape", "QuantizedResizeBilinear")),
154 purpose_from_list(
155 [TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
156 ),
157 ),
158 (
159 set(("QuantizedBiasAdd", "QuantizedAdd", "QuantizedMul")),
160 purpose_from_list(
161 [
162 TensorPurpose.FeatureMap,
163 TensorPurpose.FeatureMap,
164 TensorPurpose.FeatureMap,
165 TensorPurpose.FeatureMap,
166 TensorPurpose.FeatureMap,
167 TensorPurpose.FeatureMap,
168 ]
169 ),
170 ),
171 (
172 set(
173 (
174 "Dequantize",
175 "Quantize",
176 "QuantizeV2",
177 "QuantizedRelu",
178 "QuantizedRelu1",
179 "QuantizedRelu6",
180 "QuantizedAvgPool",
181 "QuantizedMaxPool",
182 "Slice",
183 "SplitV",
184 )
185 ),
186 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
187 ),
188 (
189 set(("BatchToSpaceND", "SpaceToBatchND", "DepthToSpaceND", "SpaceToDepthND")),
190 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
191 ),
192 (
193 set(("BlockLSTM",)),
194 purpose_from_list(
195 [
196 TensorPurpose.FeatureMap,
197 TensorPurpose.FeatureMap,
198 TensorPurpose.FeatureMap,
199 TensorPurpose.FeatureMap,
200 TensorPurpose.Weights,
201 TensorPurpose.FeatureMap,
202 TensorPurpose.FeatureMap,
203 TensorPurpose.FeatureMap,
204 TensorPurpose.FeatureMap,
205 ]
206 ),
207 ),
208 (set(("SplitSliceRead",)), purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap])),
209 (set(("Shape", "ConcatSliceWrite", "AudioSpectrogram")), purpose_from_list([TensorPurpose.FeatureMap])),
210 (
211 set(("StridedSlice",)),
212 purpose_from_list(
213 [TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
214 ),
215 ),
216 (set(("Fill", "Pack", "Range")), all_parameter),
217 (
218 set(("Requantize",)),
219 purpose_from_list(
220 [
221 TensorPurpose.FeatureMap,
222 TensorPurpose.FeatureMap,
223 TensorPurpose.FeatureMap,
224 TensorPurpose.FeatureMap,
225 TensorPurpose.FeatureMap,
226 ]
227 ),
228 ),
229 (set(("Placeholder", "SubgraphInput", "Const", "VariableV2")), purpose_from_list([])),
230 (set(("FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxVars")), input0_from_output_rest_parameter),
231 (
232 set(("Square", "Sqrt", "Log", "Less", "Enter", "Exit", "Identity", "StopGradient", "Merge", "Switch")),
233 inputs_from_output,
234 ),
235 (None, all_fm),
236]
237
238
239for ops, input_purpose in tensor_purposes:
240 if ops is None:
241 continue
242 for op in ops:
243 assert len(op) > 1, "string literal has been decomposed"
244
245
246def mark_tensor_purpose(nng, arch, verbose_tensor_purpose=False):
247 def mark_tensor_helper(tens, purpose):
248
249 if tens.purpose == TensorPurpose.Unknown or tens.purpose == purpose:
250 tens.purpose = purpose
251 else:
252 assert 0, "Cannot resolve tensor purpose %s and %s for tensor %s" % (tens.purpose, purpose, tens)
253 tens.mem_area = arch.tensor_storage_mem_area[tens.purpose]
254
255 if len(tens.ops) == 1 and tens.ops[0].type == "Const":
256 tens.mem_area = (
257 arch.permanent_storage_mem_area
258 ) # special case constants, as they must be in permanent storage
259
260 def rewrite_mark_tensor_purpose(op, arch):
261 # find disconnected outputs and mark as parameters
262 for tens in op.outputs:
263 if not tens.consumers():
264 mark_tensor_helper(tens, TensorPurpose.FeatureMap)
265
266 for ops, input_purpose in tensor_purposes:
267 if ops is None or op.type in ops:
268 if ops is None:
269 print(
270 "warning: don't know how to mark up purpose for",
271 op.type,
272 op.inputs,
273 "triggering all feature map fallback",
274 )
275 for idx, tens in enumerate(op.inputs):
276 purpose = input_purpose(op, idx)
277 mark_tensor_helper(tens, purpose)
Louis Verhaardc4cbbc92020-05-18 13:40:02 +0200278 if op.type == "Reshape":
279 # Reshape's input and output point to same data
280 op.outputs[0].mem_area = op.inputs[0].mem_area
Tim Hall79d07d22020-04-27 18:20:16 +0100281 break
282 return op
283
284 for sg in nng.subgraphs:
285 sg = rewrite_graph.rewrite_graph_pre_order(sg, arch, [], [rewrite_mark_tensor_purpose])
286 for tens in sg.output_tensors:
287 mark_tensor_helper(tens, TensorPurpose.FeatureMap)
288
289 if verbose_tensor_purpose:
290 nng.print_graph_with_tensors()
291
292 return nng
293
294
295reshape_operations = set(
296 (
297 "Reshape",
298 "QuantizedReshape",
299 "ExpandDims",
300 "Squeeze",
301 "BatchToSpaceND",
302 "SpaceToBatchND",
303 "DepthToSpaceND",
304 "SpaceToDepthND",
305 "Placeholder",
306 )
307)
308
309
310def mark_tensor_format(nng, arch, verbose_tensor_format=False):
311 formats_for_tensor = {}
312
313 def init_tens(tens):
314 if tens.purpose == TensorPurpose.FeatureMap:
315 fmt = arch.default_feature_map_format
316 elif tens.purpose == TensorPurpose.Weights:
317 fmt = arch.default_weight_format
318 else:
319 assert 0, "unknown tensor purpose %s" % (tens.purpose,)
320 return fmt
321
322 def find_npu_usage_of_tensor(tens):
323 for op in tens.consumers():
324 if op.type == "DMA":
325 return find_npu_usage_of_tensor(op.outputs[0])
326 if "npu_block_type" in op.attrs:
327 return op.attrs["npu_block_type"]
328 return NpuBlockType.Default
329
330 def visit_tens(tens, ps):
Diego Russoea6111a2020-04-14 18:41:58 +0100331 if tens not in formats_for_tensor:
Tim Hall79d07d22020-04-27 18:20:16 +0100332 fmt = init_tens(tens)
333 else:
334 fmt = formats_for_tensor[tens]
335
336 formats_for_tensor[tens] = fmt
337
338 for sg in nng.subgraphs:
339 for ps in sg.passes:
340 for tens in ps.outputs:
341 visit_tens(tens, ps)
342 for tens in ps.intermediates:
343 visit_tens(tens, ps)
344 for tens in ps.inputs:
345 visit_tens(tens, ps)
346
347 for tens, fmt in formats_for_tensor.items():
348 tens.set_format(fmt, arch)
349 if fmt == TensorFormat.WeightsCompressed and tens.values is not None:
Louis Verhaard3c07c972020-05-07 08:12:58 +0200350 src_tens = tens.get_dma_src_tensor()
351 if src_tens is not None:
352 npu_block_type = find_npu_usage_of_tensor(tens)
353 weight_compressor.compress_weights(arch, nng, tens, npu_block_type, 32, 32)
Tim Hall79d07d22020-04-27 18:20:16 +0100354 # Alias compressed weights back into source tensor
Louis Verhaard3c07c972020-05-07 08:12:58 +0200355 src_tens.copy_compressed_weight_info(tens)
Tim Hall79d07d22020-04-27 18:20:16 +0100356
357 if verbose_tensor_format:
358 nng.print_passes_with_tensors()