blob: fb8a08c0621af3f2ab55cb88f05697ce9ca5d0fc [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from . import rewrite_graph
26from . import scaling
27from .api import NpuRoundingMode
28from .data_type import DataType
29from .debug_database import DebugDatabase
30from .errors import UnsupportedFeatureError
31from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020032from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020033from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020034from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020035from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020036from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020037from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020038from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020039from .graph_optimiser_util import needed_total_padding
40from .graph_optimiser_util import set_ifm_ofm_op_shapes
41from .graph_optimiser_util import set_tensor_equivalence
42from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020043from .numeric_util import round_away_zero
44from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020045from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020046from .operation import NpuBlockType
47from .operation import Op
48from .operation import Operation
49from .operation import Padding
50from .operation_util import create_avgpool_nop
51from .operation_util import get_pad_values_from_input
52from .shape4d import Shape4D
53from .softmax import SoftMax
54from .tensor import check_quantized_tens_scaling_equal
55from .tensor import create_const_tensor
56from .tensor import create_equivalence_id
57from .tensor import QuantizationParameters
58from .tensor import Tensor
59from .tensor import TensorPurpose
60from .tflite_mapping import optype_to_builtintype
61
62passthrough_nodes = (Op.Identity,)
63
64
65def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
66 """Creates an average pool for the given concat op/input feature map"""
67 ofm = concat_op.ofm
68 avgpool_op = create_avgpool_nop(name)
69 avgpool_op.inputs = [ifm]
70 avgpool_op.outputs = [ofm]
71
72 avgpool_op.write_offset = write_offset
73 avgpool_op.write_shape = ifm_shape
74 ofm.ops.append(avgpool_op)
75 DebugDatabase.add_optimised(concat_op, avgpool_op)
76 avgpool_op.ifm_shapes.append(ifm_shape)
77 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
78 avgpool_op.memory_function = Op.ConcatSliceWrite
79 return avgpool_op
80
81
82def remove_passthrough_tensor(tens, arch, nng):
83 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
84 assert len(tens.ops[0].inputs) == 1
85 tens = tens.ops[0].inputs[0]
86 return tens
87
88
89def rewrite_concat_ops(op, arch):
90 if not op.run_on_npu or not op.type.is_concat_op():
91 return
92
93 axis_4D = 0
94 ofm = op.ofm
95 ofm.ops = []
96 offset = 0
97
98 unfuse_activation_function(op)
99
100 if op.type == Op.Pack:
101 # Pack is also referred to as Stack
102 axis = int(op.attrs["axis"])
103 if axis < 0: # Convert to positive axis
104 axis = len(op.inputs[0].shape) + 1 + axis
105
106 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
107
108 axis_4D = axis + (4 - len(desired_shape))
109
110 for idx, inp in enumerate(op.inputs):
111 op.ifm_shapes[idx] = Shape4D(desired_shape)
112 op.type = Op.PackReshaped
113
114 inputs, axis = op.get_concat_inputs_axis()
115 for idx, inp in enumerate(inputs):
116 if op.type != Op.PackReshaped:
117 op.ifm_shapes[idx] = Shape4D(inp.shape)
118 if axis >= 0:
119 axis_4D = axis + (4 - len(inp.shape))
120 else:
121 axis_4D = axis
122 write_offset = [0, 0, 0, 0]
123 write_offset[axis_4D] = offset
124 concat_end = offset + op.ifm_shapes[idx][axis_4D]
125 create_avg_pool_for_concat(
126 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
127 )
128 offset = concat_end
129 assert ofm.shape[axis] == offset
130
131 return op
132
133
134def rewrite_split_ops(tens, arch, nng):
135
136 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
137 split_op = tens.ops[0]
138
139 # Not supported so leave it and run on CPU
140 if not split_op.run_on_npu:
141 return tens
142
143 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
144
145 tens.ops = []
146 new_op = Operation(Op.SplitSliceRead, split_op.name)
147 new_op.inputs = [inp]
148 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000149 if None in (offset_end, offset_start):
150 read_shape = None
151 else:
152 # the read shape is relative to each start offset
153 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200154
155 # For Split the offset cannot be extracted from the tensor so it has to
156 # be calculated from the index of the output tensor
157 if axis is not None:
158 # Get the start and end of the split
159 offset_start = [0] * 4
160 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
161 for idx, out in enumerate(outputs):
162 if axis_4D_list is not None:
163 axis_4D = axis_4D_list[idx]
164 else:
165 split_op.ofm_shapes[idx] = Shape4D(out.shape)
166 if axis >= 0:
167 axis_4D = axis + (4 - len(out.shape))
168 else:
169 axis_4D = axis
170
171 if out == tens:
172 ofm_shape_idx = idx
173 read_shape = split_op.ofm_shapes[idx]
174 break
175
176 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
177
178 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
179 new_op.read_shapes[0] = read_shape
180 new_op.run_on_npu = True
181 new_op.set_output_tensor(tens)
182 new_op.ifm_shapes.append(Shape4D(inp.shape))
183 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
184 DebugDatabase.add_optimised(split_op, new_op)
185
186 return tens
187
188
189def remove_SplitSliceRead(op, arch):
190
191 if op.type == Op.SplitSliceRead:
192 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
193 if (
194 len(op.ofm.consumer_list) == 1
195 and op.ofm.consumer_list[0] is not None
196 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200197 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200198 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
199 ):
200 # SplitSliceRead can be performed by tensor consumer
201 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200202 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200203 else:
204 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
205 avgpool_op.add_input_tensor(op.ifm)
206 avgpool_op.outputs = [op.ofm]
207 op.ofm.ops.remove(op)
208 op.ofm.ops.append(avgpool_op)
209 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
210 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
211 avgpool_op.read_offsets[0] = op.read_offsets[0]
212 avgpool_op.read_shapes[0] = op.read_shapes[0]
213
214 op.ifm.consumer_list.remove(op)
215 DebugDatabase.add_optimised(op, avgpool_op)
216
217
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200218def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
219 k_w, k_h = kernel.dilated_wh()
220 s_x, s_y = kernel.stride
221 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
222 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
223 if padding_type == Padding.SAME:
224 left_pad = (xpad + 0) // 2
225 right_pad = (xpad + 1) // 2
226 top_pad = (ypad + 0) // 2
227 bottom_pad = (ypad + 1) // 2
228 elif padding_type == Padding.VALID:
229 left_pad = 0
230 right_pad = 0
231 top_pad = 0
232 bottom_pad = 0
233 elif padding_type == Padding.EXPLICIT:
234 # Padding is specified in a PAD operator which has been bypassed.
235 top, left, bottom, right = explicit_padding
236 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
237 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
238 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000239 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200240 padding = (top_pad, left_pad, bottom_pad, right_pad)
241 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
242 return padding, skirt
243
244
245def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
246 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
247 if padding_type == Padding.SAME:
248 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
249 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
250 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
251 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
252 left_pad = max(kernel_width - 1 - right_pad, 0)
253 top_pad = max(kernel_height - 1 - bottom_pad, 0)
254 elif padding_type == Padding.VALID:
255 right_pad = max(kernel_width - 2, 0)
256 bottom_pad = max(kernel_height - 2, 0)
257 left_pad = kernel_width - 1
258 top_pad = kernel_height - 1
259 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000260 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200261 padding = (top_pad, left_pad, bottom_pad, right_pad)
262 skirt = padding
263 return padding, skirt
264
265
266def fixup_conv2d_backprop(op, arch, nng):
267 if op.type == Op.Conv2DBackpropInput:
268 # flip the inputs
269 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
270 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000271 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200272
273 # Update strides
274 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
275
276 return op
277
278
279# Convert the op to an elementwise add
280def convert_resizebilinear_1x1_to_add(op):
281 op.type = Op.Add
282 op.name = op.name + "_add"
283 op.attrs["resizebilinear"] = True
284 # Create an input tensor filled with zeros
285 shape = op.ofm_shapes[0].as_list()
286 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100287 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200288 tens.quantization = QuantizationParameters(0.0, 255.0)
289 tens.quantization.scale_f32 = 1.0
290 tens.quantization.zero_point = 0
291 tens.consumer_list = [op]
292 tens_op = op.inputs[1].ops[0]
293 tens_op.set_output_tensor(tens)
294 # Set the add inputs
295 op.inputs[1] = op.inputs[0]
296 op.inputs[0] = tens
297 op.set_ifm_ofm_shapes()
298
299 return op
300
301
Rickard Boline546def2022-01-25 15:45:00 +0000302# Convert ResizeBilinear to a number of 2x2 nearest neighbor upscaling and one avgpool op with kernel size dependent
303# on the upscaling factor. Avgpool kernel limit of 8x8 when padding is applied limits upscaling to 8x8.
304def convert_resizebilinear_to_nearest_neighbor_upscaling_and_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200305 pre_op = op
306 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000307 dtype = op.ifm.dtype
308 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200309 if op.attrs["align_corners"]:
310 shape_modifier = 1
311 op.attrs["padding"] = Padding.VALID
312 else:
313 shape_modifier = 0
314 op.attrs["padding"] = Padding.SAME
Tim Hall3c5cfe92022-03-16 16:31:57 +0000315 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200316
317 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
318 out_shape = np.array(op.ofm_shapes[0].get_hw_as_list())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200319
Rickard Boline546def2022-01-25 15:45:00 +0000320 # Calculate how many times 2x2 upscaling needs to be performed
321 upscale_factor = round(out_shape[1] / upscaled_shape[1])
322 n = int(np.log2(upscale_factor))
323
324 # Perform 2x2 upscaling n-1 times
325 scaled_op = pre_op
326 for count in range(n - 1):
327 if count > 0:
328 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200329 scaled_op.inputs[0] = pre_op.outputs[0]
330
Rickard Boline546def2022-01-25 15:45:00 +0000331 # Nearest neighbor 2x2 upscaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200332 upscaled_shape = upscaled_shape * 2 - shape_modifier
Rickard Boline546def2022-01-25 15:45:00 +0000333 shape = op.ofm_shapes[0].as_list()
334 shape[1:3] = upscaled_shape
335 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
336 out_tens.quantization = op.outputs[0].quantization.clone()
337 scaled_op.set_output_tensor(out_tens)
338 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200339
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200340 scaled_op.set_ifm_ofm_shapes()
341
Rickard Boline546def2022-01-25 15:45:00 +0000342 # Last 2x2 upscaling also applies avgpool with kernel size dependent on the upscaling factor and adds
343 # padding to the right and bottom.
344 if n > 1:
345 scaled_op = op.clone(f"_{n-1}")
346 scaled_op.inputs[0] = pre_op.outputs[0]
347 scaled_op.attrs["padding"] = Padding.EXPLICIT
348 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
349 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
350 scaled_op.outputs = outputs
351 scaled_op.outputs[0].ops = [scaled_op]
352 scaled_op.set_ifm_ofm_shapes()
353
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200354 return op
355
356
357def fixup_resizebilinear(op, arch, nng):
358 if op.type == Op.ResizeBilinear and op.run_on_npu:
359 if op.ifm_shapes[0] == op.ofm_shapes[0]:
360 # Bypass nop resizebilinear
361 op.inputs = op.inputs[:1]
362 op.type = Op.Identity
363 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
364 convert_resizebilinear_1x1_to_add(op)
365 else:
Rickard Boline546def2022-01-25 15:45:00 +0000366 convert_resizebilinear_to_nearest_neighbor_upscaling_and_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200367
368 return op
369
370
371def convert_nop_split_to_identity(op, arch, nng):
372 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
373 # the list comprehension should return a list with a single tensor
374 # if it shouldn't, remove_passthrough_tensor will fail appropriately
375 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
376 op.type = Op.Identity
377 return op
378
379
380def rewrite_fully_connected_input(op, arch, nng):
381 if op.type == Op.FullyConnected:
382 n_in_elems = op.weights.shape[-2]
383 elms = op.ifm.elements()
384 batch_size = elms // n_in_elems
385 assert batch_size * n_in_elems == elms
386
387 op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
388 return op
389
390
391def convert_batched_fc_shape(op, arch, nng):
392 if op.type == Op.FullyConnected:
393 # Check if the first dimension indicates batching
394 if op.ifm_shapes[0].batch > 1:
395 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
396 n = op.ifm_shapes[0].batch
397 h, w = batching_split.get(n, (1, n))
398 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
399
400 # Reshape Weights to be 4D. IO becomes HWIO
401 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100402 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
403 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200404
405 n = op.ofm_shapes[0].batch
406 h, w = batching_split.get(n, (1, n))
407 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
408 return op
409
410
411def unfuse_activation_function(op):
412 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
413 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
414 op.activation = None
415 out_tens = op.outputs[0]
416 intermediate_tens = out_tens.clone("_act_intermediate")
417 act_op.set_output_tensor(out_tens)
418 act_op.add_input_tensor(intermediate_tens)
419 op.set_output_tensor(intermediate_tens)
420 act_op.set_ifm_ofm_shapes()
421
422
423def rewrite_stridedslice_output(op, arch, nng):
424 if not op.run_on_npu or op.type != Op.StridedSlice:
425 return op
426
427 new_axis_mask = op.attrs["new_axis_mask"]
428 shrink_axis_mask = op.attrs["shrink_axis_mask"]
429
430 if shrink_axis_mask == 0 and new_axis_mask == 0:
431 return op
432
433 axis_4D = [0] * len(op.outputs)
434 for idx, out_tens in enumerate(op.outputs):
435 output_shape = list(out_tens.shape)
436
437 if shrink_axis_mask != 0:
438 n = 0
439 axis = 0
440 while shrink_axis_mask:
441 prev_mask = shrink_axis_mask
442 n += 1
443 shrink_axis_mask &= shrink_axis_mask - 1
444 axis = int(math.log2(prev_mask - shrink_axis_mask))
445 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
446
447 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
448 op.attrs["shrink_axis_mask"] = 0
449 if axis >= 0:
450 axis_4D[idx] = axis + (4 - len(output_shape))
451 else:
452 axis_4D[idx] = axis
453 op.ofm_shapes[idx] = Shape4D(output_shape)
454
455 elif new_axis_mask != 0:
456 n = 0
457 axis = 0
458 while new_axis_mask:
459 prev_mask = new_axis_mask
460 n += 1
461 new_axis_mask &= new_axis_mask - 1
462 axis = int(math.log2(prev_mask - new_axis_mask))
463 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
464 new_axis_mask >>= 1
465
466 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
467 op.attrs["new_axis_mask"] = 0
468 if axis >= 0:
469 axis_4D[idx] = axis + (4 - len(output_shape))
470 else:
471 axis_4D[idx] = axis
472 op.ofm_shapes[idx] = Shape4D(output_shape)
473
474 op.attrs["split_axis_4D"] = axis_4D
475 return op
476
477
478def rewrite_unpack_output(op, arch, nng):
479 tens = op.outputs[0]
480 if op.run_on_npu and op.type == Op.Unpack:
481 # Unpack is also referred to as Unstack
482 axis = int(op.attrs["axis"])
483 if axis < 0: # Convert to positive axis
484 axis = len(op.inputs[0].shape) + 1 + axis
485 op.type = Op.UnpackReshaped
486 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
487
488 axis_4D = axis + (4 - len(desired_output_shape))
489 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
490
491 for idx, out_tens in enumerate(op.outputs):
492 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
493 return op
494
495
496def add_padding_fields(op, arch, nng):
497 if op.run_on_npu:
498 if "padding" in op.attrs:
499 input_shape = op.ifm_shapes[0]
500 output_shape = op.ofm_shapes[0]
501 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
502 kernel_size = op.inputs[1].shape[:2]
503 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
504 kernel_size = op.attrs["ksize"][1:3]
505 else:
506 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
507
508 if op.type == Op.Conv2DBackpropInputSwitchedBias:
509 upscaling_factor = output_shape.height // input_shape.height
510 padding, skirt = calc_upscaled_padding_and_skirt(
511 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
512 )
513 else:
514 padding, skirt = calc_padding_and_skirt(
515 op.attrs["padding"], op.kernel, input_shape, op.attrs.get("explicit_padding"),
516 )
517
518 op.attrs["explicit_padding"] = padding
519 op.attrs["skirt"] = skirt
520
521 return op
522
523
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200524def reorder_depthwise_weights(op, arch, nng):
525 if op.type.is_depthwise_conv2d_op():
526 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100527 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
528 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200529 weight_tensor.weight_transpose_depthwise = True
530
531 return op
532
533
534def optimise_strided_conv(op, arch, nng):
Louis Verhaard43d27582022-03-17 14:06:00 +0100535 if op.type != Op.Conv2DBias or op.op_index != 0:
536 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200537 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100538 weight_tensor = op.weights
539 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200540
541 if (
Louis Verhaard43d27582022-03-17 14:06:00 +0100542 stride_x == 2
543 and ifm_shape.depth <= 4
544 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200545 and weight_tensor is not None
546 and weight_tensor.shape[1] >= 2
547 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100548 k_w, _ = op.get_kernel_size()
549 curr_padding_x = needed_total_padding(ifm_shape.width, 2, k_w)
550 optimised_padding_x = needed_total_padding(ifm_shape.width // 2, 1, (k_w + 1) // 2)
551 if curr_padding_x != optimised_padding_x:
552 # Horizontal padding would become different after optimisation; this would not work
553 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200554 # IFM
555 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
556
557 # Weights
558 weight_shape = weight_tensor.shape
559 if weight_shape[1] % 2 != 0:
560 weight_shape[1] = weight_shape[1] + 1
561 padded_array = np.zeros(weight_shape)
562 for i in range(weight_shape[0]):
563 padded_array[i] = np.vstack(
564 [
James Peet7519d502021-07-19 16:47:58 +0100565 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200566 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
567 ]
568 )
James Peet7519d502021-07-19 16:47:58 +0100569 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200570 weight_shape[1] //= 2
571 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100572 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200573 weight_tensor.set_all_shapes(weight_shape)
574 # If multiple copies of the weights are used, we could avoid
575 # them having the same address by changing the value_id
576 weight_tensor.value_id = uuid.uuid4()
577
578 # Strides
579 stride_x = 1
580 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
581
582 return op
583
584
585def convert_conv_to_fc(op, arch, nng):
586 # Conv 1x1 can be equivalent to Fully Connected.
587 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
588 # caching/double buffering for the weights.
589 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
590 if op.type == Op.Conv2DBias:
591 h = op.ifm_shapes[0].height
592 w = op.ifm_shapes[0].width
593 kh, kw, _, _ = op.inputs[1].shape
594 if h == 1 and w == 1 and kh == 1 and kw == 1:
595 # Overwrite this op as a Fully Connected Op
596 op.name += "_fc"
597 op.type = Op.FullyConnected
598 op.attrs = {
599 "weights_format": 0,
600 }
601 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
602 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100603 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
604 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200605
606 DebugDatabase.add_optimised(op, op)
607 return op
608
609
610def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
611 if op.run_on_npu and op.type.is_relu_op():
612 ifm = op.inputs[0]
613 ofm = op.outputs[0]
614 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
615 # and requires its own to be inserted
616 if not check_quantized_tens_scaling_equal(ifm, ofm):
617 # Override this op with its own primary op (avgpool)
618 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
619 # And fuse the original activation function to it
620 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200621 # Add explicit rescaling
622 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
623 multiplier, shift = scaling.quantise_scale(rescale)
624 relu_fused_op.rescale = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200625 # Tidy up and assign the ifm and ofm to the new op
626 ifm.consumer_list.remove(op)
627
628 relu_fused_op.add_input_tensor(ifm)
629 relu_fused_op.set_output_tensor(ofm)
630 relu_fused_op.set_ifm_ofm_shapes()
631 op = relu_fused_op
632 return op
633
634
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200635def convert_softmax(op, arch, nng):
636 if op.type == Op.Softmax and op.run_on_npu:
637 softmax = SoftMax(op)
638 op = softmax.get_graph()
639 return op
640
641
642def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
643 r"""Whenever there is a subgraph with this topology:
644
645 Input X For X = -1 or X > 0
646 | \ / This subgraph can be replaced with either
647 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
648 | /
649 Max
650 """
651
652 if op.type == Op.Maximum:
653 # finds the Mul input(s) to the Max
654 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
655 if len(muls) == 1:
656 mul = muls[0].ops[0]
657 elif len(muls) == 2:
658 # In the case both inputs are Muls, find the one with the same input as the Max
659 mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
660 else:
661 # No Mul inputs
662 return op
663
664 # make sure the Mul doesn't have any other consumers
665 mul_ofm = mul.outputs[0]
666 if len(mul_ofm.consumers()) != 1:
667 return op
668 # make sure the Mul doesn't have a fused activation function
669 if mul.activation:
670 return op
671 ifm, ofm = op.get_ifm_ofm()
672 if ifm is None or ofm is None:
673 return op
674
675 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
676 return op
677 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
678 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
679 return op
680
681 # finds the branched input that goes to both the Max and the Mul
682 shared = set(op.inputs) & set(mul.inputs)
683 if len(shared) == 1:
684 shared_in = shared.pop()
685 # find the constant scalar input to the Mul
686 const_tens = (set(mul.inputs) - {shared_in}).pop()
687 # check that it is a scalar
688 if const_tens.shape != []:
689 return op
690 const = const_tens.ops[0]
691 # check that it is a constant
692 if const.type != Op.Const:
693 return op
694 # Remove the Mul from the shared input's consumers
695 shared_in.consumer_list.remove(mul)
696 else:
697 return op
698
699 val = const.outputs[0].values
700 if val >= 0:
701 new_op = Op.LeakyRelu
702 op.attrs["alpha"] = val
703 # to produce bit exact results, the alpha is not enough;
704 # save additional scaling info in attr "alpha_scale", to be used as input
705 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +0100706 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200707 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
708 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
709 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
710 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
711 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
712 elif val == -1:
713 new_op = Op.Abs
714 else:
715 return op
716
717 op.type = new_op
718 op.name = op.name.replace("Maximum", new_op.name)
719 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
720 op.inputs = [shared_in]
721 op.set_ifm_ofm_shapes()
722
723 # Record optimisation in debug database
724 DebugDatabase.add_optimised(op, op)
725
726 return op
727
728
729def convert_hardswish_to_lut(op, arch, nng):
730 if op.type == Op.HardSwish:
731 ifm, ofm = op.get_ifm_ofm()
732 # Generate the LUT
733 ifm_scale = np.double(ifm.quantization.scale_f32)
734 ofm_scale = np.double(ofm.quantization.scale_f32)
735 zp_in = ifm.quantization.zero_point
736 zp_out = ofm.quantization.zero_point
737 ifm_scale_hires = (1 / 128) * ifm_scale
738 relu_multiplier = np.double(3 / 32768)
739 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
740 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
741 # Use 16bit scale
742 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
743 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
744
745 values = []
746 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
747 quantized_min = min(ix)
748 quantized_max = max(ix)
749 for x in ix:
750 input_value = x - zp_in
751 input_value_hires = input_value * 128
752 # Compute the input value on essentially the output scale, not shifted yet
753 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
754 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
755 relu_value = np.int16(input_value_hires)
756 if relu_shift < 31:
757 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
758
759 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
760
761 if relu_shift < 31:
762 relu_value = fp_math.shift_left16(relu_value, 1)
763
764 if relu_shift > 31:
765 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
766
767 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
768 # Now convert that to a 16bit fixedpoint value in [0, 1]
769 relu_value = (relu_value + (1 << 15)) >> 1
770 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
771 shift = 31 - out_shift
772 shift = -shift if shift < 0 else 0
773 # Finally apply the output shift
774 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
775 lut_result = min(quantized_max, max(quantized_min, lut_result))
776 values.append(lut_result)
777 return convert_to_lut(op, values, "hardswish")
778 return op
779
780
781def convert_lrelu_to_mul_max(op, arch):
782 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
783 # (the opposite of convert_mul_max_to_abs_or_lrelu)
784 ifm, ofm = op.get_ifm_ofm()
785 if ifm is None or ofm is None:
786 return op
787
788 # Add multiplication with alpha
789 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
790 mul_alpha.add_input_tensor(ifm)
791 # Create const tensor containing alpha as scalar
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200792 alpha = np.float32(op.attrs["alpha"])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200793 quantization = ifm.quantization.clone()
794 quantization.min = 0
795 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
796 quantization.zero_point = 0
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200797 if np.isinf(1 / alpha):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200798 # Handling of alpha near zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200799 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200800 scalar = 0
801 else:
802 quantization.scale_f32 = alpha
803 scalar = alpha
804 alpha_tens = create_const_tensor(
805 op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
806 )
James Peet7519d502021-07-19 16:47:58 +0100807 alpha_tens.values = np.array([1])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200808 mul_alpha.add_input_tensor(alpha_tens)
809 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
810 mul_alpha.set_output_tensor(fm_alpha)
811 mul_alpha.set_ifm_ofm_shapes()
812 DebugDatabase.add_optimised(op, mul_alpha)
813
814 if check_quantized_tens_scaling_equal(ifm, ofm):
815 # No identity multiplication is needed
816 fm_id = ifm
817 else:
818 # Add multiplication with identity
819 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
820 mul_identity.add_input_tensor(ifm)
821 # Create const tensor containing identity as scalar
822 quantization = ifm.quantization.clone()
823 quantization.min = 0
824 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200825 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200826 quantization.zero_point = 0
827 identity_tens = create_const_tensor(
828 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
829 )
830 mul_identity.add_input_tensor(identity_tens)
831 # Make sure that fm_id is allocated to a different address than fm_alpha
832 fm_id = ofm.clone(op.name + "_id", set_unique=True)
833 mul_identity.set_output_tensor(fm_id)
834 mul_identity.set_ifm_ofm_shapes()
835 DebugDatabase.add_optimised(op, mul_identity)
836
837 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
838 op.type = Op.Maximum
839 op.name = op.name.replace("LeakyRelu", "Maximum")
840 op.inputs = []
841 ifm.consumer_list.remove(op)
842 op.add_input_tensor(fm_alpha)
843 op.add_input_tensor(fm_id)
844 op.set_ifm_ofm_shapes()
845
846 DebugDatabase.add_optimised(op, op)
847 return op
848
849
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200850def convert_to_lut8(op, fn, fn_name):
851 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
852 # fn is a function(real) -> real
853 ifm, ofm = op.get_ifm_ofm()
854 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
855 return op
856 # Generate the LUT
857 ifm_scale = np.double(ifm.quantization.scale_f32)
858 ofm_scale = np.double(ofm.quantization.scale_f32)
859 zp_in = ifm.quantization.zero_point
860 zp_out = ofm.quantization.zero_point
861 values = []
862 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
863 quantized_min = min(ix)
864 quantized_max = max(ix)
865 for x in ix:
866 x_real = ifm_scale * (x - zp_in)
867 y_real = fn(x_real)
868 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
869 lut_result = min(quantized_max, max(quantized_min, lut_result))
870 values.append(lut_result)
871 return convert_to_lut(op, values, fn_name)
872
873
874def convert_lrelu_to_lut(op, arch):
875 ifm, ofm = op.get_ifm_ofm()
876 # Generate the LUT
877 alpha = op.attrs["alpha"]
878 ifm_scale = np.double(ifm.quantization.scale_f32)
879 ofm_scale = np.double(ofm.quantization.scale_f32)
880 zp_in = ifm.quantization.zero_point
881 zp_out = ofm.quantization.zero_point
882 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
883 alpha_scalar = 1
884 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
885 if "alpha_scaling" in op.attrs:
886 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
887 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
888 values = []
889 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
890 quantized_min = min(ix)
891 quantized_max = max(ix)
892 for x in ix:
893 if x < zp_in:
894 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
895 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
896 )
897 else:
898 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
899 lut_result = min(quantized_max, max(quantized_min, lut_result))
900 values.append(lut_result)
901 return convert_to_lut(op, values, "lrelu")
902
903
904def convert_lrelu(op, arch, nng):
905 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
906 if op.type != Op.LeakyRelu:
907 return op
908 ifm, ofm = op.get_ifm_ofm()
909 if ifm is None or ofm is None:
910 return op
911 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
912 # use LUT for int8/uint8
913 return convert_lrelu_to_lut(op, arch)
914 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
915 # use LeakyRelu unmodified for int16 with equal input/output scaling
916 return op
917 return convert_lrelu_to_mul_max(op, arch)
918
919
920def convert_tanh_sigmoid_to_lut(op, arch, nng):
921 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
922 if op.type == Op.Sigmoid:
923 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
924 elif op.type == Op.Tanh:
925 return convert_to_lut8(op, math.tanh, "tanh")
926 return op
927
928
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200929def remove_memory_only_ops(op, arch):
930 if op.run_on_npu and op.type in memory_only_ops:
931 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200932
933
934def fuse_activation_function_with_prev(op, arch, nng):
935 # if op is a no-op: attempts to move the activation function to the preceding op
936 if not op.attrs.get("is_nop", False) or op.activation is None:
937 return op
938 ifm, ofm = op.get_ifm_ofm()
939 if ifm is None or ofm is None:
940 return op
941 # finds the input(s) to the operation
942 prev_op = ifm.ops[0]
943 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
944 fuse = (
945 prev_op.run_on_npu
946 and prev_op.type.npu_block_type != NpuBlockType.Default
947 and len(ifm.ops) == 1
948 and len(prev_op.outputs[0].consumers()) == 1
949 and prev_op.activation is None
950 )
951 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
952 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
953 # LUT currently only works correctly for elementwise ops
954 fuse = False
955 if not fuse:
956 return op
957 # Move the fused activation function + corresponding info to prev_op
958 prev_op.activation = op.activation
959 prev_op.forced_output_quantization = op.forced_output_quantization
960 if op.activation_lut is not None:
961 prev_op.set_activation_lut(op.activation_lut)
962 # Bypass op
963 prev_op.set_output_tensor(ofm)
964 DebugDatabase.add_optimised(op, prev_op)
965 return op
966
967
968def _leading_pad_ok(leading_pad, stride, kernel_size):
969 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
970 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
971 max_size = kernel_size // 2
972 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
973
974
975def replace_pad_by_hw_pad(op: Operation, arch, nng):
976 """
977 Tries to completely remove a PAD operator by using hardware padding.
978 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
979 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
980 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
981 if both operations can be run on the NPU.
982 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
983 """
984 if (
985 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +0000986 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200987 and op.run_on_npu
988 and op.attrs["padding"] == Padding.VALID
989 ):
990 pad_op = op.ifm.ops[0]
991 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
992 return op
993 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
994 return op
995 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
996 k = op.kernel
997 k_w, k_h = k.dilated_wh()
998
999 # Check if the PAD operator can be replaced by hardware padding
1000 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1001 # Too much padding, it would require hardware padding to actually insert zeros
1002 return op
1003 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1004 return op
1005
1006 if op.type.is_avgpool_op():
1007 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1008 for pad, k_size in (
1009 (left, k_w),
1010 (right, k_w),
1011 (top, k_h),
1012 (bottom, k_h),
1013 ):
1014 if pad not in (0, k_size // 2):
1015 return op
1016 # Average pool is converted to depthwise, because NPU average pool + same padding
1017 # has a special implementation that is different from PAD followed by average pool with
1018 # valid padding.
1019 k_w, k_h = op.kernel.width, op.kernel.height
1020 ifm = op.ifm
1021 # Remember other inputs
1022 other_inputs = op.inputs[1:]
1023 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1024 quantization = QuantizationParameters(0.0, 255.0)
1025 quantization.scale_f32 = 1.0 / (k_w * k_h)
1026 quantization.zero_point = 0
1027 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1028 weights = np.full(shape, 1)
1029
1030 weight_tens = create_const_tensor(
1031 op.name + "_weights",
1032 shape,
1033 op.ifm.dtype,
1034 weights,
1035 np.uint8,
1036 purpose=TensorPurpose.Weights,
1037 quantization=quantization,
1038 )
James Peet7519d502021-07-19 16:47:58 +01001039 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001040 op.type = Op.DepthwiseConv2DBias
1041 op.inputs = []
1042 op.add_input_tensor(ifm)
1043 op.add_input_tensor(weight_tens)
1044 # Add bias tensor, all biases set to 0
1045 op.inputs.append(None)
1046 fixup_bias_tensors(op, arch, nng)
1047 # Add other inputs
1048 op.inputs.extend(other_inputs)
1049 op.rounding_mode = NpuRoundingMode.NATURAL
1050
1051 # Bypass the PAD operator
1052 op.set_input_tensor(pad_op.ifm, 0)
1053 # Adjust the padding attributes of the convolution operator
1054 op.attrs["padding"] = Padding.EXPLICIT
1055 op.attrs["explicit_padding"] = (top, left, bottom, right)
1056 op.set_ifm_ofm_shapes()
1057 return op
1058
1059
1060def convert_pad(op: Operation, arch, nng):
1061 """
1062 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1063 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1064 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1065 """
1066 if op.type != Op.Pad or not op.run_on_npu:
1067 return op
1068 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1069
1070 ifm = op.ifm
1071 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001072 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001073 ofm = op.ofm
1074 assert ofm is not None
1075 ofm.ops = []
1076 ofm_shape = op.ofm_shapes[0]
1077
1078 # Average pool op that copies IFM to the right place inside the OFM
1079 shp0 = Shape4D(0, 0, 0, 0)
1080 shp_top = shp0.with_height(top)
1081 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1082 avgpool_op.activation = op.activation
1083 quant = ofm.quantization
1084 pad_value = quant.zero_point
1085 # Add operations that fill the borders of the OFM
1086 if top > 0:
1087 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1088 zero_tens = create_const_tensor(
1089 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1090 )
1091 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1092 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1093 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1094 if bottom > 0:
1095 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1096 zero_tens = create_const_tensor(
1097 op.name + "_bottom",
1098 shape.as_list(),
1099 ofm.dtype,
1100 shape.elements() * [pad_value],
1101 np.uint8,
1102 quantization=quant,
1103 )
1104 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1105 create_avg_pool_for_concat(
1106 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1107 )
1108 if left > 0:
1109 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1110 zero_tens = create_const_tensor(
1111 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1112 )
1113 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1114 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1115 if right > 0:
1116 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1117 zero_tens = create_const_tensor(
1118 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1119 )
1120 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1121 create_avg_pool_for_concat(
1122 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1123 )
1124
1125 op.type = Op.ConcatTFLite
1126 return avgpool_op
1127
1128
1129def add_attrs_to_resizebilinear(op, arch, nng):
1130 if op.type == Op.ResizeBilinear and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001131 input_shape = op.ifm_shapes[0]
1132 upscaled_height = input_shape.height * 2
1133 upscaled_width = input_shape.width * 2
1134 out_shape = op.ofm_shapes[0]
1135 if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
1136 # this means the output is supposed to be a x2 upscale,
1137 # so we need to do SAME padding
1138 op.attrs["padding"] = Padding.SAME
1139 elif (
1140 op.attrs["align_corners"]
1141 and out_shape.height == (upscaled_height - 1)
1142 and out_shape.width == (upscaled_width - 1)
1143 ):
1144 # here we can just run the avg pool without padding and
1145 # produce a (M * 2 - 1, N * 2 - 1) sized output
1146 op.attrs["padding"] = Padding.VALID
1147 else:
1148 return op
Tim Hall3c5cfe92022-03-16 16:31:57 +00001149 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001150 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
1151 return op
1152
1153
1154def fixup_bias_tensors(op, arch, nng):
1155 if op.type.needs_bias() and op.bias is None:
1156 # Op has no bias, add bias tensor filled with zeros
1157 nr_biases = op.inputs[1].shape[-1]
1158 bias_values = [0] * nr_biases
1159 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001160 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1161
1162 return op
1163
1164
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001165def fixup_asymmetric_weights(op, arch, nng):
1166 if op.run_on_npu and (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op()):
1167 if op.ifm.dtype == DataType.int8:
1168 if not np.all(op.weights.quantization.zero_point == 0):
1169 print(f"Warning: {op.type} '{op.name}' has asymmetric weights, zero points have been adjusted.")
1170 op.weights.quantization.zero_point *= 0
1171
1172 return op
1173
1174
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001175def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1176 if op.type == Op.Mean and op.run_on_npu:
1177 keep_dims = op.attrs.get("keep_dims", False)
1178 inp, axis = op.inputs
1179 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001180 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001181 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001182 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001183
1184 # Height and width axes have different index depending on dimensions
1185 if axis.shape == [] or axis.shape[0] == 1: # single axis
1186 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1187 if dims in (2, 3):
1188 if axis == 0:
1189 h, w = shape[axis], 1
1190 else:
1191 h, w = 1, shape[axis]
1192 else:
1193 if axis == 1:
1194 h, w = shape[axis], 1
1195 else:
1196 h, w = 1, shape[axis]
1197 else: # multiple axes
1198 axis = sorted(axis.values)
1199 h, w = [shape[i] for i in axis]
1200
1201 # Set necessary depthwise attributes
1202 op.attrs.update(
1203 {
1204 "padding": Padding.VALID,
1205 "stride_h": 1,
1206 "stride_w": 1,
1207 "strides": (1, 1, 1, 1),
1208 "depth_multiplier": 1,
1209 "channel_multiplier": 1,
1210 "dilation_h_factor": 1,
1211 "dilation_w_factor": 1,
1212 "dilation": (1, 1, 1, 1),
1213 }
1214 )
1215 # Change op type
1216 op.type = Op.DepthwiseConv2DBias
1217 # Set IFM/OFM shapes after changing op type
1218 op.set_ifm_ofm_shapes()
1219
1220 weight_scale, bias = 1, None
1221 ofmq, ifmq = op.ofm.quantization, inp.quantization
1222 # Set rounding mode, scaling and zero point based on which reference implementation to match
1223 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1224 if inp.dtype == DataType.uint8:
1225 # This attribute means a different scaling calculation is used in order to match reference
1226 op.low_precision_scaling = True
1227 weight_scale = h * w
1228 # Set zero points to 0 as they will be adjusted for with bias term
1229 foq = ofmq.clone()
1230 foq.zero_point = 0
1231 fiq = ifmq.clone()
1232 fiq.zero_point = 0
1233 op.forced_input_quantization = fiq
1234 bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
1235 # If the bias term is outside uint8 range, we need an Add op to apply it.
1236 if bias_term < 0 or bias_term > 255:
1237 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1238 # Bias term has higher bitness (i32) than input/output (u8).
1239 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1240 # the bias can only effectively assume values in the range [-255, 255].
1241 intermediate.dtype = DataType.int16
1242 intermediate.quantization.zero_point = 0
1243 add_op = Operation(Op.Add, op.name + "_bias")
1244 add_op.forced_output_quantization = foq
1245 add_op.add_input_tensor(intermediate)
1246 quant = QuantizationParameters()
1247 quant.zero_point = 0
1248 bias_term_tens = create_const_tensor(
James Peet7519d502021-07-19 16:47:58 +01001249 op.name + "_bias", [1, 1, 1, 1], DataType.int16, [bias_term], np.int16, quantization=quant,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001250 )
1251 add_op.add_input_tensor(bias_term_tens)
1252 add_op.set_output_tensor(op.ofm)
1253 add_op.set_ifm_ofm_shapes()
1254 add_op.activation = op.activation
1255 op.activation = None
1256 op.set_output_tensor(intermediate)
1257 op.set_ifm_ofm_shapes()
1258 # If not, we can just do it with the OFM zero point.
1259 else:
1260 foq.zero_point = bias_term
1261 op.forced_output_quantization = foq
1262 else:
1263 assert inp.dtype == DataType.int8
1264 # Use a depthwise to calculate the sum,
1265 # followed by a multiplication with 1/N to get the MEAN
1266 weight_scale = 1
1267 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1268 intermediate.dtype = DataType.int16
1269 mul_op = Operation(Op.Mul, op.name + "_mul")
1270 mul_op.add_input_tensor(intermediate)
1271 # Create scalar containing 1/N
1272 quant = QuantizationParameters()
1273 quant.zero_point = 0
1274 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1275 # while rounding mode NATURAL would round this to -1.
1276 # This can only occur if N is even, and can be emulated by
1277 # multiplying with a number that is slightly smaller than 1/N.
1278 # It must be so small that other roundings are not affected;
1279 # the calculated value is based on worst case,
1280 # which is sum 256 * N (the maximum sum that can occur with int8)
1281 n = int(h * w)
1282 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1283 quant.scale_f32 = 1 / (n - eps)
1284 scalar = create_const_tensor(
1285 op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
1286 )
1287 mul_op.add_input_tensor(scalar)
1288 mul_op.set_output_tensor(op.ofm)
1289 mul_op.set_ifm_ofm_shapes()
1290 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1291 mul_op.activation = op.activation
1292 op.activation = None
1293 op.set_output_tensor(intermediate)
1294 op.set_ifm_ofm_shapes()
1295 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1296 # Here we can just use a simple AvgPool with truncating rounding,
1297 # as we're emulating simple integer division.
1298 op.rounding_mode = NpuRoundingMode.TRUNCATE
1299 op.type = Op.AvgPool
1300 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1301 else:
1302 op.rounding_mode = NpuRoundingMode.NATURAL
1303 weight_scale = 1 / (h * w)
1304 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1305 bias = -ifmq.zero_point * h * w
1306 fiq = ifmq.clone()
1307 fiq.zero_point = 0
1308 op.forced_input_quantization = fiq
1309
1310 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001311 def extend_dims(dim, in_shape):
1312 if dim < 4:
1313 in_shape = [1] + in_shape
1314 if dim == 2:
1315 in_shape += [1]
1316 return in_shape
1317
1318 if dims < 4 or dims_ofm < 4:
1319 # Fix the ofm dimension when keep_dims is false
1320 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1321 if isinstance(axis, int) and dims_ofm + 1 == dims:
1322 ofm_shape.insert(axis, 1)
1323 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1324 for i in axis:
1325 ofm_shape.insert(i, 1)
1326 shape = extend_dims(dims, shape)
1327 dims_ofm = len(ofm_shape)
1328 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1329 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001330
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001331 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
1332 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001333 shape = [shape[0], 1, h * w, shape[3]]
1334 op.ifm_shapes[0] = Shape4D(shape)
1335 if h > 256 and op.type == Op.AvgPool:
1336 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1337
1338 # If the AvgPool version is used, we don't need to do anything else
1339 if op.type == Op.AvgPool:
1340 return op
1341
1342 # Make unit weight tensor quantization
1343 weight_quant = ifmq.clone()
1344 weight_quant.min = 0
1345 weight_quant.max = 255
1346 weight_quant.scale_f32 = weight_scale
1347 weight_quant.zero_point = 0
1348
1349 # Set weight shape to [H,W,C,B]
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001350 weight_shape = [h, w, shape[3], shape[0]]
1351
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001352 # Add unit weight tensor
1353 op.set_input_tensor(
1354 create_const_tensor(
1355 "weights",
1356 weight_shape,
1357 inp.dtype,
1358 np.ones(weight_shape),
1359 value_dtype=np.uint8,
1360 quantization=weight_quant,
1361 ),
1362 1,
1363 )
James Peet7519d502021-07-19 16:47:58 +01001364 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001365
1366 # Add None bias tensor
1367 op.inputs.append(None)
1368 # Add bias tensor
1369 if bias:
1370 bias_shape = [shape[-1]]
1371 op.set_input_tensor(
1372 create_const_tensor(
Tim Hall8ae29292021-07-28 16:52:03 +01001373 "bias", bias_shape, inp.dtype, np.ones(bias_shape) * bias, value_dtype=np.int32, quantization=None,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001374 ),
1375 2,
1376 )
1377
1378 return op
1379
1380
1381def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001382 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001383 return op
1384
1385
1386def tflite_optimise_graph(nng, arch):
1387 # Pre-processing step
1388 pre_process_list = [
1389 supported_operator_check,
1390 set_ifm_ofm_op_shapes,
1391 ]
1392
1393 for idx, sg in enumerate(nng.subgraphs):
1394 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1395 nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
1396 )
1397
1398 # Handle Concat Ops
1399 for idx, sg in enumerate(nng.subgraphs):
1400 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1401 sg.refresh_after_modification()
1402
1403 # Handle Split Ops
1404 for idx, sg in enumerate(nng.subgraphs):
1405 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1406 nng,
1407 sg,
1408 arch,
1409 [],
1410 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1411 rewrite_unsupported=False,
1412 )
1413
1414 for idx, sg in enumerate(nng.subgraphs):
1415 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1416 nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False,
1417 )
1418
1419 # Handle sg input output
1420 for idx, sg in enumerate(nng.subgraphs):
1421 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1422 nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
1423 )
1424
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001425 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001426 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001427 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001428 sg.refresh_after_modification()
1429
1430 # Rewrite of operators
1431 op_rewrite_list = [
1432 set_tensor_equivalence,
1433 convert_mean_to_depthwise_conv_or_avgpool,
1434 convert_depthwise_to_conv,
1435 convert_conv_to_fc,
1436 convert_softmax,
1437 optimise_strided_conv,
1438 convert_hardswish_to_lut,
1439 rewrite_fully_connected_input,
1440 convert_batched_fc_shape,
1441 fixup_conv2d_backprop,
1442 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001443 reorder_depthwise_weights,
1444 fixup_resizebilinear,
1445 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001446 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001447 convert_mul_max_to_abs_or_lrelu,
1448 convert_lrelu,
1449 convert_tanh_sigmoid_to_lut,
1450 replace_pad_by_hw_pad,
1451 ]
1452
1453 for idx, sg in enumerate(nng.subgraphs):
1454 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1455 nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
1456 )
1457
1458 for idx, sg in enumerate(nng.subgraphs):
1459 # remove passthrough tensors and attempt further optimizations
1460 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1461 nng,
1462 sg,
1463 arch,
1464 [remove_passthrough_tensor],
1465 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1466 )
1467
1468 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1469 # since ifm/ofm_shapes are of importance to this function
1470 for sg in nng.subgraphs:
1471 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1472 sg.refresh_after_modification()
1473
1474 return nng