blob: 2d1245b03ab883daae72c5d1356aedfcc3ccc16b [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of the TOSA based network graph, using the rewrite_graph module to do the traversal of the graph.
Patrik Gustavssonf366fb12021-09-07 13:30:29 +020018import numpy as np
19
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020020from . import rewrite_graph
21from .api import NpuRoundingMode
22from .data_type import DataType
23from .debug_database import DebugDatabase
Patrik Gustavssondf995102021-08-23 15:33:59 +020024from .graph_optimiser_util import bypass_reshape_and_squeeze_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020025from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020026from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020027from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020028from .graph_optimiser_util import needed_total_padding
29from .graph_optimiser_util import set_ifm_ofm_op_shapes
30from .graph_optimiser_util import set_tensor_equivalence
31from .operation import ExplicitScaling
32from .operation import NpuBlockType
33from .operation import Op
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020034from .operation_util import create_add_nop
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020035from .operation_util import create_avgpool_nop
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020036from .shape4d import Shape4D
37from .tensor import create_const_tensor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020038
39
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020040def replace_rescale_with_avg_pool(rescale_op):
41 assert rescale_op.type == Op.Rescale
42
43 avgpool_op = create_avgpool_nop(rescale_op.name + "_avgpool")
44 rescale_op_clone = rescale_op.clone()
45 op = rescale_op
46 op.attrs = avgpool_op.attrs.copy()
47 op.type = Op.AvgPool
48 DebugDatabase.add_optimised(rescale_op_clone, op)
49
50 return op
51
52
53def calc_skirt(kernel, input_shape, explicit_padding):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020054 k_w, k_h = kernel.dilated_wh()
55 s_x, s_y = kernel.stride
56 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
57 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020058
59 top, left, bottom, right = explicit_padding
60 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
61 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020062
63 padding = (top_pad, left_pad, bottom_pad, right_pad)
64 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
65 return padding, skirt
66
67
68def add_padding_fields(op, arch, nng):
69 if op.run_on_npu:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020070 if "explicit_padding" in op.attrs:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020071 input_shape = op.ifm_shapes[0]
72
73 if op.type == Op.Conv2DBackpropInputSwitchedBias:
74 # TODO not yet supported, but there will be need for separate handling
75 assert False
76 else:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020077 padding, skirt = calc_skirt(op.kernel, input_shape, op.attrs.get("explicit_padding"))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020078
79 op.attrs["explicit_padding"] = padding
80 op.attrs["skirt"] = skirt
81
82 return op
83
84
Patrik Gustavssonf366fb12021-09-07 13:30:29 +020085# Counts leading zeroes for a (int32)
86def count_leading_zeros(a):
87 lz = int(32)
88 if a != 0:
89 mask = 1 << (32 - 1)
90 lz = 0
91 while (mask & a) == 0:
92 mask = mask >> 1
93 lz = lz + 1
94 return lz
95
96
97def calc_scaling_avgpool(op, arch, nng):
98 if op.type == Op.AvgPool:
99 top, left, _, _ = op.attrs["explicit_padding"]
100 # TODO Only support for when global scaling can be used.
101 # That is when there is no padding
102 assert top == 0 and left == 0
103 assert op.explicit_scaling is None
104 multiplier = []
105 shift = []
106
107 kernel_wh = op.kernel.elements_wh()
108 k = 32 - count_leading_zeros(kernel_wh - 1)
109 numerator = np.int64(((1 << 30) + 1) << k)
110 multiplier.append(numerator // kernel_wh)
111 shift.append(30 + k)
112
113 op.rounding_mode = NpuRoundingMode.NATURAL
114 op.explicit_scaling = ExplicitScaling(False, shift, multiplier)
115 return op
116
117
Patrik Gustavssondf995102021-08-23 15:33:59 +0200118def remove_const_transpose(op, arch, nng):
119 if op.type == Op.Transpose:
120 removed = False
121 if len(op.ifm.ops) == 1:
122 prev_op = op.ifm.ops[0]
123 if prev_op.type == Op.Const:
124 # Transpose the Tensor and data and remove Transpose
125 # TODO move to Tensor?
126 reorder = op.attrs["perms"]
127 shape = op.ifm.shape.copy()
128 tens = op.ifm
129
130 tens.shape = [shape[idx] for idx in reorder]
131 tens.bandwidth_shape = tens.shape
132 tens.storage_shape = tens.shape
133
134 if tens.values is not None:
135 tens.values = tens.values.transpose(reorder)
136
137 op.ofm.values = tens.values
138 # Bypass the Transpose op
139 prev_op.set_output_tensor(op.ofm)
140 DebugDatabase.add_optimised(op, prev_op)
141 removed = True
142
143 if not removed:
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200144 print("Warning: Cannot remove Transpose, and handling of Transpose is not supported")
Patrik Gustavssondf995102021-08-23 15:33:59 +0200145 assert False
146
147 return op
148
149
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200150# TODO can we change to add for both TFLite and TOSA?
151def insert_add_copy_op_after_tens(tens):
152 tens_cons_list_copy = tens.consumer_list.copy()
153 copy_tens = tens.clone()
154
155 name = tens.name + "_add"
156 ifm2 = create_const_tensor(
157 name + "_zero_scalar",
158 [1],
159 copy_tens.dtype,
160 [0],
161 copy_tens.dtype.as_numpy_type(),
162 quantization=copy_tens.quantization,
163 )
164 copy_op = create_add_nop(name)
165 copy_op.add_input_tensor(tens)
166 copy_op.add_input_tensor(ifm2)
167 copy_op.set_output_tensor(copy_tens)
168 copy_op.set_ifm_ofm_shapes()
169 copy_op.run_on_npu = True
170
171 # Set copy_ifm consumers
172 for tens_cons in tens_cons_list_copy:
173 if tens_cons is not None:
174 for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
175 if cons_inp == tens:
176 tens_cons.set_input_tensor(copy_tens, ifm_idx)
177
178 DebugDatabase.add_optimised(tens.ops[0], copy_op)
179
180
181def fix_sg_input_output_tosa(op, arch, nng):
182 if not op.run_on_npu or op.type != Op.Reshape:
183 return op
184
185 # For the Reshape operators we want to remove, tensors are removed.
186 # But in order to to do this, they cannot be outputs of the sg,
187 # this need to be fixed prior to the removal.
188 # Solution is to add a copy op, to maintain the original tensor.
189 # This is also valid when reshape ifm/ofm is produced respectively
190 # consumed by CPU
191
192 # Check if operator ifm/ofm are sg ifm/ofm
193 ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
194 ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
195 ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
196 # Check if ifm/ofm is produced repectivly consumed by CPU
197 ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
198 ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
199
200 if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
201 # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape
202 insert_add_copy_op_after_tens(op.ifm)
203
204 return op
205
206
207def create_add_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
208 """Creates an add op for the given concat op/input feature map"""
209 ofm = concat_op.ofm
210 ifm2 = create_const_tensor(
211 name + "_zero_scalar", [1], ofm.dtype, [0], ofm.dtype.as_numpy_type(), quantization=ofm.quantization
212 )
213 add_op = create_add_nop(name)
214
215 add_op.inputs = [ifm, ifm2]
216 add_op.outputs = [ofm]
217 add_op.write_offset = write_offset
218 add_op.write_shape = ifm_shape
219 ofm.ops.append(add_op)
220 DebugDatabase.add_optimised(concat_op, add_op)
221 add_op.ifm_shapes.append(ifm_shape)
222 add_op.ifm_shapes.append(Shape4D(ifm2.shape))
223 add_op.ofm_shapes.append(concat_op.ofm_shapes[0])
224 add_op.memory_function = Op.ConcatSliceWrite
225 return add_op
226
227
228# TODO Could be further optimized checking the type of the consumer,
229# rather than just mimic the TFLite behaviour depending on type.
230# TOSA bool_t not considered yet
231def remove_splitsliceread(op, arch):
232
233 if op.type == Op.SplitSliceRead:
234 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
235 if (
236 len(op.ofm.consumer_list) == 1
237 and op.ofm.consumer_list[0] is not None
238 and op.ofm.consumer_list[0].run_on_npu
239 and op.ofm.consumer_list[0].type != Op.Reshape
240 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
241 and op.ofm.dtype in (DataType.uint8, DataType.int8, DataType.int16)
242 ):
243 # SplitSliceRead can be performed by tensor consumer
244 cons_op = op.ofm.consumer_list[0]
245 move_splitsliceread_to_consumer(op, cons_op)
246 else:
247 name = op.name + "_add"
248 ofm = op.ofm
249 ifm2 = create_const_tensor(
250 name + "_zero_scalar", [1], ofm.dtype, [0], ofm.dtype.as_numpy_type(), quantization=ofm.quantization
251 )
252 add_op = create_add_nop(name)
253 add_op.inputs = [op.ifm, ifm2]
254 add_op.outputs = [ofm]
255 op.ofm.ops.remove(op)
256 op.ofm.ops.append(add_op)
257 add_op.ifm_shapes.append(op.ifm_shapes[0])
258 add_op.ifm_shapes.append(Shape4D(ifm2.shape))
259 add_op.ofm_shapes.append(op.ofm_shapes[0])
260 add_op.read_offsets[0] = op.read_offsets[0]
261 add_op.read_shapes[0] = op.read_shapes[0]
262
263 op.ifm.consumer_list.remove(op)
264 DebugDatabase.add_optimised(op, add_op)
265
266
267def rewrite_concat_ops(op, arch):
268 if not op.run_on_npu or not op.type == Op.Concat:
269 return
270
271 axis_4D = 0
272 ofm = op.ofm
273 ofm.ops = []
274 offset = 0
275
276 inputs = op.inputs
277 axis = op.attrs["axis"]
278
279 for idx, inp in enumerate(inputs):
280 op.ifm_shapes[idx] = Shape4D(inp.shape)
281 if axis >= 0:
282 axis_4D = axis + (4 - len(inp.shape))
283 else:
284 axis_4D = axis
285 write_offset = [0, 0, 0, 0]
286 write_offset[axis_4D] = offset
287 concat_end = offset + op.ifm_shapes[idx][axis_4D]
288 create_add_for_concat(op, op.name + str(idx) + "_add", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset))
289 offset = concat_end
290 assert ofm.shape[axis] == offset
291
292 return op
293
294
Patrik Gustavssondf995102021-08-23 15:33:59 +0200295def remove_reshapes(op, arch):
296 if op.run_on_npu and op.type == Op.Reshape:
297 bypass_reshape_and_squeeze_ops(op)
298
299
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200300def rewrite_activation(op, arch, nng):
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200301 if op.type not in (Op.ReluN, Op.Clamp):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200302 return op
303
304 ifm = op.ifm
305 prev_op = ifm.ops[0]
306
307 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
308 fuseable = (
309 prev_op.run_on_npu
310 and prev_op.type.npu_block_type != NpuBlockType.Default
311 and len(ifm.ops) == 1
312 and len(prev_op.outputs[0].consumers()) == 1
313 and prev_op.activation is None
314 )
315 if not fuseable:
316 print("Warning: relu like op will not be possible to fuse, currently not supported")
317 assert False
318
319 zp = ifm.quantization.zero_point if ifm.quantization.zero_point else 0
320 if op.ofm.quantization.zero_point is None:
321 op.ofm.quantization.zero_point = zp
322
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200323 if op.type == Op.Clamp:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200324 op.attrs["min"] = op.attrs["min_int"] - zp
325 op.attrs["max"] = op.attrs["max_int"] - zp
326 elif op.type == Op.ReluN:
327 op.attrs["max"] = op.attrs["max_int"] - zp
328 else:
329 print("Warning: Unknown TOSA activation Op")
330 assert False
331
332 return op
333
334
335def rewrite_rescale(op, arch, nng):
336 if op.type == Op.Rescale:
337 ifm = op.ifm
338 ofm = op.ofm
339
340 # some error checking
341 assert len(ifm.ops) == 1
342 prev_op = ifm.ops[0]
343
344 # TODO currently not supported
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200345 assert len(ifm.consumer_list) == 1
346
347 input_zp = op.attrs["input_zp"]
348 output_zp = op.attrs["output_zp"]
349 multiplier = op.attrs["multiplier"]
350 shift = op.attrs["shift"]
351 scale32 = op.attrs["scale32"]
352 double_round = op.attrs["double_round"]
353 per_channel = op.attrs["per_channel"]
354
355 assert ifm.dtype in (DataType.uint8, DataType.int8, DataType.int32)
356 assert ifm.dtype in (DataType.uint8, DataType.int8) or input_zp == 0
357 assert ofm.dtype in (DataType.uint8, DataType.int8) or output_zp == 0
358 assert (scale32 and ifm.dtype != DataType.int48) or (not scale32 and not double_round)
359
360 # Check that input tensor has the same zp or no zp
361 ifm_zp = ifm.quantization.zero_point
362 if ifm_zp is not None and ifm_zp != input_zp:
363 print("Error (fuse_rescale): zp of tensors producer/consumer differs unexpectedidly ")
364 assert False
365 ifm.quantization.zero_point = input_zp
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200366 ofm.quantization.zero_point = output_zp
367 for s, m in zip(shift, multiplier):
368 # TODO these are the TOSA limitations
369 assert m >= 0
370 assert 2 <= s <= 62
371 # TODO these are the HW limitations
372 assert 0 <= s < (1 << 6)
373 explicit_scaling = ExplicitScaling(per_channel, shift, multiplier)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200374
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200375 if double_round and scale32:
376 rounding_mode = NpuRoundingMode.TFL
377 else:
378 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200379
380 if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected:
381 assert len(multiplier) == len(shift) == len(prev_op.bias.values)
382
383 if ifm.dtype == DataType.int32 and per_channel:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200384 prev_op.explicit_scaling = explicit_scaling
385 prev_op.rounding_mode = rounding_mode
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200386
387 # Bypass op
388 prev_op.set_output_tensor(ofm)
389 DebugDatabase.add_optimised(op, prev_op)
390 return op
391 else:
392 print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
393 assert False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200394 # TODO which are the cases we need to and can do standalone Rescale?
395 # TODO should we try to identify a conversion uint8<->int8 accomplished by 2 RESCALE ops?
396 # origin might be TFLite op QUANTIZE, should we look to see if they can be translated to QUANTIZE?
397 # limited to these at the moment:
398 elif (
399 (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8)
400 or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8)
401 or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8)
402 ):
403 # Create NOP performing the RESCALE
404 avgpool_op = replace_rescale_with_avg_pool(op)
405 avgpool_op.rounding_mode = rounding_mode
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200406
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200407 if per_channel:
408 # TODO
409 avgpool_op.explicit_scaling = explicit_scaling
410 print("Warning, unsupported TOSA Rescale")
411 assert False
412 else:
413 avgpool_op.explicit_scaling = explicit_scaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200414 else:
415 print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
416 assert False
417 return op
418
419
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200420def fixup_quantization(op, arch, nng):
421 if op.ifm and op.ifm.quantization.zero_point is None:
422 op.ifm.quantization.zero_point = 0
423 if op.ifm2 and op.ifm2.quantization.zero_point is None:
424 op.ifm.quantization.zero_point = 0
425 if op.ofm and op.ofm.quantization.zero_point is None:
426 op.ofm.quantization.zero_point = 0
427 return op
428
429
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200430def supported_operator_check(op, arch, nng):
431 op.run_on_npu = arch.tosa_supported_operators.is_operator_supported(op)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200432 assert op.run_on_npu or op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200433 return op
434
435
436def tosa_optimise_graph(nng, arch):
437 # Pre-processing step
438 pre_process_list = [
439 supported_operator_check,
440 set_ifm_ofm_op_shapes,
441 ]
442
443 for idx, sg in enumerate(nng.subgraphs):
444 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
445 nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
446 )
447
Patrik Gustavssondf995102021-08-23 15:33:59 +0200448 # Removal of Transpose
449 for idx, sg in enumerate(nng.subgraphs):
450 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
451 nng, sg, arch, [], [remove_const_transpose], rewrite_unsupported=False,
452 )
453
454 # Handle sg input output
455 for idx, sg in enumerate(nng.subgraphs):
456 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200457 nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=False,
Patrik Gustavssondf995102021-08-23 15:33:59 +0200458 )
459
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200460 # Rewrite concat ops
461 for idx, sg in enumerate(nng.subgraphs):
462 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
463 sg.refresh_after_modification()
464
Patrik Gustavssondf995102021-08-23 15:33:59 +0200465 # Removal of reshapes
466 for sg in nng.subgraphs:
467 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
468 sg.refresh_after_modification()
469
Patrik Gustavssonf366fb12021-09-07 13:30:29 +0200470 # TODO, when and where to best handle calc_scaling_avgpool
471 for idx, sg in enumerate(nng.subgraphs):
472 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
473 nng, sg, arch, [], [calc_scaling_avgpool], rewrite_unsupported=False,
474 )
475
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200476 # Rewite Operators step
Patrik Gustavssondf995102021-08-23 15:33:59 +0200477 op_rewrite_list = [set_tensor_equivalence, rewrite_rescale, convert_depthwise_to_conv]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200478
479 for idx, sg in enumerate(nng.subgraphs):
480 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
481 nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
482 )
483
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200484 # Post-processing step 1
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200485 for idx, sg in enumerate(nng.subgraphs):
486 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
487 nng, sg, arch, [], [rewrite_activation, add_padding_fields],
488 )
489
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200490 # Removal of Slice, need to be done after optimisation has been performed,
491 # since ifm/ofm_shapes are of importance to this function
492 for sg in nng.subgraphs:
493 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_splitsliceread])
494 sg.refresh_after_modification()
495
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200496 # Post-processing step 2
497 for idx, sg in enumerate(nng.subgraphs):
498 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [fixup_quantization],)
499
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200500 return nng