blob: e27dbed6d0a203e4a7404f76f5e284656d057f5f [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of the TOSA based network graph, using the rewrite_graph module to do the traversal of the graph.
Patrik Gustavssonf366fb12021-09-07 13:30:29 +020018import numpy as np
19
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020020from . import rewrite_graph
21from .api import NpuRoundingMode
22from .data_type import DataType
23from .debug_database import DebugDatabase
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020024from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020025from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020026from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020027from .graph_optimiser_util import convert_to_lut
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020028from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020029from .graph_optimiser_util import needed_total_padding
30from .graph_optimiser_util import set_ifm_ofm_op_shapes
31from .graph_optimiser_util import set_tensor_equivalence
32from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020033from .operation import Op
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020034from .operation_util import create_add_nop
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020035from .operation_util import create_avgpool_nop
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +020036from .operation_util import create_pad_nop
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020037from .shape4d import Shape4D
38from .tensor import create_const_tensor
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +020039from .tensor import create_equivalence_id
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +020040from .tensor import shape_num_elements
41from .tensor import Tensor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020042
43
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020044def replace_rescale_with_avg_pool(rescale_op):
45 assert rescale_op.type == Op.Rescale
46
47 avgpool_op = create_avgpool_nop(rescale_op.name + "_avgpool")
48 rescale_op_clone = rescale_op.clone()
49 op = rescale_op
50 op.attrs = avgpool_op.attrs.copy()
51 op.type = Op.AvgPool
52 DebugDatabase.add_optimised(rescale_op_clone, op)
53
54 return op
55
56
57def calc_skirt(kernel, input_shape, explicit_padding):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020058 k_w, k_h = kernel.dilated_wh()
59 s_x, s_y = kernel.stride
60 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
61 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020062
63 top, left, bottom, right = explicit_padding
64 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
65 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020066
67 padding = (top_pad, left_pad, bottom_pad, right_pad)
68 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
69 return padding, skirt
70
71
72def add_padding_fields(op, arch, nng):
73 if op.run_on_npu:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020074 if "explicit_padding" in op.attrs:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020075 input_shape = op.ifm_shapes[0]
76
77 if op.type == Op.Conv2DBackpropInputSwitchedBias:
78 # TODO not yet supported, but there will be need for separate handling
79 assert False
80 else:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020081 padding, skirt = calc_skirt(op.kernel, input_shape, op.attrs.get("explicit_padding"))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020082
83 op.attrs["explicit_padding"] = padding
84 op.attrs["skirt"] = skirt
85
86 return op
87
88
Patrik Gustavssonf366fb12021-09-07 13:30:29 +020089# Counts leading zeroes for a (int32)
90def count_leading_zeros(a):
91 lz = int(32)
92 if a != 0:
93 mask = 1 << (32 - 1)
94 lz = 0
95 while (mask & a) == 0:
96 mask = mask >> 1
97 lz = lz + 1
98 return lz
99
100
101def calc_scaling_avgpool(op, arch, nng):
102 if op.type == Op.AvgPool:
103 top, left, _, _ = op.attrs["explicit_padding"]
104 # TODO Only support for when global scaling can be used.
105 # That is when there is no padding
106 assert top == 0 and left == 0
107 assert op.explicit_scaling is None
108 multiplier = []
109 shift = []
110
111 kernel_wh = op.kernel.elements_wh()
112 k = 32 - count_leading_zeros(kernel_wh - 1)
113 numerator = np.int64(((1 << 30) + 1) << k)
114 multiplier.append(numerator // kernel_wh)
115 shift.append(30 + k)
116
117 op.rounding_mode = NpuRoundingMode.NATURAL
118 op.explicit_scaling = ExplicitScaling(False, shift, multiplier)
119 return op
120
121
Patrik Gustavssondf995102021-08-23 15:33:59 +0200122def remove_const_transpose(op, arch, nng):
123 if op.type == Op.Transpose:
124 removed = False
125 if len(op.ifm.ops) == 1:
126 prev_op = op.ifm.ops[0]
127 if prev_op.type == Op.Const:
128 # Transpose the Tensor and data and remove Transpose
129 # TODO move to Tensor?
130 reorder = op.attrs["perms"]
131 shape = op.ifm.shape.copy()
132 tens = op.ifm
133
134 tens.shape = [shape[idx] for idx in reorder]
135 tens.bandwidth_shape = tens.shape
136 tens.storage_shape = tens.shape
137
138 if tens.values is not None:
139 tens.values = tens.values.transpose(reorder)
140
141 op.ofm.values = tens.values
142 # Bypass the Transpose op
143 prev_op.set_output_tensor(op.ofm)
144 DebugDatabase.add_optimised(op, prev_op)
145 removed = True
146
147 if not removed:
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200148 print("Warning: Cannot remove Transpose, and handling of Transpose is not supported")
Patrik Gustavssondf995102021-08-23 15:33:59 +0200149 assert False
150
151 return op
152
153
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200154# TODO can we change to add for both TFLite and TOSA?
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200155def insert_add_copy_op_after_tens(tens, ifm_ofm_shape):
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200156 tens_cons_list_copy = tens.consumer_list.copy()
157 copy_tens = tens.clone()
158
159 name = tens.name + "_add"
160 ifm2 = create_const_tensor(
161 name + "_zero_scalar",
162 [1],
163 copy_tens.dtype,
164 [0],
165 copy_tens.dtype.as_numpy_type(),
166 quantization=copy_tens.quantization,
167 )
168 copy_op = create_add_nop(name)
169 copy_op.add_input_tensor(tens)
170 copy_op.add_input_tensor(ifm2)
171 copy_op.set_output_tensor(copy_tens)
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200172 copy_op.ifm_shapes.append(ifm_ofm_shape)
173 copy_op.ifm_shapes.append(Shape4D(ifm2.shape))
174 copy_op.ofm_shapes.append(ifm_ofm_shape)
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200175 copy_op.run_on_npu = True
176
177 # Set copy_ifm consumers
178 for tens_cons in tens_cons_list_copy:
179 if tens_cons is not None:
180 for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
181 if cons_inp == tens:
182 tens_cons.set_input_tensor(copy_tens, ifm_idx)
183
184 DebugDatabase.add_optimised(tens.ops[0], copy_op)
185
186
187def fix_sg_input_output_tosa(op, arch, nng):
Patrik Gustavssonef3ebdd2021-10-01 11:10:25 +0200188 if not op.run_on_npu or op.type not in (Op.Reshape, Op.Identity):
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200189 return op
190
191 # For the Reshape operators we want to remove, tensors are removed.
192 # But in order to to do this, they cannot be outputs of the sg,
193 # this need to be fixed prior to the removal.
194 # Solution is to add a copy op, to maintain the original tensor.
195 # This is also valid when reshape ifm/ofm is produced respectively
196 # consumed by CPU
197
198 # Check if operator ifm/ofm are sg ifm/ofm
199 ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
200 ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
201 ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
202 # Check if ifm/ofm is produced repectivly consumed by CPU
203 ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
204 ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
205
206 if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
207 # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200208
209 # Decide on ifm/ofm shapes for the copy op based on ifm
210 shape = op.ifm.shape.copy()
211 # remove dimensions that are set to 1
212 new_shape = []
213 for dim in shape:
214 if dim != 1:
215 new_shape.append(dim)
216 if not new_shape:
217 new_shape = [1]
218
219 rank = len(new_shape)
220 if rank > 3:
221 # Reshape so that batch becomes 1, by moving elements to H dimension
222 n = rank - 2
223 h = 1
224 for i in range(n):
225 h *= shape[i]
226 new_shape = Shape4D(new_shape[n:]).with_height(h)
227 else:
228 new_shape = Shape4D(new_shape)
229
230 insert_add_copy_op_after_tens(op.ifm, new_shape)
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200231
232 return op
233
234
235def create_add_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
236 """Creates an add op for the given concat op/input feature map"""
237 ofm = concat_op.ofm
238 ifm2 = create_const_tensor(
239 name + "_zero_scalar", [1], ofm.dtype, [0], ofm.dtype.as_numpy_type(), quantization=ofm.quantization
240 )
241 add_op = create_add_nop(name)
242
243 add_op.inputs = [ifm, ifm2]
244 add_op.outputs = [ofm]
245 add_op.write_offset = write_offset
246 add_op.write_shape = ifm_shape
247 ofm.ops.append(add_op)
248 DebugDatabase.add_optimised(concat_op, add_op)
249 add_op.ifm_shapes.append(ifm_shape)
250 add_op.ifm_shapes.append(Shape4D(ifm2.shape))
251 add_op.ofm_shapes.append(concat_op.ofm_shapes[0])
252 add_op.memory_function = Op.ConcatSliceWrite
253 return add_op
254
255
256# TODO Could be further optimized checking the type of the consumer,
257# rather than just mimic the TFLite behaviour depending on type.
258# TOSA bool_t not considered yet
259def remove_splitsliceread(op, arch):
260
261 if op.type == Op.SplitSliceRead:
262 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
263 if (
264 len(op.ofm.consumer_list) == 1
265 and op.ofm.consumer_list[0] is not None
266 and op.ofm.consumer_list[0].run_on_npu
267 and op.ofm.consumer_list[0].type != Op.Reshape
268 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
269 and op.ofm.dtype in (DataType.uint8, DataType.int8, DataType.int16)
270 ):
271 # SplitSliceRead can be performed by tensor consumer
272 cons_op = op.ofm.consumer_list[0]
273 move_splitsliceread_to_consumer(op, cons_op)
274 else:
275 name = op.name + "_add"
276 ofm = op.ofm
277 ifm2 = create_const_tensor(
278 name + "_zero_scalar", [1], ofm.dtype, [0], ofm.dtype.as_numpy_type(), quantization=ofm.quantization
279 )
280 add_op = create_add_nop(name)
281 add_op.inputs = [op.ifm, ifm2]
282 add_op.outputs = [ofm]
283 op.ofm.ops.remove(op)
284 op.ofm.ops.append(add_op)
285 add_op.ifm_shapes.append(op.ifm_shapes[0])
286 add_op.ifm_shapes.append(Shape4D(ifm2.shape))
287 add_op.ofm_shapes.append(op.ofm_shapes[0])
288 add_op.read_offsets[0] = op.read_offsets[0]
289 add_op.read_shapes[0] = op.read_shapes[0]
290
291 op.ifm.consumer_list.remove(op)
292 DebugDatabase.add_optimised(op, add_op)
293
294
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200295def rewrite_concat(op):
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200296 if not op.run_on_npu or not op.type == Op.Concat:
297 return
298
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200299 offset = 0
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200300 inputs = op.inputs
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200301 axis_4D = op.attrs["axis4D"]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200302
303 for idx, inp in enumerate(inputs):
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200304 write_offset = [0, 0, 0, 0]
305 write_offset[axis_4D] = offset
306 concat_end = offset + op.ifm_shapes[idx][axis_4D]
307 create_add_for_concat(op, op.name + str(idx) + "_add", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset))
308 offset = concat_end
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200309 assert op.ofm_shapes[0][axis_4D] == offset
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200310
311
Patrik Gustavssonef3ebdd2021-10-01 11:10:25 +0200312def remove_memory_ops(op, arch):
313 if op.run_on_npu and op.type in (Op.Reshape, Op.Identity):
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200314 bypass_memory_only_ops(op)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200315
316
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200317def rewrite_activation(op, arch, nng):
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200318 if op.type not in (Op.ReluN, Op.Clamp):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200319 return op
320
321 ifm = op.ifm
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200322 zp = ifm.quantization.zero_point if ifm.quantization.zero_point else 0
323 if op.ofm.quantization.zero_point is None:
324 op.ofm.quantization.zero_point = zp
325
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200326 if op.type == Op.Clamp:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200327 op.attrs["min"] = op.attrs["min_int"] - zp
328 op.attrs["max"] = op.attrs["max_int"] - zp
329 elif op.type == Op.ReluN:
330 op.attrs["max"] = op.attrs["max_int"] - zp
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200331
332 return op
333
334
335def rewrite_rescale(op, arch, nng):
336 if op.type == Op.Rescale:
337 ifm = op.ifm
338 ofm = op.ofm
339
340 # some error checking
341 assert len(ifm.ops) == 1
342 prev_op = ifm.ops[0]
343
344 # TODO currently not supported
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200345 assert len(ifm.consumer_list) == 1
346
347 input_zp = op.attrs["input_zp"]
348 output_zp = op.attrs["output_zp"]
349 multiplier = op.attrs["multiplier"]
350 shift = op.attrs["shift"]
351 scale32 = op.attrs["scale32"]
352 double_round = op.attrs["double_round"]
353 per_channel = op.attrs["per_channel"]
354
355 assert ifm.dtype in (DataType.uint8, DataType.int8, DataType.int32)
356 assert ifm.dtype in (DataType.uint8, DataType.int8) or input_zp == 0
357 assert ofm.dtype in (DataType.uint8, DataType.int8) or output_zp == 0
358 assert (scale32 and ifm.dtype != DataType.int48) or (not scale32 and not double_round)
359
360 # Check that input tensor has the same zp or no zp
361 ifm_zp = ifm.quantization.zero_point
362 if ifm_zp is not None and ifm_zp != input_zp:
363 print("Error (fuse_rescale): zp of tensors producer/consumer differs unexpectedidly ")
364 assert False
365 ifm.quantization.zero_point = input_zp
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200366 ofm.quantization.zero_point = output_zp
367 for s, m in zip(shift, multiplier):
368 # TODO these are the TOSA limitations
369 assert m >= 0
370 assert 2 <= s <= 62
371 # TODO these are the HW limitations
372 assert 0 <= s < (1 << 6)
373 explicit_scaling = ExplicitScaling(per_channel, shift, multiplier)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200374
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200375 if double_round and scale32:
376 rounding_mode = NpuRoundingMode.TFL
377 else:
378 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200379
380 if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected:
381 assert len(multiplier) == len(shift) == len(prev_op.bias.values)
382
383 if ifm.dtype == DataType.int32 and per_channel:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200384 prev_op.explicit_scaling = explicit_scaling
385 prev_op.rounding_mode = rounding_mode
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200386
387 # Bypass op
388 prev_op.set_output_tensor(ofm)
389 DebugDatabase.add_optimised(op, prev_op)
390 return op
391 else:
392 print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
393 assert False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200394 # TODO which are the cases we need to and can do standalone Rescale?
395 # TODO should we try to identify a conversion uint8<->int8 accomplished by 2 RESCALE ops?
396 # origin might be TFLite op QUANTIZE, should we look to see if they can be translated to QUANTIZE?
397 # limited to these at the moment:
398 elif (
399 (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8)
400 or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8)
401 or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8)
402 ):
403 # Create NOP performing the RESCALE
404 avgpool_op = replace_rescale_with_avg_pool(op)
405 avgpool_op.rounding_mode = rounding_mode
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200406
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200407 if per_channel:
408 # TODO
409 avgpool_op.explicit_scaling = explicit_scaling
410 print("Warning, unsupported TOSA Rescale")
411 assert False
412 else:
413 avgpool_op.explicit_scaling = explicit_scaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200414 else:
415 print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
416 assert False
417 return op
418
419
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200420def convert_pad_in_width(op):
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200421 """
422 Rewrites PAD operator to an add that copies the IFM to the OFM
423 + up to 4 add operators that fill the OFM with zeros at the borders.
424 """
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200425 assert op.type == Op.Pad
426 assert op.ifm_shapes[0] is not None and op.ofm_shapes[0] is not None
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200427 ifm = op.ifm
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200428 ofm = op.ofm
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200429 ifm_shape = op.ifm_shapes[0]
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200430 ofm.ops = []
431 ofm_shape = op.ofm_shapes[0]
432
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200433 padding = op.inputs[1].values
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200434 left, right = padding[-2]
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200435
436 # Add op that copies IFM to the right place inside the OFM
437 shp0 = Shape4D(0, 0, 0, 0)
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200438 add_op = create_add_for_concat(op, op.name + "_main", ifm, ifm_shape, shp0.with_width(left))
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200439 add_op.activation = op.activation
440
441 quant = ofm.quantization
442 pad_value = ifm.quantization.zero_point
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200443 ifm.quantization.zero_point = 0
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200444 if left > 0:
445 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
446 zero_tens = create_const_tensor(
447 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
448 )
449 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200450 create_add_for_concat(op, op.name + "_left", zero_tens, shape, shp0)
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200451 if right > 0:
452 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
453 zero_tens = create_const_tensor(
454 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
455 )
456 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200457 create_add_for_concat(op, op.name + "_right", zero_tens, shape, shp0.with_width(ofm_shape.width - right))
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200458
459 op.type = Op.ConcatTFLite
460 return add_op
461
462
Patrik Gustavssonf436ada2021-09-14 14:56:48 +0200463def convert_table_to_lut(op, arch, nng):
464 # Converts table op to a no-op + LUT
465 if op.type is not Op.Table:
466 return op
467
468 table = op.inputs[1]
469 op.inputs.remove(table)
470 op.set_ifm_ofm_shapes()
471
472 return convert_to_lut(op, table.values, "table")
473
474
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200475def decompose_elem_tensors_hwc(op):
476 """
477 Decomposes elementwise op if any of the ifm(s)/ofm are to large in any dimension to be handled by the NPU
478 """
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200479 max_t_size = 65535
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200480 ofm_shape = op.write_shape if op.write_shape is not None else op.ofm_shapes[0]
481 ifm_shape = op.read_shapes[0] if op.read_shapes[0] is not None else op.ifm_shapes[0]
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200482 ifm2_shape = op.ifm_shapes[1] if op.ifm_shapes[1] else None
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200483 ifm2_shape = op.read_shapes[1] if op.read_shapes[1] is not None else ifm2_shape
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200484 limit_shape = Shape4D(1, max_t_size, max_t_size, max_t_size)
485
486 if any(dim_size > max_t_size for dim_size in ofm_shape.as_list()):
487 ofm_split = ofm_shape.floordiv_const(max_t_size).add(1, 1, 1, 1)
488
489 for height in range(ofm_split.height):
490 for width in range(ofm_split.width):
491 for depth in range(ofm_split.depth):
492 ofm_offset = Shape4D(0, height * max_t_size, width * max_t_size, depth * max_t_size)
493 ofm_part_shape = ofm_shape.clip(ofm_offset, limit_shape)
494 ofm_cut = (ofm_offset, ofm_part_shape)
495
496 ifm_d = depth * max_t_size if ifm_shape.depth == ofm_shape.depth else 0
497 ifm_w = width * max_t_size if ifm_shape.width == ofm_shape.width else 0
498 ifm_h = height * max_t_size if ifm_shape.height == ofm_shape.height else 0
499 ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d)
500 ifm_part_shape = ifm_shape.clip(ifm_offset, limit_shape)
501 ifm_cut = (ifm_offset, ifm_part_shape)
502
503 if ifm2_shape is not None:
504 ifm2_d = depth * max_t_size if ifm2_shape.depth == ofm_shape.depth else 0
505 ifm2_w = width * max_t_size if ifm2_shape.width == ofm_shape.width else 0
506 ifm2_h = height * max_t_size if ifm2_shape.height == ofm_shape.height else 0
507 ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d)
508 ifm2_part_shape = ifm2_shape.clip(ifm2_offset, limit_shape)
509 ifm2_cut = (ifm2_offset, ifm2_part_shape)
510 else:
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200511 ifm2_cut = (None, None)
512
513 create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut)
514 op.ofm.ops.remove(op)
515 op.ifm.consumer_list.remove(op)
516 if op.ifm2 is not None:
517 op.ifm2.consumer_list.remove(op)
518 return
519
520
521def create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut):
522 part_op = op.clone()
523 ifm_read_offset = op.read_offsets[0] if op.read_offsets[0] is not None else Shape4D(0, 0, 0, 0)
524 ofm_write_offset = op.write_offset if op.write_offset is not None else Shape4D(0, 0, 0, 0)
525 ifm_offset, ifm_shape = ifm_cut
526 ofm_offset, ofm_shape = ofm_cut
527
528 part_op.read_offsets[0] = ifm_read_offset + ifm_offset
529 part_op.read_shapes[0] = ifm_shape
530 part_op.write_offset = ofm_write_offset + ofm_offset
531 part_op.write_shape = ofm_shape
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200532 part_op.ifm_shapes = op.ifm_shapes.copy()
533 part_op.ofm_shapes = op.ofm_shapes.copy()
534 part_op.ifm.consumer_list.append(part_op)
535 op.ofm.ops.append(part_op)
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200536
537 ifm2_offset, ifm2_shape = ifm2_cut
538 if ifm2_offset:
539 ifm2_read_offset = op.read_offsets[1] if op.read_offsets[1] is not None else Shape4D(0, 0, 0, 0)
540 part_op.read_offsets[1] = ifm2_read_offset + ifm2_offset
541 part_op.read_shapes[1] = ifm2_shape
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200542 part_op.ifm2.consumer_list.append(part_op)
543
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200544 return part_op
545
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200546
547def get_nhwc_stride(shape):
548 stride_x = shape.depth
549 stride_y = shape.width * stride_x
550 stride_n = shape.height * stride_y
551 return Shape4D(stride_n, stride_y, stride_x, 1)
552
553
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200554def pad_to_rank(shape, rank):
555 """
556 Pads a shape to the given rank
557 """
558 while len(shape) < rank:
559 shape = [1] + shape
560
561 return shape
562
563
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200564def get_elem_shapes_removed_singles(op):
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200565 """
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200566 Returns the shapes of ifm(s)/ofms after removing all the dimensions that are 1 for all ifm(s)/ofm
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200567 """
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200568 binary = op.ifm2 is not None
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200569 ofm_shape = op.ofm_shapes[0].as_list() if len(op.ofm_shapes) > 0 else op.ofm.shape
570 ifm_shape = op.ifm_shapes[0].as_list() if len(op.ifm_shapes) > 0 else op.ifm.shape
571 if binary:
572 ifm2_shape = op.ifm_shapes[1].as_list() if len(op.ofm_shapes) else op.ifm2.shape
573
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200574 rank = max(len(ofm_shape), len(ifm_shape), len(ifm2_shape) if binary else 0)
575 ofm_shape = pad_to_rank(ofm_shape, rank)
576 ifm_shape = pad_to_rank(ifm_shape, rank)
577 if binary:
578 ifm2_shape = pad_to_rank(ifm2_shape, rank)
579
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200580 new_ofm_shape = []
581 new_ifm_shape = []
582 new_ifm2_shape = []
583 for idx in range(rank):
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200584 if ofm_shape[idx] != 1:
585 new_ofm_shape.append(ofm_shape[idx])
586 new_ifm_shape.append(ifm_shape[idx])
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200587 if binary:
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200588 new_ifm2_shape.append(ifm2_shape[idx])
589
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200590 if new_ofm_shape == []:
591 new_ofm_shape = [1]
592 new_ifm_shape = [1]
593 new_ifm2_shape = [1] if binary else None
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200594
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200595 return new_ofm_shape, new_ifm_shape, new_ifm2_shape
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200596
597
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200598def decomp_dims_elementwise(op):
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200599 """
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200600 Decompose elementwise ops with Rank > 3 (H,W,D).
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200601 If Rank > 3, all the dimensions above H are viewed as the N dimension.
602 the elementwise operation will be decomposed to N (of ofm) elementwise operations.
603 By reading and writing with offsets from/to the ifm(s)/ofm.
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200604 Note: Broadcast need to be handled for binary elementwise ops, and TOSA allowes for broadcast by both ifm and ifm2
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200605 """
606
607 ifm = op.ifm
608 ifm2 = op.ifm2
609 ofm = op.ofm
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200610 binary = op.ifm2 is not None
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200611
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200612 # Remove dimensions that are all 1
613 new_ofm_shape, new_ifm_shape, new_ifm2_shape = get_elem_shapes_removed_singles(op)
614 rank = len(new_ofm_shape)
615
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200616 if rank > 3:
617 n = rank - 3
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200618 ofm_decomp_shape = Shape4D(new_ofm_shape[0:n])
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200619 ofm_decomp_stride = get_nhwc_stride(ofm_decomp_shape)
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200620 ofm_part_shape = Shape4D(new_ofm_shape[n:])
621 op.ofm_shapes.append(Shape4D([ofm_decomp_shape.elements()] + new_ofm_shape[n:]))
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200622
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200623 if binary:
624 ifm_decomp_shape = Shape4D(new_ifm_shape[0:n])
625 ifm2_decomp_shape = Shape4D(new_ifm2_shape[0:n])
626 ifm_decomp_stride = get_nhwc_stride(ifm_decomp_shape)
627 ifm2_decomp_stride = get_nhwc_stride(ifm2_decomp_shape)
628 ifm_part_shape = Shape4D(new_ifm_shape[n:])
629 ifm2_part_shape = Shape4D(new_ifm2_shape[n:])
630 op.ifm_shapes.append(Shape4D([ifm_decomp_shape.elements()] + new_ifm_shape[n:]))
631 op.ifm_shapes.append(Shape4D([ifm2_decomp_shape.elements()] + new_ifm2_shape[n:]))
632 else:
633 op.ifm_shapes.append(Shape4D([ofm_decomp_shape.elements()] + new_ofm_shape[n:]))
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200634
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200635 op_list = []
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200636 for height in range(ofm_decomp_shape.height):
637 for width in range(ofm_decomp_shape.width):
638 for depth in range(ofm_decomp_shape.depth):
639 ofm_offset = Shape4D(0, height, width, depth)
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200640 ofm_offset = Shape4D(ofm_offset.dot_prod(ofm_decomp_stride), 0, 0, 0)
641 ofm_cut = (ofm_offset, ofm_part_shape)
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200642
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200643 if binary:
644 ifm_d = depth if ifm_decomp_shape.depth == ofm_decomp_shape.depth else 0
645 ifm_w = width if ifm_decomp_shape.width == ofm_decomp_shape.width else 0
646 ifm_h = height if ifm_decomp_shape.height == ofm_decomp_shape.height else 0
647 ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d)
648 ifm_offset = Shape4D(ifm_offset.dot_prod(ifm_decomp_stride), 0, 0, 0)
649 ifm_cut = (ifm_offset, ifm_part_shape)
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200650
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200651 ifm2_d = depth if ifm2_decomp_shape.depth == ofm_decomp_shape.depth else 0
652 ifm2_w = width if ifm2_decomp_shape.width == ofm_decomp_shape.width else 0
653 ifm2_h = height if ifm2_decomp_shape.height == ofm_decomp_shape.height else 0
654 ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d)
655 ifm2_offset = Shape4D(ifm2_offset.dot_prod(ifm2_decomp_stride), 0, 0, 0)
656 ifm2_cut = (ifm2_offset, ifm2_part_shape)
657 op_list.append(create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut))
658 else:
659 op_list.append(create_elem_part_op(op, ofm_cut, None, ofm_cut))
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200660
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200661 ofm.ops.remove(op)
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200662 ifm.consumer_list.remove(op)
663 if binary:
664 ifm2.consumer_list.remove(op)
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200665
666 return op_list
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200667 else:
668 op.ofm_shapes.append(Shape4D(new_ofm_shape))
669 op.ifm_shapes.append(Shape4D(new_ifm_shape))
670 op.ifm_shapes.append(Shape4D(new_ifm2_shape))
671
672 return [op]
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200673
674
675def decomp_elementwise(tens, arch, nng):
676 """
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200677 Decompose elementwise ops with Rank > 3 (H,W,C).
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200678 Decompose size of tensors exceeding NPU max size
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200679 """
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200680 tens_ops = tens.ops.copy()
681 for op in tens_ops:
682 if op.type.is_elementwise_op():
683 decomp_list = decomp_dims_elementwise(op)
684 for part_op in decomp_list:
685 decompose_elem_tensors_hwc(part_op)
686 return tens
687
688
689def reshape_concat_shape(shape, rank, axis):
690 new_h = 1
691 for i in range(axis):
692 new_h *= shape[i]
693 new_c = 1
694 for i in range(axis + 1, rank):
695 new_c *= shape[i]
696 if axis == (rank - 1):
697 new_shape = [new_h, shape[axis], 1]
698 else:
699 new_shape = [new_h, shape[axis], new_c]
700 return new_shape
701
702
703def reshape_concat(op):
704 """
705 Reshapes concat ops with Rank > 3 (H,W,C).
706 """
707 ofm = op.ofm
708 rank = len(ofm.shape)
709 axis = op.attrs["axis"]
710 if axis < 0:
711 axis += rank
712
713 if rank > 3:
714 # Reshape so that axis in to be concatenated is the W dimension
715 # Reshape inputs
716 for inp in op.inputs:
717 new_shape = reshape_concat_shape(inp.shape, rank, axis)
718 op.ifm_shapes.append(Shape4D(new_shape))
719 # Reshape output
720 new_shape = reshape_concat_shape(ofm.shape, rank, axis)
721 op.ofm_shapes.append(Shape4D(new_shape))
722 op.attrs["axis4D"] = 2
723 else:
724 for inp in op.inputs:
725 op.ifm_shapes.append(Shape4D(inp.shape))
726 op.ofm_shapes.append(Shape4D(ofm.shape))
727 op.attrs["axis4D"] = axis + (4 - rank)
728
729
730def decomp_rewrite_concat(tens, arch, nng):
731 """
732 Decompose concat ops with Rank > 3 (H,W,C).
733 Rewrite of concat to elementwise operations
734 """
735 if len(tens.ops) == 1 and tens.ops[0].type == Op.Concat:
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200736 op = tens.ops[0]
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200737
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200738 reshape_concat(op)
739 rewrite_concat(op)
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200740
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200741 op.ofm.ops.remove(op)
742 for inp in op.inputs:
743 inp.consumer_list.remove(op)
744
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200745 return tens
746
747
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200748def decomp_rewrite_pad(op, arch):
749 """
750 Decomposition of pad to elementwise operations:
751 For each dimension that needs padding:
752 -Create a new PAD operator for each dimension to be added
753 Ifm/ofm are reshape so this is the width dimension is to be padded
754 (rank for each is 3)
755 -Rewrite the the new PAD operator so there is:
756 -1 Add operator for copying the data
757 -1 Add operator for each left/right to be padded
758 """
759 # TODO several things would be possible to optimize
760 # For instance there are cases when it should be possible to pad 2
761 # dimensions at the same time.
762 if op.type == Op.Pad:
763 ofm_elements = shape_num_elements(op.ofm.shape)
764 padding = op.inputs[1].values
765
766 rank = len(op.ifm.shape)
767 next_ifm = op.ifm
768 next_ifm_shape = next_ifm.shape.copy()
769
770 first_pad_rewrite_op = None
771 ifm_quant = op.ifm.quantization.clone()
772
773 for dim in range(padding.shape[0]):
774 # Check if padding is to be applied in this dimension
775 dim_pad = padding[dim]
776 if not (dim_pad == 0).all():
777 # Reshape so that width dimension is to be padded
778 new_ifm_shape = reshape_concat_shape(next_ifm_shape, rank, dim)
779 new_pad_input = np.zeros((4, 2), dtype=np.int32)
780 new_pad_input[2] = dim_pad
781
782 pad_op = create_pad_nop(f"{op.name}_dim_{dim}")
783 pad_op.add_input_tensor(next_ifm)
784 new_pad_tens = op.inputs[1].clone("_dim_{dim}")
785
786 name = op.inputs[1].name + f"_dim_{dim}"
787 new_pad_tens = create_const_tensor(
788 name, list(new_pad_input.shape), DataType.int32, new_pad_input, np.int32
789 )
790 pad_op.add_input_tensor(new_pad_tens)
791
792 new_ofm_shape = new_ifm_shape.copy()
793 new_ofm_shape[-2] = new_ofm_shape[-2] + dim_pad.sum()
794 next_ifm_shape[dim] = next_ifm_shape[dim] + dim_pad.sum()
795
796 if Shape4D(new_ofm_shape).elements() == ofm_elements:
797 # Last one, use op.ofm
798 ofm = op.ofm
799 else:
800 # add a new ofm Tensor
801 ofm = Tensor(new_ofm_shape, op.ofm.dtype, f"{pad_op.name}_tens")
802 ofm.quantization = ifm_quant.clone()
803
804 pad_op.set_output_tensor(ofm)
805 pad_op.ifm_shapes.append(Shape4D(new_ifm_shape))
806 pad_op.ofm_shapes.append(Shape4D(new_ofm_shape))
807 DebugDatabase.add_optimised(op, pad_op)
808 next_ifm = ofm
809
810 # Rewrite the pad op
811 converted_pad_op = convert_pad_in_width(pad_op)
812 first_pad_rewrite_op = converted_pad_op
813 else:
814 # Change to Identity operation (will be removed)
815 op.type = Op.Identity
816
817 if first_pad_rewrite_op:
818 assert op.ofm.shape == next_ifm_shape
819 for inp in op.inputs:
820 inp.consumer_list.remove(op)
821 return first_pad_rewrite_op
822
823 return op
824
825
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200826def fixup_quantization(op, arch, nng):
827 if op.ifm and op.ifm.quantization.zero_point is None:
828 op.ifm.quantization.zero_point = 0
829 if op.ifm2 and op.ifm2.quantization.zero_point is None:
Patrik Gustavssonf436ada2021-09-14 14:56:48 +0200830 op.ifm2.quantization.zero_point = 0
831 if not op.forced_output_quantization:
832 if op.ofm and op.ofm.quantization and op.ofm.quantization.zero_point is None:
833 op.ofm.quantization.zero_point = 0
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200834 return op
835
836
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200837def supported_operator_check(op, arch, nng):
838 op.run_on_npu = arch.tosa_supported_operators.is_operator_supported(op)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200839 assert op.run_on_npu or op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200840 return op
841
842
843def tosa_optimise_graph(nng, arch):
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200844
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200845 # TODO the supported operator checking need to be split in semantic and HW checks
846 for idx, sg in enumerate(nng.subgraphs):
847 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
848 nng, sg, arch, [], [supported_operator_check], rewrite_unsupported=False,
849 )
850
851 # Decomposing and rewrite of concat
852 for idx, sg in enumerate(nng.subgraphs):
853 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
854 nng, sg, arch, [decomp_rewrite_concat], [], rewrite_unsupported=False
855 )
856
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200857 # Decomposing of pad
858 for idx, sg in enumerate(nng.subgraphs):
859 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [decomp_rewrite_pad])
860 sg.refresh_after_modification()
861
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200862 # Handle sg input output
863 for idx, sg in enumerate(nng.subgraphs):
864 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
865 nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=False,
866 )
867
868 # Removal of reshapes
869 for sg in nng.subgraphs:
Patrik Gustavssonef3ebdd2021-10-01 11:10:25 +0200870 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_ops])
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200871 sg.refresh_after_modification()
872
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200873 # Decomposing of elementwise
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200874 for idx, sg in enumerate(nng.subgraphs):
875 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
876 nng, sg, arch, [decomp_elementwise], [], rewrite_unsupported=False
877 )
878
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200879 for idx, sg in enumerate(nng.subgraphs):
880 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200881 nng, sg, arch, [], [set_ifm_ofm_op_shapes], rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200882 )
883
Patrik Gustavssondf995102021-08-23 15:33:59 +0200884 # Removal of Transpose
885 for idx, sg in enumerate(nng.subgraphs):
886 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
887 nng, sg, arch, [], [remove_const_transpose], rewrite_unsupported=False,
888 )
889
Patrik Gustavssonf366fb12021-09-07 13:30:29 +0200890 # TODO, when and where to best handle calc_scaling_avgpool
891 for idx, sg in enumerate(nng.subgraphs):
892 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
893 nng, sg, arch, [], [calc_scaling_avgpool], rewrite_unsupported=False,
894 )
895
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200896 # Rewite Operators step
Patrik Gustavssonf436ada2021-09-14 14:56:48 +0200897 op_rewrite_list = [set_tensor_equivalence, rewrite_rescale, convert_depthwise_to_conv, convert_table_to_lut]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200898
899 for idx, sg in enumerate(nng.subgraphs):
900 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
901 nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
902 )
903
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200904 # Post-processing step 1
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200905 for idx, sg in enumerate(nng.subgraphs):
906 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200907 nng, sg, arch, [], [rewrite_activation, add_padding_fields],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200908 )
909
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200910 # Removal of Slice, need to be done after optimisation has been performed,
911 # since ifm/ofm_shapes are of importance to this function
912 for sg in nng.subgraphs:
913 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_splitsliceread])
914 sg.refresh_after_modification()
915
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200916 # Post-processing step 2
917 for idx, sg in enumerate(nng.subgraphs):
918 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [fixup_quantization],)
919
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200920 return nng