blob: d32955d5e0d91b54d86b13e870dfdf53224338cc [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of the TOSA based network graph, using the rewrite_graph module to do the traversal of the graph.
Patrik Gustavssonf366fb12021-09-07 13:30:29 +020018import numpy as np
19
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020020from . import rewrite_graph
21from .api import NpuRoundingMode
22from .data_type import DataType
23from .debug_database import DebugDatabase
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020024from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020025from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020026from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020027from .graph_optimiser_util import convert_to_lut
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020028from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020029from .graph_optimiser_util import needed_total_padding
30from .graph_optimiser_util import set_ifm_ofm_op_shapes
31from .graph_optimiser_util import set_tensor_equivalence
32from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020033from .operation import Op
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020034from .operation_util import create_add_nop
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020035from .operation_util import create_avgpool_nop
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020036from .shape4d import Shape4D
37from .tensor import create_const_tensor
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +020038from .tensor import create_equivalence_id
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020039
40
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020041def replace_rescale_with_avg_pool(rescale_op):
42 assert rescale_op.type == Op.Rescale
43
44 avgpool_op = create_avgpool_nop(rescale_op.name + "_avgpool")
45 rescale_op_clone = rescale_op.clone()
46 op = rescale_op
47 op.attrs = avgpool_op.attrs.copy()
48 op.type = Op.AvgPool
49 DebugDatabase.add_optimised(rescale_op_clone, op)
50
51 return op
52
53
54def calc_skirt(kernel, input_shape, explicit_padding):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020055 k_w, k_h = kernel.dilated_wh()
56 s_x, s_y = kernel.stride
57 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
58 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020059
60 top, left, bottom, right = explicit_padding
61 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
62 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020063
64 padding = (top_pad, left_pad, bottom_pad, right_pad)
65 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
66 return padding, skirt
67
68
69def add_padding_fields(op, arch, nng):
70 if op.run_on_npu:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020071 if "explicit_padding" in op.attrs:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020072 input_shape = op.ifm_shapes[0]
73
74 if op.type == Op.Conv2DBackpropInputSwitchedBias:
75 # TODO not yet supported, but there will be need for separate handling
76 assert False
77 else:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020078 padding, skirt = calc_skirt(op.kernel, input_shape, op.attrs.get("explicit_padding"))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020079
80 op.attrs["explicit_padding"] = padding
81 op.attrs["skirt"] = skirt
82
83 return op
84
85
Patrik Gustavssonf366fb12021-09-07 13:30:29 +020086# Counts leading zeroes for a (int32)
87def count_leading_zeros(a):
88 lz = int(32)
89 if a != 0:
90 mask = 1 << (32 - 1)
91 lz = 0
92 while (mask & a) == 0:
93 mask = mask >> 1
94 lz = lz + 1
95 return lz
96
97
98def calc_scaling_avgpool(op, arch, nng):
99 if op.type == Op.AvgPool:
100 top, left, _, _ = op.attrs["explicit_padding"]
101 # TODO Only support for when global scaling can be used.
102 # That is when there is no padding
103 assert top == 0 and left == 0
104 assert op.explicit_scaling is None
105 multiplier = []
106 shift = []
107
108 kernel_wh = op.kernel.elements_wh()
109 k = 32 - count_leading_zeros(kernel_wh - 1)
110 numerator = np.int64(((1 << 30) + 1) << k)
111 multiplier.append(numerator // kernel_wh)
112 shift.append(30 + k)
113
114 op.rounding_mode = NpuRoundingMode.NATURAL
115 op.explicit_scaling = ExplicitScaling(False, shift, multiplier)
116 return op
117
118
Patrik Gustavssondf995102021-08-23 15:33:59 +0200119def remove_const_transpose(op, arch, nng):
120 if op.type == Op.Transpose:
121 removed = False
122 if len(op.ifm.ops) == 1:
123 prev_op = op.ifm.ops[0]
124 if prev_op.type == Op.Const:
125 # Transpose the Tensor and data and remove Transpose
126 # TODO move to Tensor?
127 reorder = op.attrs["perms"]
128 shape = op.ifm.shape.copy()
129 tens = op.ifm
130
131 tens.shape = [shape[idx] for idx in reorder]
132 tens.bandwidth_shape = tens.shape
133 tens.storage_shape = tens.shape
134
135 if tens.values is not None:
136 tens.values = tens.values.transpose(reorder)
137
138 op.ofm.values = tens.values
139 # Bypass the Transpose op
140 prev_op.set_output_tensor(op.ofm)
141 DebugDatabase.add_optimised(op, prev_op)
142 removed = True
143
144 if not removed:
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200145 print("Warning: Cannot remove Transpose, and handling of Transpose is not supported")
Patrik Gustavssondf995102021-08-23 15:33:59 +0200146 assert False
147
148 return op
149
150
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200151# TODO can we change to add for both TFLite and TOSA?
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200152def insert_add_copy_op_after_tens(tens, ifm_ofm_shape):
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200153 tens_cons_list_copy = tens.consumer_list.copy()
154 copy_tens = tens.clone()
155
156 name = tens.name + "_add"
157 ifm2 = create_const_tensor(
158 name + "_zero_scalar",
159 [1],
160 copy_tens.dtype,
161 [0],
162 copy_tens.dtype.as_numpy_type(),
163 quantization=copy_tens.quantization,
164 )
165 copy_op = create_add_nop(name)
166 copy_op.add_input_tensor(tens)
167 copy_op.add_input_tensor(ifm2)
168 copy_op.set_output_tensor(copy_tens)
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200169 copy_op.ifm_shapes.append(ifm_ofm_shape)
170 copy_op.ifm_shapes.append(Shape4D(ifm2.shape))
171 copy_op.ofm_shapes.append(ifm_ofm_shape)
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200172 copy_op.run_on_npu = True
173
174 # Set copy_ifm consumers
175 for tens_cons in tens_cons_list_copy:
176 if tens_cons is not None:
177 for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
178 if cons_inp == tens:
179 tens_cons.set_input_tensor(copy_tens, ifm_idx)
180
181 DebugDatabase.add_optimised(tens.ops[0], copy_op)
182
183
184def fix_sg_input_output_tosa(op, arch, nng):
185 if not op.run_on_npu or op.type != Op.Reshape:
186 return op
187
188 # For the Reshape operators we want to remove, tensors are removed.
189 # But in order to to do this, they cannot be outputs of the sg,
190 # this need to be fixed prior to the removal.
191 # Solution is to add a copy op, to maintain the original tensor.
192 # This is also valid when reshape ifm/ofm is produced respectively
193 # consumed by CPU
194
195 # Check if operator ifm/ofm are sg ifm/ofm
196 ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
197 ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
198 ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
199 # Check if ifm/ofm is produced repectivly consumed by CPU
200 ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
201 ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
202
203 if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
204 # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200205
206 # Decide on ifm/ofm shapes for the copy op based on ifm
207 shape = op.ifm.shape.copy()
208 # remove dimensions that are set to 1
209 new_shape = []
210 for dim in shape:
211 if dim != 1:
212 new_shape.append(dim)
213 if not new_shape:
214 new_shape = [1]
215
216 rank = len(new_shape)
217 if rank > 3:
218 # Reshape so that batch becomes 1, by moving elements to H dimension
219 n = rank - 2
220 h = 1
221 for i in range(n):
222 h *= shape[i]
223 new_shape = Shape4D(new_shape[n:]).with_height(h)
224 else:
225 new_shape = Shape4D(new_shape)
226
227 insert_add_copy_op_after_tens(op.ifm, new_shape)
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200228
229 return op
230
231
232def create_add_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
233 """Creates an add op for the given concat op/input feature map"""
234 ofm = concat_op.ofm
235 ifm2 = create_const_tensor(
236 name + "_zero_scalar", [1], ofm.dtype, [0], ofm.dtype.as_numpy_type(), quantization=ofm.quantization
237 )
238 add_op = create_add_nop(name)
239
240 add_op.inputs = [ifm, ifm2]
241 add_op.outputs = [ofm]
242 add_op.write_offset = write_offset
243 add_op.write_shape = ifm_shape
244 ofm.ops.append(add_op)
245 DebugDatabase.add_optimised(concat_op, add_op)
246 add_op.ifm_shapes.append(ifm_shape)
247 add_op.ifm_shapes.append(Shape4D(ifm2.shape))
248 add_op.ofm_shapes.append(concat_op.ofm_shapes[0])
249 add_op.memory_function = Op.ConcatSliceWrite
250 return add_op
251
252
253# TODO Could be further optimized checking the type of the consumer,
254# rather than just mimic the TFLite behaviour depending on type.
255# TOSA bool_t not considered yet
256def remove_splitsliceread(op, arch):
257
258 if op.type == Op.SplitSliceRead:
259 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
260 if (
261 len(op.ofm.consumer_list) == 1
262 and op.ofm.consumer_list[0] is not None
263 and op.ofm.consumer_list[0].run_on_npu
264 and op.ofm.consumer_list[0].type != Op.Reshape
265 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
266 and op.ofm.dtype in (DataType.uint8, DataType.int8, DataType.int16)
267 ):
268 # SplitSliceRead can be performed by tensor consumer
269 cons_op = op.ofm.consumer_list[0]
270 move_splitsliceread_to_consumer(op, cons_op)
271 else:
272 name = op.name + "_add"
273 ofm = op.ofm
274 ifm2 = create_const_tensor(
275 name + "_zero_scalar", [1], ofm.dtype, [0], ofm.dtype.as_numpy_type(), quantization=ofm.quantization
276 )
277 add_op = create_add_nop(name)
278 add_op.inputs = [op.ifm, ifm2]
279 add_op.outputs = [ofm]
280 op.ofm.ops.remove(op)
281 op.ofm.ops.append(add_op)
282 add_op.ifm_shapes.append(op.ifm_shapes[0])
283 add_op.ifm_shapes.append(Shape4D(ifm2.shape))
284 add_op.ofm_shapes.append(op.ofm_shapes[0])
285 add_op.read_offsets[0] = op.read_offsets[0]
286 add_op.read_shapes[0] = op.read_shapes[0]
287
288 op.ifm.consumer_list.remove(op)
289 DebugDatabase.add_optimised(op, add_op)
290
291
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200292def rewrite_concat(op):
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200293 if not op.run_on_npu or not op.type == Op.Concat:
294 return
295
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200296 offset = 0
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200297 inputs = op.inputs
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200298 axis_4D = op.attrs["axis4D"]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200299
300 for idx, inp in enumerate(inputs):
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200301 write_offset = [0, 0, 0, 0]
302 write_offset[axis_4D] = offset
303 concat_end = offset + op.ifm_shapes[idx][axis_4D]
304 create_add_for_concat(op, op.name + str(idx) + "_add", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset))
305 offset = concat_end
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200306 assert op.ofm_shapes[0][axis_4D] == offset
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200307
308
Patrik Gustavssondf995102021-08-23 15:33:59 +0200309def remove_reshapes(op, arch):
310 if op.run_on_npu and op.type == Op.Reshape:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200311 bypass_memory_only_ops(op)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200312
313
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200314def rewrite_activation(op, arch, nng):
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200315 if op.type not in (Op.ReluN, Op.Clamp):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200316 return op
317
318 ifm = op.ifm
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200319 zp = ifm.quantization.zero_point if ifm.quantization.zero_point else 0
320 if op.ofm.quantization.zero_point is None:
321 op.ofm.quantization.zero_point = zp
322
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200323 if op.type == Op.Clamp:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200324 op.attrs["min"] = op.attrs["min_int"] - zp
325 op.attrs["max"] = op.attrs["max_int"] - zp
326 elif op.type == Op.ReluN:
327 op.attrs["max"] = op.attrs["max_int"] - zp
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200328
329 return op
330
331
332def rewrite_rescale(op, arch, nng):
333 if op.type == Op.Rescale:
334 ifm = op.ifm
335 ofm = op.ofm
336
337 # some error checking
338 assert len(ifm.ops) == 1
339 prev_op = ifm.ops[0]
340
341 # TODO currently not supported
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200342 assert len(ifm.consumer_list) == 1
343
344 input_zp = op.attrs["input_zp"]
345 output_zp = op.attrs["output_zp"]
346 multiplier = op.attrs["multiplier"]
347 shift = op.attrs["shift"]
348 scale32 = op.attrs["scale32"]
349 double_round = op.attrs["double_round"]
350 per_channel = op.attrs["per_channel"]
351
352 assert ifm.dtype in (DataType.uint8, DataType.int8, DataType.int32)
353 assert ifm.dtype in (DataType.uint8, DataType.int8) or input_zp == 0
354 assert ofm.dtype in (DataType.uint8, DataType.int8) or output_zp == 0
355 assert (scale32 and ifm.dtype != DataType.int48) or (not scale32 and not double_round)
356
357 # Check that input tensor has the same zp or no zp
358 ifm_zp = ifm.quantization.zero_point
359 if ifm_zp is not None and ifm_zp != input_zp:
360 print("Error (fuse_rescale): zp of tensors producer/consumer differs unexpectedidly ")
361 assert False
362 ifm.quantization.zero_point = input_zp
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200363 ofm.quantization.zero_point = output_zp
364 for s, m in zip(shift, multiplier):
365 # TODO these are the TOSA limitations
366 assert m >= 0
367 assert 2 <= s <= 62
368 # TODO these are the HW limitations
369 assert 0 <= s < (1 << 6)
370 explicit_scaling = ExplicitScaling(per_channel, shift, multiplier)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200371
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200372 if double_round and scale32:
373 rounding_mode = NpuRoundingMode.TFL
374 else:
375 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200376
377 if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected:
378 assert len(multiplier) == len(shift) == len(prev_op.bias.values)
379
380 if ifm.dtype == DataType.int32 and per_channel:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200381 prev_op.explicit_scaling = explicit_scaling
382 prev_op.rounding_mode = rounding_mode
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200383
384 # Bypass op
385 prev_op.set_output_tensor(ofm)
386 DebugDatabase.add_optimised(op, prev_op)
387 return op
388 else:
389 print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
390 assert False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200391 # TODO which are the cases we need to and can do standalone Rescale?
392 # TODO should we try to identify a conversion uint8<->int8 accomplished by 2 RESCALE ops?
393 # origin might be TFLite op QUANTIZE, should we look to see if they can be translated to QUANTIZE?
394 # limited to these at the moment:
395 elif (
396 (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8)
397 or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8)
398 or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8)
399 ):
400 # Create NOP performing the RESCALE
401 avgpool_op = replace_rescale_with_avg_pool(op)
402 avgpool_op.rounding_mode = rounding_mode
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200403
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200404 if per_channel:
405 # TODO
406 avgpool_op.explicit_scaling = explicit_scaling
407 print("Warning, unsupported TOSA Rescale")
408 assert False
409 else:
410 avgpool_op.explicit_scaling = explicit_scaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200411 else:
412 print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
413 assert False
414 return op
415
416
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200417# TODO modified copy of TFLite, solution for TOSA PAD will change so reuse has not been considered
418def convert_pad(op, arch, nng):
419 """
420 Rewrites PAD operator to an add that copies the IFM to the OFM
421 + up to 4 add operators that fill the OFM with zeros at the borders.
422 """
423
424 if op.type != Op.Pad:
425 return op
426
427 # TODO assuming rank <= 4 and N = 1 for rank ==4
428 # This is checked in tosa_supported_operators
429 ifm = op.ifm
430 assert ifm is not None
431 ifm_shape = Shape4D(ifm.shape)
432 ofm = op.ofm
433 assert ofm is not None
434 ofm.ops = []
435 ofm_shape = op.ofm_shapes[0]
436
437 rank = len(ifm.shape)
438 padding = op.inputs[1].values
439 pad_depth = padding[-1]
440 if not (pad_depth == 0).all():
441 print("Warning: For PAD, padding in depth not supported yet")
442 assert False
443
444 top, bottom = 0, 0
445 left, right = 0, 0
446 if rank > 1:
447 left, right = padding[-2][0], padding[-2][1]
448 if rank > 2:
449 top, bottom = padding[-3][0], padding[-3][1]
450 if rank == 4 and not (padding[-4] == 0).all():
451 print("Warning: For PAD, padding not supported in first dimension when rank == 4 yet")
452 assert False
453
454 # Add op that copies IFM to the right place inside the OFM
455 shp0 = Shape4D(0, 0, 0, 0)
456 shp_top = shp0.with_height(top)
457 add_op = create_add_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
458 add_op.activation = op.activation
459
460 quant = ofm.quantization
461 pad_value = ifm.quantization.zero_point
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200462 ifm.quantization.zero_point = 0
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200463 # Add operations that fill the borders of the OFM
464 if top > 0:
465 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
466 zero_tens = create_const_tensor(
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200467 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant,
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200468 )
469 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
470 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
471 create_add_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
472 if bottom > 0:
473 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
474 zero_tens = create_const_tensor(
475 op.name + "_bottom",
476 shape.as_list(),
477 ofm.dtype,
478 shape.elements() * [pad_value],
479 np.uint8,
480 quantization=quant,
481 )
482 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
483 create_add_for_concat(op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom))
484 if left > 0:
485 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
486 zero_tens = create_const_tensor(
487 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
488 )
489 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
490 create_add_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
491 if right > 0:
492 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
493 zero_tens = create_const_tensor(
494 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
495 )
496 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
497 create_add_for_concat(op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right))
498
499 op.type = Op.ConcatTFLite
500 return add_op
501
502
Patrik Gustavssonf436ada2021-09-14 14:56:48 +0200503def convert_table_to_lut(op, arch, nng):
504 # Converts table op to a no-op + LUT
505 if op.type is not Op.Table:
506 return op
507
508 table = op.inputs[1]
509 op.inputs.remove(table)
510 op.set_ifm_ofm_shapes()
511
512 return convert_to_lut(op, table.values, "table")
513
514
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200515def decompose_elem_tensors_hwc(op):
516 """
517 Decomposes elementwise op if any of the ifm(s)/ofm are to large in any dimension to be handled by the NPU
518 """
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200519 max_t_size = 65535
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200520 ofm_shape = op.write_shape if op.write_shape is not None else op.ofm_shapes[0]
521 ifm_shape = op.read_shapes[0] if op.read_shapes[0] is not None else op.ifm_shapes[0]
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200522 ifm2_shape = op.ifm_shapes[1] if op.ifm_shapes[1] else None
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200523 ifm2_shape = op.read_shapes[1] if op.read_shapes[1] is not None else ifm2_shape
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200524 limit_shape = Shape4D(1, max_t_size, max_t_size, max_t_size)
525
526 if any(dim_size > max_t_size for dim_size in ofm_shape.as_list()):
527 ofm_split = ofm_shape.floordiv_const(max_t_size).add(1, 1, 1, 1)
528
529 for height in range(ofm_split.height):
530 for width in range(ofm_split.width):
531 for depth in range(ofm_split.depth):
532 ofm_offset = Shape4D(0, height * max_t_size, width * max_t_size, depth * max_t_size)
533 ofm_part_shape = ofm_shape.clip(ofm_offset, limit_shape)
534 ofm_cut = (ofm_offset, ofm_part_shape)
535
536 ifm_d = depth * max_t_size if ifm_shape.depth == ofm_shape.depth else 0
537 ifm_w = width * max_t_size if ifm_shape.width == ofm_shape.width else 0
538 ifm_h = height * max_t_size if ifm_shape.height == ofm_shape.height else 0
539 ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d)
540 ifm_part_shape = ifm_shape.clip(ifm_offset, limit_shape)
541 ifm_cut = (ifm_offset, ifm_part_shape)
542
543 if ifm2_shape is not None:
544 ifm2_d = depth * max_t_size if ifm2_shape.depth == ofm_shape.depth else 0
545 ifm2_w = width * max_t_size if ifm2_shape.width == ofm_shape.width else 0
546 ifm2_h = height * max_t_size if ifm2_shape.height == ofm_shape.height else 0
547 ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d)
548 ifm2_part_shape = ifm2_shape.clip(ifm2_offset, limit_shape)
549 ifm2_cut = (ifm2_offset, ifm2_part_shape)
550 else:
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200551 ifm2_cut = (None, None)
552
553 create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut)
554 op.ofm.ops.remove(op)
555 op.ifm.consumer_list.remove(op)
556 if op.ifm2 is not None:
557 op.ifm2.consumer_list.remove(op)
558 return
559
560
561def create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut):
562 part_op = op.clone()
563 ifm_read_offset = op.read_offsets[0] if op.read_offsets[0] is not None else Shape4D(0, 0, 0, 0)
564 ofm_write_offset = op.write_offset if op.write_offset is not None else Shape4D(0, 0, 0, 0)
565 ifm_offset, ifm_shape = ifm_cut
566 ofm_offset, ofm_shape = ofm_cut
567
568 part_op.read_offsets[0] = ifm_read_offset + ifm_offset
569 part_op.read_shapes[0] = ifm_shape
570 part_op.write_offset = ofm_write_offset + ofm_offset
571 part_op.write_shape = ofm_shape
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200572 part_op.ifm_shapes = op.ifm_shapes.copy()
573 part_op.ofm_shapes = op.ofm_shapes.copy()
574 part_op.ifm.consumer_list.append(part_op)
575 op.ofm.ops.append(part_op)
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200576
577 ifm2_offset, ifm2_shape = ifm2_cut
578 if ifm2_offset:
579 ifm2_read_offset = op.read_offsets[1] if op.read_offsets[1] is not None else Shape4D(0, 0, 0, 0)
580 part_op.read_offsets[1] = ifm2_read_offset + ifm2_offset
581 part_op.read_shapes[1] = ifm2_shape
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200582 part_op.ifm2.consumer_list.append(part_op)
583
584
585def get_nhwc_stride(shape):
586 stride_x = shape.depth
587 stride_y = shape.width * stride_x
588 stride_n = shape.height * stride_y
589 return Shape4D(stride_n, stride_y, stride_x, 1)
590
591
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200592def pad_to_rank(shape, rank):
593 """
594 Pads a shape to the given rank
595 """
596 while len(shape) < rank:
597 shape = [1] + shape
598
599 return shape
600
601
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200602def get_elem_shapes_removed_singles(op):
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200603 """
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200604 Returns the shapes of ifm(s)/ofms after removing all the dimensions that are 1 for all ifm(s)/ofm
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200605 """
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200606 binary = op.ifm2 is not None
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200607 ofm_shape = op.ofm_shapes[0].as_list() if len(op.ofm_shapes) > 0 else op.ofm.shape
608 ifm_shape = op.ifm_shapes[0].as_list() if len(op.ifm_shapes) > 0 else op.ifm.shape
609 if binary:
610 ifm2_shape = op.ifm_shapes[1].as_list() if len(op.ofm_shapes) else op.ifm2.shape
611
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200612 rank = max(len(ofm_shape), len(ifm_shape), len(ifm2_shape) if binary else 0)
613 ofm_shape = pad_to_rank(ofm_shape, rank)
614 ifm_shape = pad_to_rank(ifm_shape, rank)
615 if binary:
616 ifm2_shape = pad_to_rank(ifm2_shape, rank)
617
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200618 new_ofm_shape = []
619 new_ifm_shape = []
620 new_ifm2_shape = []
621 for idx in range(rank):
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200622 if ofm_shape[idx] != 1:
623 new_ofm_shape.append(ofm_shape[idx])
624 new_ifm_shape.append(ifm_shape[idx])
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200625 if binary:
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200626 new_ifm2_shape.append(ifm2_shape[idx])
627
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200628 if new_ofm_shape == []:
629 new_ofm_shape = [1]
630 new_ifm_shape = [1]
631 new_ifm2_shape = [1] if binary else None
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200632
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200633 return new_ofm_shape, new_ifm_shape, new_ifm2_shape
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200634
635
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200636def decomp_dims_elementwise(op):
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200637 """
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200638 Decompose elementwise ops with Rank > 3 (H,W,D).
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200639 If Rank > 3, all the dimensions above H are viewed as the N dimension.
640 the elementwise operation will be decomposed to N (of ofm) elementwise operations.
641 By reading and writing with offsets from/to the ifm(s)/ofm.
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200642 Note: Broadcast need to be handled for binary elementwise ops, and TOSA allowes for broadcast by both ifm and ifm2
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200643 """
644
645 ifm = op.ifm
646 ifm2 = op.ifm2
647 ofm = op.ofm
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200648 binary = op.ifm2 is not None
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200649
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200650 # Remove dimensions that are all 1
651 new_ofm_shape, new_ifm_shape, new_ifm2_shape = get_elem_shapes_removed_singles(op)
652 rank = len(new_ofm_shape)
653
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200654 if rank > 3:
655 n = rank - 3
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200656 ofm_decomp_shape = Shape4D(new_ofm_shape[0:n])
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200657 ofm_decomp_stride = get_nhwc_stride(ofm_decomp_shape)
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200658 ofm_part_shape = Shape4D(new_ofm_shape[n:])
659 op.ofm_shapes.append(Shape4D([ofm_decomp_shape.elements()] + new_ofm_shape[n:]))
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200660
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200661 if binary:
662 ifm_decomp_shape = Shape4D(new_ifm_shape[0:n])
663 ifm2_decomp_shape = Shape4D(new_ifm2_shape[0:n])
664 ifm_decomp_stride = get_nhwc_stride(ifm_decomp_shape)
665 ifm2_decomp_stride = get_nhwc_stride(ifm2_decomp_shape)
666 ifm_part_shape = Shape4D(new_ifm_shape[n:])
667 ifm2_part_shape = Shape4D(new_ifm2_shape[n:])
668 op.ifm_shapes.append(Shape4D([ifm_decomp_shape.elements()] + new_ifm_shape[n:]))
669 op.ifm_shapes.append(Shape4D([ifm2_decomp_shape.elements()] + new_ifm2_shape[n:]))
670 else:
671 op.ifm_shapes.append(Shape4D([ofm_decomp_shape.elements()] + new_ofm_shape[n:]))
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200672
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200673 op_list = []
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200674 for height in range(ofm_decomp_shape.height):
675 for width in range(ofm_decomp_shape.width):
676 for depth in range(ofm_decomp_shape.depth):
677 ofm_offset = Shape4D(0, height, width, depth)
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200678 ofm_offset = Shape4D(ofm_offset.dot_prod(ofm_decomp_stride), 0, 0, 0)
679 ofm_cut = (ofm_offset, ofm_part_shape)
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200680
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200681 if binary:
682 ifm_d = depth if ifm_decomp_shape.depth == ofm_decomp_shape.depth else 0
683 ifm_w = width if ifm_decomp_shape.width == ofm_decomp_shape.width else 0
684 ifm_h = height if ifm_decomp_shape.height == ofm_decomp_shape.height else 0
685 ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d)
686 ifm_offset = Shape4D(ifm_offset.dot_prod(ifm_decomp_stride), 0, 0, 0)
687 ifm_cut = (ifm_offset, ifm_part_shape)
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200688
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200689 ifm2_d = depth if ifm2_decomp_shape.depth == ofm_decomp_shape.depth else 0
690 ifm2_w = width if ifm2_decomp_shape.width == ofm_decomp_shape.width else 0
691 ifm2_h = height if ifm2_decomp_shape.height == ofm_decomp_shape.height else 0
692 ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d)
693 ifm2_offset = Shape4D(ifm2_offset.dot_prod(ifm2_decomp_stride), 0, 0, 0)
694 ifm2_cut = (ifm2_offset, ifm2_part_shape)
695 op_list.append(create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut))
696 else:
697 op_list.append(create_elem_part_op(op, ofm_cut, None, ofm_cut))
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200698
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200699 ofm.ops.remove(op)
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200700 ifm.consumer_list.remove(op)
701 if binary:
702 ifm2.consumer_list.remove(op)
703 else:
704 op.ofm_shapes.append(Shape4D(new_ofm_shape))
705 op.ifm_shapes.append(Shape4D(new_ifm_shape))
706 op.ifm_shapes.append(Shape4D(new_ifm2_shape))
707
708 return [op]
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200709
710
711def decomp_elementwise(tens, arch, nng):
712 """
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200713 Decompose elementwise ops with Rank > 3 (H,W,C).
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200714 Decompose size of tensors exceeding NPU max size
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200715 """
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200716 tens_ops = tens.ops.copy()
717 for op in tens_ops:
718 if op.type.is_elementwise_op():
719 decomp_list = decomp_dims_elementwise(op)
720 for part_op in decomp_list:
721 decompose_elem_tensors_hwc(part_op)
722 return tens
723
724
725def reshape_concat_shape(shape, rank, axis):
726 new_h = 1
727 for i in range(axis):
728 new_h *= shape[i]
729 new_c = 1
730 for i in range(axis + 1, rank):
731 new_c *= shape[i]
732 if axis == (rank - 1):
733 new_shape = [new_h, shape[axis], 1]
734 else:
735 new_shape = [new_h, shape[axis], new_c]
736 return new_shape
737
738
739def reshape_concat(op):
740 """
741 Reshapes concat ops with Rank > 3 (H,W,C).
742 """
743 ofm = op.ofm
744 rank = len(ofm.shape)
745 axis = op.attrs["axis"]
746 if axis < 0:
747 axis += rank
748
749 if rank > 3:
750 # Reshape so that axis in to be concatenated is the W dimension
751 # Reshape inputs
752 for inp in op.inputs:
753 new_shape = reshape_concat_shape(inp.shape, rank, axis)
754 op.ifm_shapes.append(Shape4D(new_shape))
755 # Reshape output
756 new_shape = reshape_concat_shape(ofm.shape, rank, axis)
757 op.ofm_shapes.append(Shape4D(new_shape))
758 op.attrs["axis4D"] = 2
759 else:
760 for inp in op.inputs:
761 op.ifm_shapes.append(Shape4D(inp.shape))
762 op.ofm_shapes.append(Shape4D(ofm.shape))
763 op.attrs["axis4D"] = axis + (4 - rank)
764
765
766def decomp_rewrite_concat(tens, arch, nng):
767 """
768 Decompose concat ops with Rank > 3 (H,W,C).
769 Rewrite of concat to elementwise operations
770 """
771 if len(tens.ops) == 1 and tens.ops[0].type == Op.Concat:
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200772 op = tens.ops[0]
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200773
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200774 reshape_concat(op)
775 rewrite_concat(op)
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200776
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200777 op.ofm.ops.remove(op)
778 for inp in op.inputs:
779 inp.consumer_list.remove(op)
780
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200781 return tens
782
783
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200784def fixup_quantization(op, arch, nng):
785 if op.ifm and op.ifm.quantization.zero_point is None:
786 op.ifm.quantization.zero_point = 0
787 if op.ifm2 and op.ifm2.quantization.zero_point is None:
Patrik Gustavssonf436ada2021-09-14 14:56:48 +0200788 op.ifm2.quantization.zero_point = 0
789 if not op.forced_output_quantization:
790 if op.ofm and op.ofm.quantization and op.ofm.quantization.zero_point is None:
791 op.ofm.quantization.zero_point = 0
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200792 return op
793
794
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200795def supported_operator_check(op, arch, nng):
796 op.run_on_npu = arch.tosa_supported_operators.is_operator_supported(op)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200797 assert op.run_on_npu or op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200798 return op
799
800
801def tosa_optimise_graph(nng, arch):
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200802
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200803 # TODO the supported operator checking need to be split in semantic and HW checks
804 for idx, sg in enumerate(nng.subgraphs):
805 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
806 nng, sg, arch, [], [supported_operator_check], rewrite_unsupported=False,
807 )
808
809 # Decomposing and rewrite of concat
810 for idx, sg in enumerate(nng.subgraphs):
811 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
812 nng, sg, arch, [decomp_rewrite_concat], [], rewrite_unsupported=False
813 )
814
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200815 # Handle sg input output
816 for idx, sg in enumerate(nng.subgraphs):
817 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
818 nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=False,
819 )
820
821 # Removal of reshapes
822 for sg in nng.subgraphs:
823 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
824 sg.refresh_after_modification()
825
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200826 # Decomposing of elementwise
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200827 for idx, sg in enumerate(nng.subgraphs):
828 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
829 nng, sg, arch, [decomp_elementwise], [], rewrite_unsupported=False
830 )
831
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200832 for idx, sg in enumerate(nng.subgraphs):
833 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200834 nng, sg, arch, [], [set_ifm_ofm_op_shapes], rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200835 )
836
Patrik Gustavssondf995102021-08-23 15:33:59 +0200837 # Removal of Transpose
838 for idx, sg in enumerate(nng.subgraphs):
839 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
840 nng, sg, arch, [], [remove_const_transpose], rewrite_unsupported=False,
841 )
842
Patrik Gustavssonf366fb12021-09-07 13:30:29 +0200843 # TODO, when and where to best handle calc_scaling_avgpool
844 for idx, sg in enumerate(nng.subgraphs):
845 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
846 nng, sg, arch, [], [calc_scaling_avgpool], rewrite_unsupported=False,
847 )
848
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200849 # Rewite Operators step
Patrik Gustavssonf436ada2021-09-14 14:56:48 +0200850 op_rewrite_list = [set_tensor_equivalence, rewrite_rescale, convert_depthwise_to_conv, convert_table_to_lut]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200851
852 for idx, sg in enumerate(nng.subgraphs):
853 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
854 nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
855 )
856
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200857 # Post-processing step 1
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200858 for idx, sg in enumerate(nng.subgraphs):
859 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200860 nng, sg, arch, [], [rewrite_activation, convert_pad, add_padding_fields],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200861 )
862
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200863 # Removal of Slice, need to be done after optimisation has been performed,
864 # since ifm/ofm_shapes are of importance to this function
865 for sg in nng.subgraphs:
866 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_splitsliceread])
867 sg.refresh_after_modification()
868
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200869 # Post-processing step 2
870 for idx, sg in enumerate(nng.subgraphs):
871 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [fixup_quantization],)
872
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200873 return nng