blob: 206d8365e8725d7c6f753865e8f24ccc4e72aa66 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Mark purpose and select formats for Tensors. Also compresses the weights.
Tim Hall79d07d22020-04-27 18:20:16 +010018from . import rewrite_graph
19from . import weight_compressor
Tim Hallc8310b12020-06-17 14:53:11 +010020from .errors import OperatorError
Louis Verhaardaee5d752020-09-30 09:01:52 +020021from .operation import CustomType
22from .operation import Op
Patrik Gustavssoneca2e952020-05-27 09:15:11 +020023from .tensor import MemType
Diego Russoe8a10452020-04-21 17:39:10 +010024from .tensor import TensorFormat
25from .tensor import TensorPurpose
Tim Hall79d07d22020-04-27 18:20:16 +010026
27
28def purpose_from_list(lst):
29 def purpose(op, idx):
30 return lst[idx]
31
32 return purpose
33
34
35def all_fm(op, idx):
36 return TensorPurpose.FeatureMap
37
38
39def all_parameter(op, idx):
40 return TensorPurpose.FeatureMap
41
42
43def input0_from_output_rest_parameter(op, idx):
44 if idx == 0:
45 res = op.outputs[0].purpose
46 if res == TensorPurpose.Unknown:
47 print("Warning: Propagating unknown tensor purpose", op)
48 return res
49 return TensorPurpose.FeatureMap
50
51
52def inputs_from_output(op, idx):
53 res = op.outputs[0].purpose
54 if res == TensorPurpose.Unknown:
55 print("Warning: Propagating unknown tensor purpose", op)
56 return res
57
Diego Russoea6111a2020-04-14 18:41:58 +010058
Tim Hall79d07d22020-04-27 18:20:16 +010059tensor_purposes = [ # ops, input_purpose
60 (
61 set(
62 (
Louis Verhaardaee5d752020-09-30 09:01:52 +020063 Op.Relu,
64 Op.Relu6,
65 Op.Rsqrt,
66 Op.Abs,
67 Op.Cast,
68 Op.Exp,
69 Op.Floor,
70 Op.FloorDiv,
71 Op.FloorMod,
72 Op.SquaredDifference,
73 Op.AddN,
74 Op.Maximum,
75 Op.Minimum,
76 Op.Sigmoid,
77 Op.Tanh,
78 Op.AvgPool,
79 Op.MaxPool,
80 Op.Squeeze,
81 Op.Softmax,
82 Op.LRN,
83 Op.BatchMatMul,
84 Op.ZerosLike,
85 Op.Mul,
86 Op.Add,
87 Op.Sub,
88 Op.Div,
89 Op.LeakyRelu,
90 Op.CLZ,
91 Op.SHL,
92 Op.SHR,
93 Op.ReduceSum,
Tim Hall79d07d22020-04-27 18:20:16 +010094 )
95 ),
96 all_fm,
97 ),
98 (
Louis Verhaardaee5d752020-09-30 09:01:52 +020099 set((Op.Conv2D, Op.MatMul, Op.Conv2DBias, Op.DepthwiseConv2DBias, Op.FullyConnected,)),
Tim Hall79d07d22020-04-27 18:20:16 +0100100 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.Weights, TensorPurpose.FeatureMap]),
101 ),
102 (
Louis Verhaardaee5d752020-09-30 09:01:52 +0200103 set((Op.Conv2DBackpropInputSwitchedBias,)),
Tim Hallc30f4952020-06-15 20:47:35 +0100104 purpose_from_list(
105 [TensorPurpose.FeatureMap, TensorPurpose.Weights, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
106 ),
Tim Hall79d07d22020-04-27 18:20:16 +0100107 ),
108 (
Louis Verhaardaee5d752020-09-30 09:01:52 +0200109 set((Op.QuantizedConv2D, Op.QuantizedMatMul)),
Tim Hall79d07d22020-04-27 18:20:16 +0100110 purpose_from_list(
111 [
112 TensorPurpose.FeatureMap,
113 TensorPurpose.Weights,
114 TensorPurpose.FeatureMap,
115 TensorPurpose.FeatureMap,
116 TensorPurpose.FeatureMap,
117 TensorPurpose.FeatureMap,
118 ]
119 ),
120 ),
121 (
122 set(
123 (
Louis Verhaardaee5d752020-09-30 09:01:52 +0200124 Op.Reshape,
125 Op.Min,
126 Op.Max,
127 Op.Mean,
128 Op.Pad,
129 Op.MirrorPad,
130 Op.ArgMax,
131 Op.ArgMin,
132 Op.ExpandDims,
133 Op.ResizeNearestNeighbor,
134 Op.ResizeBilinear,
135 Op.Tile,
136 Op.Transpose,
Tim Hall79d07d22020-04-27 18:20:16 +0100137 )
138 ),
139 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
140 ),
141 (
Louis Verhaardaee5d752020-09-30 09:01:52 +0200142 set((Op.QuantizedReshape,)),
Tim Hall79d07d22020-04-27 18:20:16 +0100143 purpose_from_list(
144 [TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
145 ),
146 ),
147 (
Louis Verhaardaee5d752020-09-30 09:01:52 +0200148 set((Op.Dequantize, Op.Quantize, Op.QuantizedAvgPool, Op.QuantizedMaxPool, Op.Slice, Op.SplitV,)),
Tim Hall79d07d22020-04-27 18:20:16 +0100149 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
150 ),
151 (
Louis Verhaardaee5d752020-09-30 09:01:52 +0200152 set((Op.BatchToSpaceND, Op.SpaceToBatchND, Op.DepthToSpace, Op.SpaceToDepth)),
Tim Hall79d07d22020-04-27 18:20:16 +0100153 purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]),
154 ),
155 (
Louis Verhaardaee5d752020-09-30 09:01:52 +0200156 set((Op.BlockLSTM,)),
Tim Hall79d07d22020-04-27 18:20:16 +0100157 purpose_from_list(
158 [
159 TensorPurpose.FeatureMap,
160 TensorPurpose.FeatureMap,
161 TensorPurpose.FeatureMap,
162 TensorPurpose.FeatureMap,
163 TensorPurpose.Weights,
164 TensorPurpose.FeatureMap,
165 TensorPurpose.FeatureMap,
166 TensorPurpose.FeatureMap,
167 TensorPurpose.FeatureMap,
168 ]
169 ),
170 ),
Louis Verhaardaee5d752020-09-30 09:01:52 +0200171 (set((Op.SplitSliceRead,)), purpose_from_list([TensorPurpose.FeatureMap, TensorPurpose.FeatureMap])),
172 (set((Op.Shape, Op.ConcatSliceWrite)), purpose_from_list([TensorPurpose.FeatureMap])),
Tim Hall79d07d22020-04-27 18:20:16 +0100173 (
Louis Verhaardaee5d752020-09-30 09:01:52 +0200174 set((Op.StridedSlice,)),
Tim Hall79d07d22020-04-27 18:20:16 +0100175 purpose_from_list(
176 [TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap, TensorPurpose.FeatureMap]
177 ),
178 ),
Louis Verhaardaee5d752020-09-30 09:01:52 +0200179 (set((Op.Fill, Op.Pack, Op.Range)), all_parameter),
180 (set((Op.Placeholder, Op.SubgraphInput, Op.Const,)), purpose_from_list([])),
181 (set((Op.FakeQuantWithMinMaxArgs,)), input0_from_output_rest_parameter),
182 (set((Op.Square, Op.Sqrt, Op.Log, Op.Less, Op.Identity,)), inputs_from_output,),
Tim Hall79d07d22020-04-27 18:20:16 +0100183 (None, all_fm),
184]
185
186
187for ops, input_purpose in tensor_purposes:
188 if ops is None:
189 continue
Tim Hall79d07d22020-04-27 18:20:16 +0100190
191
192def mark_tensor_purpose(nng, arch, verbose_tensor_purpose=False):
193 def mark_tensor_helper(tens, purpose):
Tim Hall79d07d22020-04-27 18:20:16 +0100194 if tens.purpose == TensorPurpose.Unknown or tens.purpose == purpose:
195 tens.purpose = purpose
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200196 elif tens.purpose != TensorPurpose.LUT:
Tim Hall79d07d22020-04-27 18:20:16 +0100197 assert 0, "Cannot resolve tensor purpose %s and %s for tensor %s" % (tens.purpose, purpose, tens)
198 tens.mem_area = arch.tensor_storage_mem_area[tens.purpose]
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200199 tens.mem_type = arch.tensor_storage_mem_type[tens.purpose]
Tim Hall79d07d22020-04-27 18:20:16 +0100200
Louis Verhaardaee5d752020-09-30 09:01:52 +0200201 if len(tens.ops) == 1 and tens.ops[0].type == Op.Const:
Tim Hall79d07d22020-04-27 18:20:16 +0100202 tens.mem_area = (
203 arch.permanent_storage_mem_area
204 ) # special case constants, as they must be in permanent storage
Patrik Gustavssoneca2e952020-05-27 09:15:11 +0200205 tens.mem_type = MemType.Permanent_NPU
Tim Hall79d07d22020-04-27 18:20:16 +0100206
Patrik Gustavsson3010d9b2020-10-01 08:22:10 +0200207 def rewrite_mark_tensor_purpose(op, arch, nng):
Tim Hall79d07d22020-04-27 18:20:16 +0100208 # find disconnected outputs and mark as parameters
209 for tens in op.outputs:
210 if not tens.consumers():
211 mark_tensor_helper(tens, TensorPurpose.FeatureMap)
212
213 for ops, input_purpose in tensor_purposes:
214 if ops is None or op.type in ops:
215 if ops is None:
216 print(
Tim Hallc8310b12020-06-17 14:53:11 +0100217 "Warning: Don't know how to mark up purpose for",
Tim Hall79d07d22020-04-27 18:20:16 +0100218 op.type,
219 op.inputs,
220 "triggering all feature map fallback",
221 )
Tim Hallc8310b12020-06-17 14:53:11 +0100222
Tim Hall79d07d22020-04-27 18:20:16 +0100223 for idx, tens in enumerate(op.inputs):
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200224 if tens is None:
225 continue
Louis Verhaardb9fc33c2020-08-13 11:47:36 +0200226 purpose = input_purpose(op, idx) if tens.purpose == TensorPurpose.Unknown else tens.purpose
Tim Hall79d07d22020-04-27 18:20:16 +0100227 mark_tensor_helper(tens, purpose)
Tim Hallc8310b12020-06-17 14:53:11 +0100228
Louis Verhaardaee5d752020-09-30 09:01:52 +0200229 if op.type == Op.Reshape:
Louis Verhaardc4cbbc92020-05-18 13:40:02 +0200230 # Reshape's input and output point to same data
231 op.outputs[0].mem_area = op.inputs[0].mem_area
Tim Hallc8310b12020-06-17 14:53:11 +0100232
Louis Verhaardaee5d752020-09-30 09:01:52 +0200233 if op.type == Op.Custom and op.attrs.get("custom_type") == CustomType.ExistingNpuOp:
Tim Hallc8310b12020-06-17 14:53:11 +0100234 scratch_tensor = None
235
236 if len(op.inputs) >= 3:
237 scratch_tensor = op.inputs[2] # should be existing scratch tensor
238 if scratch_tensor.name.endswith("_scratch"):
239 scratch_tensor.purpose = TensorPurpose.Scratch
240
241 if scratch_tensor is None:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200242 OperatorError(op, "Scratch tensor not found.")
Tim Hallc8310b12020-06-17 14:53:11 +0100243
Tim Hall79d07d22020-04-27 18:20:16 +0100244 break
Tim Hallc8310b12020-06-17 14:53:11 +0100245
Tim Hall79d07d22020-04-27 18:20:16 +0100246 return op
247
248 for sg in nng.subgraphs:
Patrik Gustavsson3010d9b2020-10-01 08:22:10 +0200249 sg = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [rewrite_mark_tensor_purpose])
Tim Hall79d07d22020-04-27 18:20:16 +0100250 for tens in sg.output_tensors:
251 mark_tensor_helper(tens, TensorPurpose.FeatureMap)
252
253 if verbose_tensor_purpose:
254 nng.print_graph_with_tensors()
255
256 return nng
257
258
Tim Hall79d07d22020-04-27 18:20:16 +0100259def mark_tensor_format(nng, arch, verbose_tensor_format=False):
260 formats_for_tensor = {}
261
262 def init_tens(tens):
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200263 if tens.purpose in (TensorPurpose.FeatureMap, TensorPurpose.LUT):
Tim Hall79d07d22020-04-27 18:20:16 +0100264 fmt = arch.default_feature_map_format
265 elif tens.purpose == TensorPurpose.Weights:
266 fmt = arch.default_weight_format
Tim Hallc8310b12020-06-17 14:53:11 +0100267 elif tens.purpose == TensorPurpose.Scratch:
268 fmt = arch.default_feature_map_format
Tim Hall465582c2020-05-26 09:33:14 +0100269 elif tens.purpose == TensorPurpose.Unknown:
270 fmt = TensorFormat.Unknown
Tim Hall79d07d22020-04-27 18:20:16 +0100271 else:
272 assert 0, "unknown tensor purpose %s" % (tens.purpose,)
273 return fmt
274
Tim Hall79d07d22020-04-27 18:20:16 +0100275 def visit_tens(tens, ps):
Diego Russoea6111a2020-04-14 18:41:58 +0100276 if tens not in formats_for_tensor:
Tim Hall79d07d22020-04-27 18:20:16 +0100277 fmt = init_tens(tens)
278 else:
279 fmt = formats_for_tensor[tens]
280
281 formats_for_tensor[tens] = fmt
282
283 for sg in nng.subgraphs:
284 for ps in sg.passes:
285 for tens in ps.outputs:
286 visit_tens(tens, ps)
287 for tens in ps.intermediates:
288 visit_tens(tens, ps)
289 for tens in ps.inputs:
290 visit_tens(tens, ps)
291
292 for tens, fmt in formats_for_tensor.items():
Fredrik Svedberg0f98b362020-09-29 10:00:39 +0200293 if len(tens.shape) > 4:
294 continue
Tim Hall79d07d22020-04-27 18:20:16 +0100295 tens.set_format(fmt, arch)
296 if fmt == TensorFormat.WeightsCompressed and tens.values is not None:
Louis Verhaard3c07c972020-05-07 08:12:58 +0200297 src_tens = tens.get_dma_src_tensor()
298 if src_tens is not None:
Louis Verhaardb2fb2122020-06-04 15:51:24 +0200299 op = tens.find_npu_op()
Dwight Lidman940fdee2020-08-13 13:11:48 +0200300 if op is not None:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200301 weight_compressor.compress_weights(
302 arch, nng, tens, op.type.npu_block_type, 16, 16, op.get_dilation_h_w()
303 )
Dwight Lidman940fdee2020-08-13 13:11:48 +0200304 # Alias compressed weights back into source tensor
305 src_tens.copy_compressed_weight_info(tens)
Tim Hall79d07d22020-04-27 18:20:16 +0100306
307 if verbose_tensor_format:
308 nng.print_passes_with_tensors()