blob: c4f2bae24a32ed6cf177f81be776990cdebd75d9 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Mark purpose and select formats for Tensors. Also compresses the weights.
Tim Hall79d07d22020-04-27 18:20:16 +010018from . import rewrite_graph
19from . import weight_compressor
Tim Hallc8310b12020-06-17 14:53:11 +010020from .errors import OperatorError
Diego Russoe8a10452020-04-21 17:39:10 +010021from .tensor import TensorFormat
22from .tensor import TensorPurpose
Tim Hallc8310b12020-06-17 14:53:11 +010023from .tflite_mapping import custom_prefix
Tim Hall79d07d22020-04-27 18:20:16 +010024
25
26def purpose_from_list(lst):
27 def purpose(op, idx):
28 return lst[idx]
29
30 return purpose
31
32
33def all_fm(op, idx):
34 return TensorPurpose.FeatureMap
35
36
37def all_parameter(op, idx):
38 return TensorPurpose.FeatureMap
39
40
41def input0_from_output_rest_parameter(op, idx):
42 if idx == 0:
43 res = op.outputs[0].purpose
44 if res == TensorPurpose.Unknown:
45 print("Warning: Propagating unknown tensor purpose", op)
46 return res
47 return TensorPurpose.FeatureMap
48
49
50def inputs_from_output(op, idx):
51 res = op.outputs[0].purpose
52 if res == TensorPurpose.Unknown:
53 print("Warning: Propagating unknown tensor purpose", op)
54 return res
55
Diego Russoea6111a2020-04-14 18:41:58 +010056
Tim Hall79d07d22020-04-27 18:20:16 +010057tensor_purposes = [ # ops, input_purpose
58 (
59 set(
60 (
61 "Relu",
62 "Relu6",
63 "Mul",
64 "Add",
65 "Sub",
66 "Rsqrt",
67 "Abs",
68 "Cast",
69 "Exp",
70 "Floor",
71 "FloorDiv",
72 "FloorMod",
73 "SquaredDifference",
74 "AddN",
75 "BiasAdd",
76 "RealDiv",
77 "Maximum",
78 "Minimum",
79 "Sigmoid",
80 "Tanh",
81 "FusedBatchNorm",
82 "AvgPool",
83 "MaxPool",
84 "Squeeze",
85 "Softmax",
86 "LRN",
87 "Assign",
88 "BatchMatMul",
89 "ZerosLike",
90 "ExtractImagePatches",
91 "MulAct",
92 "AddAct",
93 "SubAct",
94 "DivAct",
95 "AvgPoolAct",
96 "MaxPoolAct",
97 "LeakyRelu",
98 )
99 ),
100 all_fm,
101 ),
102 (
103 set(
104 (
105 "Conv2D",
106 "DepthwiseConv2dNative",
107 "MatMul",
108 "Conv2DBiasAct",
109 "DepthwiseConv2dBiasAct",
110 "FullyConnectedAct",
111 )
112 ),
113 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.Weights, TensorPurpose.FeatureMap]),
114 ),
115 (
Jacob Bohlincf7da102020-05-20 09:03:40 +0200116 set(("Conv2DBackpropInputSwitchedBias",)),
Tim Hallc30f4952020-06-15 20:47:35 +0100117 purpose_from_list(
118 [TensorPurpose.FeatureMap, TensorPurpose.Weights, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
119 ),
Tim Hall79d07d22020-04-27 18:20:16 +0100120 ),
121 (
122 set(("QuantizedConv2D", "QuantizedMatMul")),
123 purpose_from_list(
124 [
125 TensorPurpose.FeatureMap,
126 TensorPurpose.Weights,
127 TensorPurpose.FeatureMap,
128 TensorPurpose.FeatureMap,
129 TensorPurpose.FeatureMap,
130 TensorPurpose.FeatureMap,
131 ]
132 ),
133 ),
134 (
135 set(
136 (
137 "Reshape",
138 "Min",
139 "Max",
140 "Mean",
141 "Pad",
142 "MirrorPad",
143 "ArgMax",
144 "ArgMin",
145 "ExpandDims",
146 "ResizeNearestNeighbor",
147 "ResizeBilinear",
148 "Tile",
149 "Transpose",
150 "Mfcc",
151 )
152 ),
153 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
154 ),
155 (
156 set(("QuantizedReshape", "QuantizedResizeBilinear")),
157 purpose_from_list(
158 [TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
159 ),
160 ),
161 (
162 set(("QuantizedBiasAdd", "QuantizedAdd", "QuantizedMul")),
163 purpose_from_list(
164 [
165 TensorPurpose.FeatureMap,
166 TensorPurpose.FeatureMap,
167 TensorPurpose.FeatureMap,
168 TensorPurpose.FeatureMap,
169 TensorPurpose.FeatureMap,
170 TensorPurpose.FeatureMap,
171 ]
172 ),
173 ),
174 (
175 set(
176 (
177 "Dequantize",
178 "Quantize",
179 "QuantizeV2",
180 "QuantizedRelu",
181 "QuantizedRelu1",
182 "QuantizedRelu6",
183 "QuantizedAvgPool",
184 "QuantizedMaxPool",
185 "Slice",
186 "SplitV",
187 )
188 ),
189 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
190 ),
191 (
192 set(("BatchToSpaceND", "SpaceToBatchND", "DepthToSpaceND", "SpaceToDepthND")),
193 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
194 ),
195 (
196 set(("BlockLSTM",)),
197 purpose_from_list(
198 [
199 TensorPurpose.FeatureMap,
200 TensorPurpose.FeatureMap,
201 TensorPurpose.FeatureMap,
202 TensorPurpose.FeatureMap,
203 TensorPurpose.Weights,
204 TensorPurpose.FeatureMap,
205 TensorPurpose.FeatureMap,
206 TensorPurpose.FeatureMap,
207 TensorPurpose.FeatureMap,
208 ]
209 ),
210 ),
211 (set(("SplitSliceRead",)), purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap])),
212 (set(("Shape", "ConcatSliceWrite", "AudioSpectrogram")), purpose_from_list([TensorPurpose.FeatureMap])),
213 (
214 set(("StridedSlice",)),
215 purpose_from_list(
216 [TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
217 ),
218 ),
219 (set(("Fill", "Pack", "Range")), all_parameter),
220 (
221 set(("Requantize",)),
222 purpose_from_list(
223 [
224 TensorPurpose.FeatureMap,
225 TensorPurpose.FeatureMap,
226 TensorPurpose.FeatureMap,
227 TensorPurpose.FeatureMap,
228 TensorPurpose.FeatureMap,
229 ]
230 ),
231 ),
232 (set(("Placeholder", "SubgraphInput", "Const", "VariableV2")), purpose_from_list([])),
233 (set(("FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxVars")), input0_from_output_rest_parameter),
234 (
235 set(("Square", "Sqrt", "Log", "Less", "Enter", "Exit", "Identity", "StopGradient", "Merge", "Switch")),
236 inputs_from_output,
237 ),
238 (None, all_fm),
239]
240
241
242for ops, input_purpose in tensor_purposes:
243 if ops is None:
244 continue
245 for op in ops:
246 assert len(op) > 1, "string literal has been decomposed"
247
248
249def mark_tensor_purpose(nng, arch, verbose_tensor_purpose=False):
250 def mark_tensor_helper(tens, purpose):
251
252 if tens.purpose == TensorPurpose.Unknown or tens.purpose == purpose:
253 tens.purpose = purpose
254 else:
255 assert 0, "Cannot resolve tensor purpose %s and %s for tensor %s" % (tens.purpose, purpose, tens)
256 tens.mem_area = arch.tensor_storage_mem_area[tens.purpose]
257
258 if len(tens.ops) == 1 and tens.ops[0].type == "Const":
259 tens.mem_area = (
260 arch.permanent_storage_mem_area
261 ) # special case constants, as they must be in permanent storage
262
263 def rewrite_mark_tensor_purpose(op, arch):
264 # find disconnected outputs and mark as parameters
265 for tens in op.outputs:
266 if not tens.consumers():
267 mark_tensor_helper(tens, TensorPurpose.FeatureMap)
268
269 for ops, input_purpose in tensor_purposes:
270 if ops is None or op.type in ops:
271 if ops is None:
272 print(
Tim Hallc8310b12020-06-17 14:53:11 +0100273 "Warning: Don't know how to mark up purpose for",
Tim Hall79d07d22020-04-27 18:20:16 +0100274 op.type,
275 op.inputs,
276 "triggering all feature map fallback",
277 )
Tim Hallc8310b12020-06-17 14:53:11 +0100278
Tim Hall79d07d22020-04-27 18:20:16 +0100279 for idx, tens in enumerate(op.inputs):
280 purpose = input_purpose(op, idx)
281 mark_tensor_helper(tens, purpose)
Tim Hallc8310b12020-06-17 14:53:11 +0100282
Louis Verhaardc4cbbc92020-05-18 13:40:02 +0200283 if op.type == "Reshape":
284 # Reshape's input and output point to same data
285 op.outputs[0].mem_area = op.inputs[0].mem_area
Tim Hallc8310b12020-06-17 14:53:11 +0100286
287 if op.type.startswith(custom_prefix) and op.attrs.get("custom_type", "") == "ExistingNpuOp":
288 scratch_tensor = None
289
290 if len(op.inputs) >= 3:
291 scratch_tensor = op.inputs[2] # should be existing scratch tensor
292 if scratch_tensor.name.endswith("_scratch"):
293 scratch_tensor.purpose = TensorPurpose.Scratch
294
295 if scratch_tensor is None:
296 raise OperatorError(op, "Scratch tensor not found.")
297
Tim Hall79d07d22020-04-27 18:20:16 +0100298 break
Tim Hallc8310b12020-06-17 14:53:11 +0100299
Tim Hall79d07d22020-04-27 18:20:16 +0100300 return op
301
302 for sg in nng.subgraphs:
303 sg = rewrite_graph.rewrite_graph_pre_order(sg, arch, [], [rewrite_mark_tensor_purpose])
304 for tens in sg.output_tensors:
305 mark_tensor_helper(tens, TensorPurpose.FeatureMap)
306
307 if verbose_tensor_purpose:
308 nng.print_graph_with_tensors()
309
310 return nng
311
312
313reshape_operations = set(
314 (
315 "Reshape",
316 "QuantizedReshape",
317 "ExpandDims",
318 "Squeeze",
319 "BatchToSpaceND",
320 "SpaceToBatchND",
321 "DepthToSpaceND",
322 "SpaceToDepthND",
323 "Placeholder",
324 )
325)
326
327
328def mark_tensor_format(nng, arch, verbose_tensor_format=False):
329 formats_for_tensor = {}
330
331 def init_tens(tens):
332 if tens.purpose == TensorPurpose.FeatureMap:
333 fmt = arch.default_feature_map_format
334 elif tens.purpose == TensorPurpose.Weights:
335 fmt = arch.default_weight_format
Tim Hallc8310b12020-06-17 14:53:11 +0100336 elif tens.purpose == TensorPurpose.Scratch:
337 fmt = arch.default_feature_map_format
Tim Hall465582c2020-05-26 09:33:14 +0100338 elif tens.purpose == TensorPurpose.Unknown:
339 fmt = TensorFormat.Unknown
Tim Hall79d07d22020-04-27 18:20:16 +0100340 else:
341 assert 0, "unknown tensor purpose %s" % (tens.purpose,)
342 return fmt
343
Tim Hall79d07d22020-04-27 18:20:16 +0100344 def visit_tens(tens, ps):
Diego Russoea6111a2020-04-14 18:41:58 +0100345 if tens not in formats_for_tensor:
Tim Hall79d07d22020-04-27 18:20:16 +0100346 fmt = init_tens(tens)
347 else:
348 fmt = formats_for_tensor[tens]
349
350 formats_for_tensor[tens] = fmt
351
352 for sg in nng.subgraphs:
353 for ps in sg.passes:
354 for tens in ps.outputs:
355 visit_tens(tens, ps)
356 for tens in ps.intermediates:
357 visit_tens(tens, ps)
358 for tens in ps.inputs:
359 visit_tens(tens, ps)
360
361 for tens, fmt in formats_for_tensor.items():
362 tens.set_format(fmt, arch)
363 if fmt == TensorFormat.WeightsCompressed and tens.values is not None:
Louis Verhaard3c07c972020-05-07 08:12:58 +0200364 src_tens = tens.get_dma_src_tensor()
365 if src_tens is not None:
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200366 op = tens.find_npu_op()
367 npu_block_type = op.attrs["npu_block_type"]
368 weight_compressor.compress_weights(arch, nng, tens, npu_block_type, 32, 32, op.get_dilation_h_w())
Tim Hall79d07d22020-04-27 18:20:16 +0100369 # Alias compressed weights back into source tensor
Louis Verhaard3c07c972020-05-07 08:12:58 +0200370 src_tens.copy_compressed_weight_info(tens)
Tim Hall79d07d22020-04-27 18:20:16 +0100371
372 if verbose_tensor_format:
373 nng.print_passes_with_tensors()