blob: 72ab8cfa7c4db43634d0d5e901ada5fe0beb866e [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Mark purpose and select formats for Tensors. Also compresses the weights.
Tim Hall79d07d22020-04-27 18:20:16 +010018from . import rewrite_graph
19from . import weight_compressor
Diego Russoe8a10452020-04-21 17:39:10 +010020from .tensor import TensorFormat
21from .tensor import TensorPurpose
Tim Hall79d07d22020-04-27 18:20:16 +010022
23
24def purpose_from_list(lst):
25 def purpose(op, idx):
26 return lst[idx]
27
28 return purpose
29
30
31def all_fm(op, idx):
32 return TensorPurpose.FeatureMap
33
34
35def all_parameter(op, idx):
36 return TensorPurpose.FeatureMap
37
38
39def input0_from_output_rest_parameter(op, idx):
40 if idx == 0:
41 res = op.outputs[0].purpose
42 if res == TensorPurpose.Unknown:
43 print("Warning: Propagating unknown tensor purpose", op)
44 return res
45 return TensorPurpose.FeatureMap
46
47
48def inputs_from_output(op, idx):
49 res = op.outputs[0].purpose
50 if res == TensorPurpose.Unknown:
51 print("Warning: Propagating unknown tensor purpose", op)
52 return res
53
Diego Russoea6111a2020-04-14 18:41:58 +010054
Tim Hall79d07d22020-04-27 18:20:16 +010055tensor_purposes = [ # ops, input_purpose
56 (
57 set(
58 (
59 "Relu",
60 "Relu6",
61 "Mul",
62 "Add",
63 "Sub",
64 "Rsqrt",
65 "Abs",
66 "Cast",
67 "Exp",
68 "Floor",
69 "FloorDiv",
70 "FloorMod",
71 "SquaredDifference",
72 "AddN",
73 "BiasAdd",
74 "RealDiv",
75 "Maximum",
76 "Minimum",
77 "Sigmoid",
78 "Tanh",
79 "FusedBatchNorm",
80 "AvgPool",
81 "MaxPool",
82 "Squeeze",
83 "Softmax",
84 "LRN",
85 "Assign",
86 "BatchMatMul",
87 "ZerosLike",
88 "ExtractImagePatches",
89 "MulAct",
90 "AddAct",
91 "SubAct",
92 "DivAct",
93 "AvgPoolAct",
94 "MaxPoolAct",
95 "LeakyRelu",
96 )
97 ),
98 all_fm,
99 ),
100 (
101 set(
102 (
103 "Conv2D",
104 "DepthwiseConv2dNative",
105 "MatMul",
106 "Conv2DBiasAct",
107 "DepthwiseConv2dBiasAct",
108 "FullyConnectedAct",
109 )
110 ),
111 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.Weights, TensorPurpose.FeatureMap]),
112 ),
113 (
Jacob Bohlincf7da102020-05-20 09:03:40 +0200114 set(("Conv2DBackpropInputSwitchedBias",)),
Tim Hallc30f4952020-06-15 20:47:35 +0100115 purpose_from_list(
116 [TensorPurpose.FeatureMap, TensorPurpose.Weights, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
117 ),
Tim Hall79d07d22020-04-27 18:20:16 +0100118 ),
119 (
120 set(("QuantizedConv2D", "QuantizedMatMul")),
121 purpose_from_list(
122 [
123 TensorPurpose.FeatureMap,
124 TensorPurpose.Weights,
125 TensorPurpose.FeatureMap,
126 TensorPurpose.FeatureMap,
127 TensorPurpose.FeatureMap,
128 TensorPurpose.FeatureMap,
129 ]
130 ),
131 ),
132 (
133 set(
134 (
135 "Reshape",
136 "Min",
137 "Max",
138 "Mean",
139 "Pad",
140 "MirrorPad",
141 "ArgMax",
142 "ArgMin",
143 "ExpandDims",
144 "ResizeNearestNeighbor",
145 "ResizeBilinear",
146 "Tile",
147 "Transpose",
148 "Mfcc",
149 )
150 ),
151 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
152 ),
153 (
154 set(("QuantizedReshape", "QuantizedResizeBilinear")),
155 purpose_from_list(
156 [TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
157 ),
158 ),
159 (
160 set(("QuantizedBiasAdd", "QuantizedAdd", "QuantizedMul")),
161 purpose_from_list(
162 [
163 TensorPurpose.FeatureMap,
164 TensorPurpose.FeatureMap,
165 TensorPurpose.FeatureMap,
166 TensorPurpose.FeatureMap,
167 TensorPurpose.FeatureMap,
168 TensorPurpose.FeatureMap,
169 ]
170 ),
171 ),
172 (
173 set(
174 (
175 "Dequantize",
176 "Quantize",
177 "QuantizeV2",
178 "QuantizedRelu",
179 "QuantizedRelu1",
180 "QuantizedRelu6",
181 "QuantizedAvgPool",
182 "QuantizedMaxPool",
183 "Slice",
184 "SplitV",
185 )
186 ),
187 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
188 ),
189 (
190 set(("BatchToSpaceND", "SpaceToBatchND", "DepthToSpaceND", "SpaceToDepthND")),
191 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
192 ),
193 (
194 set(("BlockLSTM",)),
195 purpose_from_list(
196 [
197 TensorPurpose.FeatureMap,
198 TensorPurpose.FeatureMap,
199 TensorPurpose.FeatureMap,
200 TensorPurpose.FeatureMap,
201 TensorPurpose.Weights,
202 TensorPurpose.FeatureMap,
203 TensorPurpose.FeatureMap,
204 TensorPurpose.FeatureMap,
205 TensorPurpose.FeatureMap,
206 ]
207 ),
208 ),
209 (set(("SplitSliceRead",)), purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap])),
210 (set(("Shape", "ConcatSliceWrite", "AudioSpectrogram")), purpose_from_list([TensorPurpose.FeatureMap])),
211 (
212 set(("StridedSlice",)),
213 purpose_from_list(
214 [TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
215 ),
216 ),
217 (set(("Fill", "Pack", "Range")), all_parameter),
218 (
219 set(("Requantize",)),
220 purpose_from_list(
221 [
222 TensorPurpose.FeatureMap,
223 TensorPurpose.FeatureMap,
224 TensorPurpose.FeatureMap,
225 TensorPurpose.FeatureMap,
226 TensorPurpose.FeatureMap,
227 ]
228 ),
229 ),
230 (set(("Placeholder", "SubgraphInput", "Const", "VariableV2")), purpose_from_list([])),
231 (set(("FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxVars")), input0_from_output_rest_parameter),
232 (
233 set(("Square", "Sqrt", "Log", "Less", "Enter", "Exit", "Identity", "StopGradient", "Merge", "Switch")),
234 inputs_from_output,
235 ),
236 (None, all_fm),
237]
238
239
240for ops, input_purpose in tensor_purposes:
241 if ops is None:
242 continue
243 for op in ops:
244 assert len(op) > 1, "string literal has been decomposed"
245
246
247def mark_tensor_purpose(nng, arch, verbose_tensor_purpose=False):
248 def mark_tensor_helper(tens, purpose):
249
250 if tens.purpose == TensorPurpose.Unknown or tens.purpose == purpose:
251 tens.purpose = purpose
252 else:
253 assert 0, "Cannot resolve tensor purpose %s and %s for tensor %s" % (tens.purpose, purpose, tens)
254 tens.mem_area = arch.tensor_storage_mem_area[tens.purpose]
255
256 if len(tens.ops) == 1 and tens.ops[0].type == "Const":
257 tens.mem_area = (
258 arch.permanent_storage_mem_area
259 ) # special case constants, as they must be in permanent storage
260
261 def rewrite_mark_tensor_purpose(op, arch):
262 # find disconnected outputs and mark as parameters
263 for tens in op.outputs:
264 if not tens.consumers():
265 mark_tensor_helper(tens, TensorPurpose.FeatureMap)
266
267 for ops, input_purpose in tensor_purposes:
268 if ops is None or op.type in ops:
269 if ops is None:
270 print(
271 "warning: don't know how to mark up purpose for",
272 op.type,
273 op.inputs,
274 "triggering all feature map fallback",
275 )
276 for idx, tens in enumerate(op.inputs):
277 purpose = input_purpose(op, idx)
278 mark_tensor_helper(tens, purpose)
Louis Verhaardc4cbbc92020-05-18 13:40:02 +0200279 if op.type == "Reshape":
280 # Reshape's input and output point to same data
281 op.outputs[0].mem_area = op.inputs[0].mem_area
Tim Hall79d07d22020-04-27 18:20:16 +0100282 break
283 return op
284
285 for sg in nng.subgraphs:
286 sg = rewrite_graph.rewrite_graph_pre_order(sg, arch, [], [rewrite_mark_tensor_purpose])
287 for tens in sg.output_tensors:
288 mark_tensor_helper(tens, TensorPurpose.FeatureMap)
289
290 if verbose_tensor_purpose:
291 nng.print_graph_with_tensors()
292
293 return nng
294
295
296reshape_operations = set(
297 (
298 "Reshape",
299 "QuantizedReshape",
300 "ExpandDims",
301 "Squeeze",
302 "BatchToSpaceND",
303 "SpaceToBatchND",
304 "DepthToSpaceND",
305 "SpaceToDepthND",
306 "Placeholder",
307 )
308)
309
310
311def mark_tensor_format(nng, arch, verbose_tensor_format=False):
312 formats_for_tensor = {}
313
314 def init_tens(tens):
315 if tens.purpose == TensorPurpose.FeatureMap:
316 fmt = arch.default_feature_map_format
317 elif tens.purpose == TensorPurpose.Weights:
318 fmt = arch.default_weight_format
Tim Hall465582c2020-05-26 09:33:14 +0100319 elif tens.purpose == TensorPurpose.Unknown:
320 fmt = TensorFormat.Unknown
Tim Hall79d07d22020-04-27 18:20:16 +0100321 else:
322 assert 0, "unknown tensor purpose %s" % (tens.purpose,)
323 return fmt
324
Tim Hall79d07d22020-04-27 18:20:16 +0100325 def visit_tens(tens, ps):
Diego Russoea6111a2020-04-14 18:41:58 +0100326 if tens not in formats_for_tensor:
Tim Hall79d07d22020-04-27 18:20:16 +0100327 fmt = init_tens(tens)
328 else:
329 fmt = formats_for_tensor[tens]
330
331 formats_for_tensor[tens] = fmt
332
333 for sg in nng.subgraphs:
334 for ps in sg.passes:
335 for tens in ps.outputs:
336 visit_tens(tens, ps)
337 for tens in ps.intermediates:
338 visit_tens(tens, ps)
339 for tens in ps.inputs:
340 visit_tens(tens, ps)
341
342 for tens, fmt in formats_for_tensor.items():
343 tens.set_format(fmt, arch)
344 if fmt == TensorFormat.WeightsCompressed and tens.values is not None:
Louis Verhaard3c07c972020-05-07 08:12:58 +0200345 src_tens = tens.get_dma_src_tensor()
346 if src_tens is not None:
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200347 op = tens.find_npu_op()
348 npu_block_type = op.attrs["npu_block_type"]
349 weight_compressor.compress_weights(arch, nng, tens, npu_block_type, 32, 32, op.get_dilation_h_w())
Tim Hall79d07d22020-04-27 18:20:16 +0100350 # Alias compressed weights back into source tensor
Louis Verhaard3c07c972020-05-07 08:12:58 +0200351 src_tens.copy_compressed_weight_info(tens)
Tim Hall79d07d22020-04-27 18:20:16 +0100352
353 if verbose_tensor_format:
354 nng.print_passes_with_tensors()