blob: f3cddadd05f3cd5344c33fc61a656c3c4e84ccb2 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of the TOSA based network graph, using the rewrite_graph module to do the traversal of the graph.
18from . import rewrite_graph
19from .api import NpuRoundingMode
20from .data_type import DataType
21from .debug_database import DebugDatabase
Patrik Gustavssondf995102021-08-23 15:33:59 +020022from .graph_optimiser_util import bypass_reshape_and_squeeze_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020023from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020024from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020025from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020026from .graph_optimiser_util import needed_total_padding
27from .graph_optimiser_util import set_ifm_ofm_op_shapes
28from .graph_optimiser_util import set_tensor_equivalence
29from .operation import ExplicitScaling
30from .operation import NpuBlockType
31from .operation import Op
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020032from .operation_util import create_add_nop
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020033from .operation_util import create_avgpool_nop
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020034from .shape4d import Shape4D
35from .tensor import create_const_tensor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020036
37
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020038def replace_rescale_with_avg_pool(rescale_op):
39 assert rescale_op.type == Op.Rescale
40
41 avgpool_op = create_avgpool_nop(rescale_op.name + "_avgpool")
42 rescale_op_clone = rescale_op.clone()
43 op = rescale_op
44 op.attrs = avgpool_op.attrs.copy()
45 op.type = Op.AvgPool
46 DebugDatabase.add_optimised(rescale_op_clone, op)
47
48 return op
49
50
51def calc_skirt(kernel, input_shape, explicit_padding):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020052 k_w, k_h = kernel.dilated_wh()
53 s_x, s_y = kernel.stride
54 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
55 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020056
57 top, left, bottom, right = explicit_padding
58 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
59 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020060
61 padding = (top_pad, left_pad, bottom_pad, right_pad)
62 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
63 return padding, skirt
64
65
66def add_padding_fields(op, arch, nng):
67 if op.run_on_npu:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020068 if "explicit_padding" in op.attrs:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020069 input_shape = op.ifm_shapes[0]
70
71 if op.type == Op.Conv2DBackpropInputSwitchedBias:
72 # TODO not yet supported, but there will be need for separate handling
73 assert False
74 else:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020075 padding, skirt = calc_skirt(op.kernel, input_shape, op.attrs.get("explicit_padding"))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020076
77 op.attrs["explicit_padding"] = padding
78 op.attrs["skirt"] = skirt
79
80 return op
81
82
Patrik Gustavssondf995102021-08-23 15:33:59 +020083def remove_const_transpose(op, arch, nng):
84 if op.type == Op.Transpose:
85 removed = False
86 if len(op.ifm.ops) == 1:
87 prev_op = op.ifm.ops[0]
88 if prev_op.type == Op.Const:
89 # Transpose the Tensor and data and remove Transpose
90 # TODO move to Tensor?
91 reorder = op.attrs["perms"]
92 shape = op.ifm.shape.copy()
93 tens = op.ifm
94
95 tens.shape = [shape[idx] for idx in reorder]
96 tens.bandwidth_shape = tens.shape
97 tens.storage_shape = tens.shape
98
99 if tens.values is not None:
100 tens.values = tens.values.transpose(reorder)
101
102 op.ofm.values = tens.values
103 # Bypass the Transpose op
104 prev_op.set_output_tensor(op.ofm)
105 DebugDatabase.add_optimised(op, prev_op)
106 removed = True
107
108 if not removed:
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200109 print("Warning: Cannot remove Transpose, and handling of Transpose is not supported")
Patrik Gustavssondf995102021-08-23 15:33:59 +0200110 assert False
111
112 return op
113
114
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200115# TODO can we change to add for both TFLite and TOSA?
116def insert_add_copy_op_after_tens(tens):
117 tens_cons_list_copy = tens.consumer_list.copy()
118 copy_tens = tens.clone()
119
120 name = tens.name + "_add"
121 ifm2 = create_const_tensor(
122 name + "_zero_scalar",
123 [1],
124 copy_tens.dtype,
125 [0],
126 copy_tens.dtype.as_numpy_type(),
127 quantization=copy_tens.quantization,
128 )
129 copy_op = create_add_nop(name)
130 copy_op.add_input_tensor(tens)
131 copy_op.add_input_tensor(ifm2)
132 copy_op.set_output_tensor(copy_tens)
133 copy_op.set_ifm_ofm_shapes()
134 copy_op.run_on_npu = True
135
136 # Set copy_ifm consumers
137 for tens_cons in tens_cons_list_copy:
138 if tens_cons is not None:
139 for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
140 if cons_inp == tens:
141 tens_cons.set_input_tensor(copy_tens, ifm_idx)
142
143 DebugDatabase.add_optimised(tens.ops[0], copy_op)
144
145
146def fix_sg_input_output_tosa(op, arch, nng):
147 if not op.run_on_npu or op.type != Op.Reshape:
148 return op
149
150 # For the Reshape operators we want to remove, tensors are removed.
151 # But in order to to do this, they cannot be outputs of the sg,
152 # this need to be fixed prior to the removal.
153 # Solution is to add a copy op, to maintain the original tensor.
154 # This is also valid when reshape ifm/ofm is produced respectively
155 # consumed by CPU
156
157 # Check if operator ifm/ofm are sg ifm/ofm
158 ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
159 ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
160 ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
161 # Check if ifm/ofm is produced repectivly consumed by CPU
162 ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
163 ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
164
165 if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
166 # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape
167 insert_add_copy_op_after_tens(op.ifm)
168
169 return op
170
171
172def create_add_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
173 """Creates an add op for the given concat op/input feature map"""
174 ofm = concat_op.ofm
175 ifm2 = create_const_tensor(
176 name + "_zero_scalar", [1], ofm.dtype, [0], ofm.dtype.as_numpy_type(), quantization=ofm.quantization
177 )
178 add_op = create_add_nop(name)
179
180 add_op.inputs = [ifm, ifm2]
181 add_op.outputs = [ofm]
182 add_op.write_offset = write_offset
183 add_op.write_shape = ifm_shape
184 ofm.ops.append(add_op)
185 DebugDatabase.add_optimised(concat_op, add_op)
186 add_op.ifm_shapes.append(ifm_shape)
187 add_op.ifm_shapes.append(Shape4D(ifm2.shape))
188 add_op.ofm_shapes.append(concat_op.ofm_shapes[0])
189 add_op.memory_function = Op.ConcatSliceWrite
190 return add_op
191
192
193# TODO Could be further optimized checking the type of the consumer,
194# rather than just mimic the TFLite behaviour depending on type.
195# TOSA bool_t not considered yet
196def remove_splitsliceread(op, arch):
197
198 if op.type == Op.SplitSliceRead:
199 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
200 if (
201 len(op.ofm.consumer_list) == 1
202 and op.ofm.consumer_list[0] is not None
203 and op.ofm.consumer_list[0].run_on_npu
204 and op.ofm.consumer_list[0].type != Op.Reshape
205 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
206 and op.ofm.dtype in (DataType.uint8, DataType.int8, DataType.int16)
207 ):
208 # SplitSliceRead can be performed by tensor consumer
209 cons_op = op.ofm.consumer_list[0]
210 move_splitsliceread_to_consumer(op, cons_op)
211 else:
212 name = op.name + "_add"
213 ofm = op.ofm
214 ifm2 = create_const_tensor(
215 name + "_zero_scalar", [1], ofm.dtype, [0], ofm.dtype.as_numpy_type(), quantization=ofm.quantization
216 )
217 add_op = create_add_nop(name)
218 add_op.inputs = [op.ifm, ifm2]
219 add_op.outputs = [ofm]
220 op.ofm.ops.remove(op)
221 op.ofm.ops.append(add_op)
222 add_op.ifm_shapes.append(op.ifm_shapes[0])
223 add_op.ifm_shapes.append(Shape4D(ifm2.shape))
224 add_op.ofm_shapes.append(op.ofm_shapes[0])
225 add_op.read_offsets[0] = op.read_offsets[0]
226 add_op.read_shapes[0] = op.read_shapes[0]
227
228 op.ifm.consumer_list.remove(op)
229 DebugDatabase.add_optimised(op, add_op)
230
231
232def rewrite_concat_ops(op, arch):
233 if not op.run_on_npu or not op.type == Op.Concat:
234 return
235
236 axis_4D = 0
237 ofm = op.ofm
238 ofm.ops = []
239 offset = 0
240
241 inputs = op.inputs
242 axis = op.attrs["axis"]
243
244 for idx, inp in enumerate(inputs):
245 op.ifm_shapes[idx] = Shape4D(inp.shape)
246 if axis >= 0:
247 axis_4D = axis + (4 - len(inp.shape))
248 else:
249 axis_4D = axis
250 write_offset = [0, 0, 0, 0]
251 write_offset[axis_4D] = offset
252 concat_end = offset + op.ifm_shapes[idx][axis_4D]
253 create_add_for_concat(op, op.name + str(idx) + "_add", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset))
254 offset = concat_end
255 assert ofm.shape[axis] == offset
256
257 return op
258
259
Patrik Gustavssondf995102021-08-23 15:33:59 +0200260def remove_reshapes(op, arch):
261 if op.run_on_npu and op.type == Op.Reshape:
262 bypass_reshape_and_squeeze_ops(op)
263
264
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200265def rewrite_activation(op, arch, nng):
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200266 if op.type not in (Op.ReluN, Op.Clamp):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200267 return op
268
269 ifm = op.ifm
270 prev_op = ifm.ops[0]
271
272 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
273 fuseable = (
274 prev_op.run_on_npu
275 and prev_op.type.npu_block_type != NpuBlockType.Default
276 and len(ifm.ops) == 1
277 and len(prev_op.outputs[0].consumers()) == 1
278 and prev_op.activation is None
279 )
280 if not fuseable:
281 print("Warning: relu like op will not be possible to fuse, currently not supported")
282 assert False
283
284 zp = ifm.quantization.zero_point if ifm.quantization.zero_point else 0
285 if op.ofm.quantization.zero_point is None:
286 op.ofm.quantization.zero_point = zp
287
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200288 if op.type == Op.Clamp:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200289 op.attrs["min"] = op.attrs["min_int"] - zp
290 op.attrs["max"] = op.attrs["max_int"] - zp
291 elif op.type == Op.ReluN:
292 op.attrs["max"] = op.attrs["max_int"] - zp
293 else:
294 print("Warning: Unknown TOSA activation Op")
295 assert False
296
297 return op
298
299
300def rewrite_rescale(op, arch, nng):
301 if op.type == Op.Rescale:
302 ifm = op.ifm
303 ofm = op.ofm
304
305 # some error checking
306 assert len(ifm.ops) == 1
307 prev_op = ifm.ops[0]
308
309 # TODO currently not supported
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200310 assert len(ifm.consumer_list) == 1
311
312 input_zp = op.attrs["input_zp"]
313 output_zp = op.attrs["output_zp"]
314 multiplier = op.attrs["multiplier"]
315 shift = op.attrs["shift"]
316 scale32 = op.attrs["scale32"]
317 double_round = op.attrs["double_round"]
318 per_channel = op.attrs["per_channel"]
319
320 assert ifm.dtype in (DataType.uint8, DataType.int8, DataType.int32)
321 assert ifm.dtype in (DataType.uint8, DataType.int8) or input_zp == 0
322 assert ofm.dtype in (DataType.uint8, DataType.int8) or output_zp == 0
323 assert (scale32 and ifm.dtype != DataType.int48) or (not scale32 and not double_round)
324
325 # Check that input tensor has the same zp or no zp
326 ifm_zp = ifm.quantization.zero_point
327 if ifm_zp is not None and ifm_zp != input_zp:
328 print("Error (fuse_rescale): zp of tensors producer/consumer differs unexpectedidly ")
329 assert False
330 ifm.quantization.zero_point = input_zp
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200331 ofm.quantization.zero_point = output_zp
332 for s, m in zip(shift, multiplier):
333 # TODO these are the TOSA limitations
334 assert m >= 0
335 assert 2 <= s <= 62
336 # TODO these are the HW limitations
337 assert 0 <= s < (1 << 6)
338 explicit_scaling = ExplicitScaling(per_channel, shift, multiplier)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200339
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200340 if double_round and scale32:
341 rounding_mode = NpuRoundingMode.TFL
342 else:
343 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200344
345 if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected:
346 assert len(multiplier) == len(shift) == len(prev_op.bias.values)
347
348 if ifm.dtype == DataType.int32 and per_channel:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200349 prev_op.explicit_scaling = explicit_scaling
350 prev_op.rounding_mode = rounding_mode
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200351
352 # Bypass op
353 prev_op.set_output_tensor(ofm)
354 DebugDatabase.add_optimised(op, prev_op)
355 return op
356 else:
357 print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
358 assert False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200359 # TODO which are the cases we need to and can do standalone Rescale?
360 # TODO should we try to identify a conversion uint8<->int8 accomplished by 2 RESCALE ops?
361 # origin might be TFLite op QUANTIZE, should we look to see if they can be translated to QUANTIZE?
362 # limited to these at the moment:
363 elif (
364 (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8)
365 or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8)
366 or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8)
367 ):
368 # Create NOP performing the RESCALE
369 avgpool_op = replace_rescale_with_avg_pool(op)
370 avgpool_op.rounding_mode = rounding_mode
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200371
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200372 if per_channel:
373 # TODO
374 avgpool_op.explicit_scaling = explicit_scaling
375 print("Warning, unsupported TOSA Rescale")
376 assert False
377 else:
378 avgpool_op.explicit_scaling = explicit_scaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200379 else:
380 print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
381 assert False
382 return op
383
384
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200385def fixup_quantization(op, arch, nng):
386 if op.ifm and op.ifm.quantization.zero_point is None:
387 op.ifm.quantization.zero_point = 0
388 if op.ifm2 and op.ifm2.quantization.zero_point is None:
389 op.ifm.quantization.zero_point = 0
390 if op.ofm and op.ofm.quantization.zero_point is None:
391 op.ofm.quantization.zero_point = 0
392 return op
393
394
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200395def supported_operator_check(op, arch, nng):
396 op.run_on_npu = arch.tosa_supported_operators.is_operator_supported(op)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200397 assert op.run_on_npu or op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200398 return op
399
400
401def tosa_optimise_graph(nng, arch):
402 # Pre-processing step
403 pre_process_list = [
404 supported_operator_check,
405 set_ifm_ofm_op_shapes,
406 ]
407
408 for idx, sg in enumerate(nng.subgraphs):
409 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
410 nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
411 )
412
Patrik Gustavssondf995102021-08-23 15:33:59 +0200413 # Removal of Transpose
414 for idx, sg in enumerate(nng.subgraphs):
415 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
416 nng, sg, arch, [], [remove_const_transpose], rewrite_unsupported=False,
417 )
418
419 # Handle sg input output
420 for idx, sg in enumerate(nng.subgraphs):
421 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200422 nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=False,
Patrik Gustavssondf995102021-08-23 15:33:59 +0200423 )
424
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200425 # Rewrite concat ops
426 for idx, sg in enumerate(nng.subgraphs):
427 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
428 sg.refresh_after_modification()
429
Patrik Gustavssondf995102021-08-23 15:33:59 +0200430 # Removal of reshapes
431 for sg in nng.subgraphs:
432 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
433 sg.refresh_after_modification()
434
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200435 # Rewite Operators step
Patrik Gustavssondf995102021-08-23 15:33:59 +0200436 op_rewrite_list = [set_tensor_equivalence, rewrite_rescale, convert_depthwise_to_conv]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200437
438 for idx, sg in enumerate(nng.subgraphs):
439 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
440 nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
441 )
442
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200443 # Post-processing step 1
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200444 for idx, sg in enumerate(nng.subgraphs):
445 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
446 nng, sg, arch, [], [rewrite_activation, add_padding_fields],
447 )
448
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200449 # Removal of Slice, need to be done after optimisation has been performed,
450 # since ifm/ofm_shapes are of importance to this function
451 for sg in nng.subgraphs:
452 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_splitsliceread])
453 sg.refresh_after_modification()
454
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200455 # Post-processing step 2
456 for idx, sg in enumerate(nng.subgraphs):
457 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [fixup_quantization],)
458
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200459 return nng