blob: 169da40d66bde7ff57fb75a1620f5aadc9642e13 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of the TOSA based network graph, using the rewrite_graph module to do the traversal of the graph.
18from . import rewrite_graph
19from .api import NpuRoundingMode
20from .data_type import DataType
21from .debug_database import DebugDatabase
Patrik Gustavssondf995102021-08-23 15:33:59 +020022from .graph_optimiser_util import bypass_reshape_and_squeeze_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020023from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020024from .graph_optimiser_util import convert_depthwise_to_conv
25from .graph_optimiser_util import fix_sg_input_output
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020026from .graph_optimiser_util import needed_total_padding
27from .graph_optimiser_util import set_ifm_ofm_op_shapes
28from .graph_optimiser_util import set_tensor_equivalence
29from .operation import ExplicitScaling
30from .operation import NpuBlockType
31from .operation import Op
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020032from .operation_util import create_avgpool_nop
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020033
34
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020035def replace_rescale_with_avg_pool(rescale_op):
36 assert rescale_op.type == Op.Rescale
37
38 avgpool_op = create_avgpool_nop(rescale_op.name + "_avgpool")
39 rescale_op_clone = rescale_op.clone()
40 op = rescale_op
41 op.attrs = avgpool_op.attrs.copy()
42 op.type = Op.AvgPool
43 DebugDatabase.add_optimised(rescale_op_clone, op)
44
45 return op
46
47
48def calc_skirt(kernel, input_shape, explicit_padding):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020049 k_w, k_h = kernel.dilated_wh()
50 s_x, s_y = kernel.stride
51 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
52 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020053
54 top, left, bottom, right = explicit_padding
55 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
56 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020057
58 padding = (top_pad, left_pad, bottom_pad, right_pad)
59 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
60 return padding, skirt
61
62
63def add_padding_fields(op, arch, nng):
64 if op.run_on_npu:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020065 if "explicit_padding" in op.attrs:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020066 input_shape = op.ifm_shapes[0]
67
68 if op.type == Op.Conv2DBackpropInputSwitchedBias:
69 # TODO not yet supported, but there will be need for separate handling
70 assert False
71 else:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020072 padding, skirt = calc_skirt(op.kernel, input_shape, op.attrs.get("explicit_padding"))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020073
74 op.attrs["explicit_padding"] = padding
75 op.attrs["skirt"] = skirt
76
77 return op
78
79
Patrik Gustavssondf995102021-08-23 15:33:59 +020080def remove_const_transpose(op, arch, nng):
81 if op.type == Op.Transpose:
82 removed = False
83 if len(op.ifm.ops) == 1:
84 prev_op = op.ifm.ops[0]
85 if prev_op.type == Op.Const:
86 # Transpose the Tensor and data and remove Transpose
87 # TODO move to Tensor?
88 reorder = op.attrs["perms"]
89 shape = op.ifm.shape.copy()
90 tens = op.ifm
91
92 tens.shape = [shape[idx] for idx in reorder]
93 tens.bandwidth_shape = tens.shape
94 tens.storage_shape = tens.shape
95
96 if tens.values is not None:
97 tens.values = tens.values.transpose(reorder)
98
99 op.ofm.values = tens.values
100 # Bypass the Transpose op
101 prev_op.set_output_tensor(op.ofm)
102 DebugDatabase.add_optimised(op, prev_op)
103 removed = True
104
105 if not removed:
106 print("Cannot remove Transpose, and handling of Transpose is not supported")
107 assert False
108
109 return op
110
111
112def remove_reshapes(op, arch):
113 if op.run_on_npu and op.type == Op.Reshape:
114 bypass_reshape_and_squeeze_ops(op)
115
116
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200117def rewrite_activation(op, arch, nng):
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200118 if op.type not in (Op.ReluN, Op.Clamp):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200119 return op
120
121 ifm = op.ifm
122 prev_op = ifm.ops[0]
123
124 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
125 fuseable = (
126 prev_op.run_on_npu
127 and prev_op.type.npu_block_type != NpuBlockType.Default
128 and len(ifm.ops) == 1
129 and len(prev_op.outputs[0].consumers()) == 1
130 and prev_op.activation is None
131 )
132 if not fuseable:
133 print("Warning: relu like op will not be possible to fuse, currently not supported")
134 assert False
135
136 zp = ifm.quantization.zero_point if ifm.quantization.zero_point else 0
137 if op.ofm.quantization.zero_point is None:
138 op.ofm.quantization.zero_point = zp
139
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200140 if op.type == Op.Clamp:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200141 op.attrs["min"] = op.attrs["min_int"] - zp
142 op.attrs["max"] = op.attrs["max_int"] - zp
143 elif op.type == Op.ReluN:
144 op.attrs["max"] = op.attrs["max_int"] - zp
145 else:
146 print("Warning: Unknown TOSA activation Op")
147 assert False
148
149 return op
150
151
152def rewrite_rescale(op, arch, nng):
153 if op.type == Op.Rescale:
154 ifm = op.ifm
155 ofm = op.ofm
156
157 # some error checking
158 assert len(ifm.ops) == 1
159 prev_op = ifm.ops[0]
160
161 # TODO currently not supported
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200162 assert len(ifm.consumer_list) == 1
163
164 input_zp = op.attrs["input_zp"]
165 output_zp = op.attrs["output_zp"]
166 multiplier = op.attrs["multiplier"]
167 shift = op.attrs["shift"]
168 scale32 = op.attrs["scale32"]
169 double_round = op.attrs["double_round"]
170 per_channel = op.attrs["per_channel"]
171
172 assert ifm.dtype in (DataType.uint8, DataType.int8, DataType.int32)
173 assert ifm.dtype in (DataType.uint8, DataType.int8) or input_zp == 0
174 assert ofm.dtype in (DataType.uint8, DataType.int8) or output_zp == 0
175 assert (scale32 and ifm.dtype != DataType.int48) or (not scale32 and not double_round)
176
177 # Check that input tensor has the same zp or no zp
178 ifm_zp = ifm.quantization.zero_point
179 if ifm_zp is not None and ifm_zp != input_zp:
180 print("Error (fuse_rescale): zp of tensors producer/consumer differs unexpectedidly ")
181 assert False
182 ifm.quantization.zero_point = input_zp
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200183 ofm.quantization.zero_point = output_zp
184 for s, m in zip(shift, multiplier):
185 # TODO these are the TOSA limitations
186 assert m >= 0
187 assert 2 <= s <= 62
188 # TODO these are the HW limitations
189 assert 0 <= s < (1 << 6)
190 explicit_scaling = ExplicitScaling(per_channel, shift, multiplier)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200191
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200192 if double_round and scale32:
193 rounding_mode = NpuRoundingMode.TFL
194 else:
195 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200196
197 if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected:
198 assert len(multiplier) == len(shift) == len(prev_op.bias.values)
199
200 if ifm.dtype == DataType.int32 and per_channel:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200201 prev_op.explicit_scaling = explicit_scaling
202 prev_op.rounding_mode = rounding_mode
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200203
204 # Bypass op
205 prev_op.set_output_tensor(ofm)
206 DebugDatabase.add_optimised(op, prev_op)
207 return op
208 else:
209 print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
210 assert False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200211 # TODO which are the cases we need to and can do standalone Rescale?
212 # TODO should we try to identify a conversion uint8<->int8 accomplished by 2 RESCALE ops?
213 # origin might be TFLite op QUANTIZE, should we look to see if they can be translated to QUANTIZE?
214 # limited to these at the moment:
215 elif (
216 (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8)
217 or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8)
218 or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8)
219 ):
220 # Create NOP performing the RESCALE
221 avgpool_op = replace_rescale_with_avg_pool(op)
222 avgpool_op.rounding_mode = rounding_mode
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200223
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200224 if per_channel:
225 # TODO
226 avgpool_op.explicit_scaling = explicit_scaling
227 print("Warning, unsupported TOSA Rescale")
228 assert False
229 else:
230 avgpool_op.explicit_scaling = explicit_scaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200231 else:
232 print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
233 assert False
234 return op
235
236
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200237def fixup_quantization(op, arch, nng):
238 if op.ifm and op.ifm.quantization.zero_point is None:
239 op.ifm.quantization.zero_point = 0
240 if op.ifm2 and op.ifm2.quantization.zero_point is None:
241 op.ifm.quantization.zero_point = 0
242 if op.ofm and op.ofm.quantization.zero_point is None:
243 op.ofm.quantization.zero_point = 0
244 return op
245
246
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200247def supported_operator_check(op, arch, nng):
248 op.run_on_npu = arch.tosa_supported_operators.is_operator_supported(op)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200249 assert op.run_on_npu or op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200250 return op
251
252
253def tosa_optimise_graph(nng, arch):
254 # Pre-processing step
255 pre_process_list = [
256 supported_operator_check,
257 set_ifm_ofm_op_shapes,
258 ]
259
260 for idx, sg in enumerate(nng.subgraphs):
261 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
262 nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
263 )
264
Patrik Gustavssondf995102021-08-23 15:33:59 +0200265 # Removal of Transpose
266 for idx, sg in enumerate(nng.subgraphs):
267 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
268 nng, sg, arch, [], [remove_const_transpose], rewrite_unsupported=False,
269 )
270
271 # Handle sg input output
272 for idx, sg in enumerate(nng.subgraphs):
273 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
274 nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
275 )
276
277 # Removal of reshapes
278 for sg in nng.subgraphs:
279 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
280 sg.refresh_after_modification()
281
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200282 # Rewite Operators step
Patrik Gustavssondf995102021-08-23 15:33:59 +0200283 op_rewrite_list = [set_tensor_equivalence, rewrite_rescale, convert_depthwise_to_conv]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200284
285 for idx, sg in enumerate(nng.subgraphs):
286 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
287 nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
288 )
289
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200290 # Post-processing step 1
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200291 for idx, sg in enumerate(nng.subgraphs):
292 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
293 nng, sg, arch, [], [rewrite_activation, add_padding_fields],
294 )
295
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200296 # Post-processing step 2
297 for idx, sg in enumerate(nng.subgraphs):
298 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [fixup_quantization],)
299
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200300 return nng