blob: 5231e860f508832de0903a0b0203863f8062e952 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Mark purpose and select formats for Tensors. Also compresses the weights.
Tim Hall79d07d22020-04-27 18:20:16 +010018from . import rewrite_graph
19from . import weight_compressor
Diego Russoe8a10452020-04-21 17:39:10 +010020from .tensor import TensorFormat
21from .tensor import TensorPurpose
Tim Hall79d07d22020-04-27 18:20:16 +010022
23
24def purpose_from_list(lst):
25 def purpose(op, idx):
26 return lst[idx]
27
28 return purpose
29
30
31def all_fm(op, idx):
32 return TensorPurpose.FeatureMap
33
34
35def all_parameter(op, idx):
36 return TensorPurpose.FeatureMap
37
38
39def input0_from_output_rest_parameter(op, idx):
40 if idx == 0:
41 res = op.outputs[0].purpose
42 if res == TensorPurpose.Unknown:
43 print("Warning: Propagating unknown tensor purpose", op)
44 return res
45 return TensorPurpose.FeatureMap
46
47
48def inputs_from_output(op, idx):
49 res = op.outputs[0].purpose
50 if res == TensorPurpose.Unknown:
51 print("Warning: Propagating unknown tensor purpose", op)
52 return res
53
Diego Russoea6111a2020-04-14 18:41:58 +010054
Tim Hall79d07d22020-04-27 18:20:16 +010055tensor_purposes = [ # ops, input_purpose
56 (
57 set(
58 (
59 "Relu",
60 "Relu6",
61 "Mul",
62 "Add",
63 "Sub",
64 "Rsqrt",
65 "Abs",
66 "Cast",
67 "Exp",
68 "Floor",
69 "FloorDiv",
70 "FloorMod",
71 "SquaredDifference",
72 "AddN",
73 "BiasAdd",
74 "RealDiv",
75 "Maximum",
76 "Minimum",
77 "Sigmoid",
78 "Tanh",
79 "FusedBatchNorm",
80 "AvgPool",
81 "MaxPool",
82 "Squeeze",
83 "Softmax",
84 "LRN",
85 "Assign",
86 "BatchMatMul",
87 "ZerosLike",
88 "ExtractImagePatches",
89 "MulAct",
90 "AddAct",
91 "SubAct",
92 "DivAct",
93 "AvgPoolAct",
94 "MaxPoolAct",
95 "LeakyRelu",
96 )
97 ),
98 all_fm,
99 ),
100 (
101 set(
102 (
103 "Conv2D",
104 "DepthwiseConv2dNative",
105 "MatMul",
106 "Conv2DBiasAct",
107 "DepthwiseConv2dBiasAct",
108 "FullyConnectedAct",
109 )
110 ),
111 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.Weights, TensorPurpose.FeatureMap]),
112 ),
113 (
Jacob Bohlincf7da102020-05-20 09:03:40 +0200114 set(("Conv2DBackpropInputSwitchedBias",)),
115 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.Weights, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
Tim Hall79d07d22020-04-27 18:20:16 +0100116 ),
117 (
118 set(("QuantizedConv2D", "QuantizedMatMul")),
119 purpose_from_list(
120 [
121 TensorPurpose.FeatureMap,
122 TensorPurpose.Weights,
123 TensorPurpose.FeatureMap,
124 TensorPurpose.FeatureMap,
125 TensorPurpose.FeatureMap,
126 TensorPurpose.FeatureMap,
127 ]
128 ),
129 ),
130 (
131 set(
132 (
133 "Reshape",
134 "Min",
135 "Max",
136 "Mean",
137 "Pad",
138 "MirrorPad",
139 "ArgMax",
140 "ArgMin",
141 "ExpandDims",
142 "ResizeNearestNeighbor",
143 "ResizeBilinear",
144 "Tile",
145 "Transpose",
146 "Mfcc",
147 )
148 ),
149 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
150 ),
151 (
152 set(("QuantizedReshape", "QuantizedResizeBilinear")),
153 purpose_from_list(
154 [TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
155 ),
156 ),
157 (
158 set(("QuantizedBiasAdd", "QuantizedAdd", "QuantizedMul")),
159 purpose_from_list(
160 [
161 TensorPurpose.FeatureMap,
162 TensorPurpose.FeatureMap,
163 TensorPurpose.FeatureMap,
164 TensorPurpose.FeatureMap,
165 TensorPurpose.FeatureMap,
166 TensorPurpose.FeatureMap,
167 ]
168 ),
169 ),
170 (
171 set(
172 (
173 "Dequantize",
174 "Quantize",
175 "QuantizeV2",
176 "QuantizedRelu",
177 "QuantizedRelu1",
178 "QuantizedRelu6",
179 "QuantizedAvgPool",
180 "QuantizedMaxPool",
181 "Slice",
182 "SplitV",
183 )
184 ),
185 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
186 ),
187 (
188 set(("BatchToSpaceND", "SpaceToBatchND", "DepthToSpaceND", "SpaceToDepthND")),
189 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
190 ),
191 (
192 set(("BlockLSTM",)),
193 purpose_from_list(
194 [
195 TensorPurpose.FeatureMap,
196 TensorPurpose.FeatureMap,
197 TensorPurpose.FeatureMap,
198 TensorPurpose.FeatureMap,
199 TensorPurpose.Weights,
200 TensorPurpose.FeatureMap,
201 TensorPurpose.FeatureMap,
202 TensorPurpose.FeatureMap,
203 TensorPurpose.FeatureMap,
204 ]
205 ),
206 ),
207 (set(("SplitSliceRead",)), purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap])),
208 (set(("Shape", "ConcatSliceWrite", "AudioSpectrogram")), purpose_from_list([TensorPurpose.FeatureMap])),
209 (
210 set(("StridedSlice",)),
211 purpose_from_list(
212 [TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
213 ),
214 ),
215 (set(("Fill", "Pack", "Range")), all_parameter),
216 (
217 set(("Requantize",)),
218 purpose_from_list(
219 [
220 TensorPurpose.FeatureMap,
221 TensorPurpose.FeatureMap,
222 TensorPurpose.FeatureMap,
223 TensorPurpose.FeatureMap,
224 TensorPurpose.FeatureMap,
225 ]
226 ),
227 ),
228 (set(("Placeholder", "SubgraphInput", "Const", "VariableV2")), purpose_from_list([])),
229 (set(("FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxVars")), input0_from_output_rest_parameter),
230 (
231 set(("Square", "Sqrt", "Log", "Less", "Enter", "Exit", "Identity", "StopGradient", "Merge", "Switch")),
232 inputs_from_output,
233 ),
234 (None, all_fm),
235]
236
237
238for ops, input_purpose in tensor_purposes:
239 if ops is None:
240 continue
241 for op in ops:
242 assert len(op) > 1, "string literal has been decomposed"
243
244
245def mark_tensor_purpose(nng, arch, verbose_tensor_purpose=False):
246 def mark_tensor_helper(tens, purpose):
247
248 if tens.purpose == TensorPurpose.Unknown or tens.purpose == purpose:
249 tens.purpose = purpose
250 else:
251 assert 0, "Cannot resolve tensor purpose %s and %s for tensor %s" % (tens.purpose, purpose, tens)
252 tens.mem_area = arch.tensor_storage_mem_area[tens.purpose]
253
254 if len(tens.ops) == 1 and tens.ops[0].type == "Const":
255 tens.mem_area = (
256 arch.permanent_storage_mem_area
257 ) # special case constants, as they must be in permanent storage
258
259 def rewrite_mark_tensor_purpose(op, arch):
260 # find disconnected outputs and mark as parameters
261 for tens in op.outputs:
262 if not tens.consumers():
263 mark_tensor_helper(tens, TensorPurpose.FeatureMap)
264
265 for ops, input_purpose in tensor_purposes:
266 if ops is None or op.type in ops:
267 if ops is None:
268 print(
269 "warning: don't know how to mark up purpose for",
270 op.type,
271 op.inputs,
272 "triggering all feature map fallback",
273 )
274 for idx, tens in enumerate(op.inputs):
275 purpose = input_purpose(op, idx)
276 mark_tensor_helper(tens, purpose)
Louis Verhaardc4cbbc92020-05-18 13:40:02 +0200277 if op.type == "Reshape":
278 # Reshape's input and output point to same data
279 op.outputs[0].mem_area = op.inputs[0].mem_area
Tim Hall79d07d22020-04-27 18:20:16 +0100280 break
281 return op
282
283 for sg in nng.subgraphs:
284 sg = rewrite_graph.rewrite_graph_pre_order(sg, arch, [], [rewrite_mark_tensor_purpose])
285 for tens in sg.output_tensors:
286 mark_tensor_helper(tens, TensorPurpose.FeatureMap)
287
288 if verbose_tensor_purpose:
289 nng.print_graph_with_tensors()
290
291 return nng
292
293
294reshape_operations = set(
295 (
296 "Reshape",
297 "QuantizedReshape",
298 "ExpandDims",
299 "Squeeze",
300 "BatchToSpaceND",
301 "SpaceToBatchND",
302 "DepthToSpaceND",
303 "SpaceToDepthND",
304 "Placeholder",
305 )
306)
307
308
309def mark_tensor_format(nng, arch, verbose_tensor_format=False):
310 formats_for_tensor = {}
311
312 def init_tens(tens):
313 if tens.purpose == TensorPurpose.FeatureMap:
314 fmt = arch.default_feature_map_format
315 elif tens.purpose == TensorPurpose.Weights:
316 fmt = arch.default_weight_format
317 else:
318 assert 0, "unknown tensor purpose %s" % (tens.purpose,)
319 return fmt
320
Tim Hall79d07d22020-04-27 18:20:16 +0100321 def visit_tens(tens, ps):
Diego Russoea6111a2020-04-14 18:41:58 +0100322 if tens not in formats_for_tensor:
Tim Hall79d07d22020-04-27 18:20:16 +0100323 fmt = init_tens(tens)
324 else:
325 fmt = formats_for_tensor[tens]
326
327 formats_for_tensor[tens] = fmt
328
329 for sg in nng.subgraphs:
330 for ps in sg.passes:
331 for tens in ps.outputs:
332 visit_tens(tens, ps)
333 for tens in ps.intermediates:
334 visit_tens(tens, ps)
335 for tens in ps.inputs:
336 visit_tens(tens, ps)
337
338 for tens, fmt in formats_for_tensor.items():
339 tens.set_format(fmt, arch)
340 if fmt == TensorFormat.WeightsCompressed and tens.values is not None:
Louis Verhaard3c07c972020-05-07 08:12:58 +0200341 src_tens = tens.get_dma_src_tensor()
342 if src_tens is not None:
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200343 op = tens.find_npu_op()
344 npu_block_type = op.attrs["npu_block_type"]
345 weight_compressor.compress_weights(arch, nng, tens, npu_block_type, 32, 32, op.get_dilation_h_w())
Tim Hall79d07d22020-04-27 18:20:16 +0100346 # Alias compressed weights back into source tensor
Louis Verhaard3c07c972020-05-07 08:12:58 +0200347 src_tens.copy_compressed_weight_info(tens)
Tim Hall79d07d22020-04-27 18:20:16 +0100348
349 if verbose_tensor_format:
350 nng.print_passes_with_tensors()