blob: 06395784e2caeca9b682904c6f7c5a16bbd78645 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from . import rewrite_graph
26from . import scaling
27from .api import NpuRoundingMode
28from .data_type import DataType
29from .debug_database import DebugDatabase
30from .errors import UnsupportedFeatureError
31from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020032from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020033from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020034from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020035from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020036from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020037from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020038from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020039from .graph_optimiser_util import needed_total_padding
40from .graph_optimiser_util import set_ifm_ofm_op_shapes
41from .graph_optimiser_util import set_tensor_equivalence
42from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020043from .numeric_util import round_away_zero
44from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020045from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020046from .operation import NpuBlockType
47from .operation import Op
48from .operation import Operation
49from .operation import Padding
50from .operation_util import create_avgpool_nop
51from .operation_util import get_pad_values_from_input
52from .shape4d import Shape4D
53from .softmax import SoftMax
54from .tensor import check_quantized_tens_scaling_equal
55from .tensor import create_const_tensor
56from .tensor import create_equivalence_id
57from .tensor import QuantizationParameters
58from .tensor import Tensor
59from .tensor import TensorPurpose
60from .tflite_mapping import optype_to_builtintype
61
62passthrough_nodes = (Op.Identity,)
63
64
65def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
66 """Creates an average pool for the given concat op/input feature map"""
67 ofm = concat_op.ofm
68 avgpool_op = create_avgpool_nop(name)
69 avgpool_op.inputs = [ifm]
70 avgpool_op.outputs = [ofm]
71
72 avgpool_op.write_offset = write_offset
73 avgpool_op.write_shape = ifm_shape
74 ofm.ops.append(avgpool_op)
75 DebugDatabase.add_optimised(concat_op, avgpool_op)
76 avgpool_op.ifm_shapes.append(ifm_shape)
77 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
78 avgpool_op.memory_function = Op.ConcatSliceWrite
79 return avgpool_op
80
81
82def remove_passthrough_tensor(tens, arch, nng):
83 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
84 assert len(tens.ops[0].inputs) == 1
85 tens = tens.ops[0].inputs[0]
86 return tens
87
88
89def rewrite_concat_ops(op, arch):
90 if not op.run_on_npu or not op.type.is_concat_op():
91 return
92
93 axis_4D = 0
94 ofm = op.ofm
95 ofm.ops = []
96 offset = 0
97
98 unfuse_activation_function(op)
99
100 if op.type == Op.Pack:
101 # Pack is also referred to as Stack
102 axis = int(op.attrs["axis"])
103 if axis < 0: # Convert to positive axis
104 axis = len(op.inputs[0].shape) + 1 + axis
105
106 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
107
108 axis_4D = axis + (4 - len(desired_shape))
109
110 for idx, inp in enumerate(op.inputs):
111 op.ifm_shapes[idx] = Shape4D(desired_shape)
112 op.type = Op.PackReshaped
113
114 inputs, axis = op.get_concat_inputs_axis()
115 for idx, inp in enumerate(inputs):
116 if op.type != Op.PackReshaped:
117 op.ifm_shapes[idx] = Shape4D(inp.shape)
118 if axis >= 0:
119 axis_4D = axis + (4 - len(inp.shape))
120 else:
121 axis_4D = axis
122 write_offset = [0, 0, 0, 0]
123 write_offset[axis_4D] = offset
124 concat_end = offset + op.ifm_shapes[idx][axis_4D]
125 create_avg_pool_for_concat(
126 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
127 )
128 offset = concat_end
129 assert ofm.shape[axis] == offset
130
131 return op
132
133
134def rewrite_split_ops(tens, arch, nng):
135
136 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
137 split_op = tens.ops[0]
138
139 # Not supported so leave it and run on CPU
140 if not split_op.run_on_npu:
141 return tens
142
143 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
144
145 tens.ops = []
146 new_op = Operation(Op.SplitSliceRead, split_op.name)
147 new_op.inputs = [inp]
148 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000149 if None in (offset_end, offset_start):
150 read_shape = None
151 else:
152 # the read shape is relative to each start offset
153 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200154
155 # For Split the offset cannot be extracted from the tensor so it has to
156 # be calculated from the index of the output tensor
157 if axis is not None:
158 # Get the start and end of the split
159 offset_start = [0] * 4
160 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
161 for idx, out in enumerate(outputs):
162 if axis_4D_list is not None:
163 axis_4D = axis_4D_list[idx]
164 else:
165 split_op.ofm_shapes[idx] = Shape4D(out.shape)
166 if axis >= 0:
167 axis_4D = axis + (4 - len(out.shape))
168 else:
169 axis_4D = axis
170
171 if out == tens:
172 ofm_shape_idx = idx
173 read_shape = split_op.ofm_shapes[idx]
174 break
175
176 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
177
178 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
179 new_op.read_shapes[0] = read_shape
180 new_op.run_on_npu = True
181 new_op.set_output_tensor(tens)
182 new_op.ifm_shapes.append(Shape4D(inp.shape))
183 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
184 DebugDatabase.add_optimised(split_op, new_op)
185
186 return tens
187
188
189def remove_SplitSliceRead(op, arch):
190
191 if op.type == Op.SplitSliceRead:
192 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
193 if (
194 len(op.ofm.consumer_list) == 1
195 and op.ofm.consumer_list[0] is not None
196 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200197 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200198 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
199 ):
200 # SplitSliceRead can be performed by tensor consumer
201 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200202 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200203 else:
204 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
205 avgpool_op.add_input_tensor(op.ifm)
206 avgpool_op.outputs = [op.ofm]
207 op.ofm.ops.remove(op)
208 op.ofm.ops.append(avgpool_op)
209 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
210 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
211 avgpool_op.read_offsets[0] = op.read_offsets[0]
212 avgpool_op.read_shapes[0] = op.read_shapes[0]
213
214 op.ifm.consumer_list.remove(op)
215 DebugDatabase.add_optimised(op, avgpool_op)
216
217
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200218def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
219 k_w, k_h = kernel.dilated_wh()
220 s_x, s_y = kernel.stride
221 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
222 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
223 if padding_type == Padding.SAME:
224 left_pad = (xpad + 0) // 2
225 right_pad = (xpad + 1) // 2
226 top_pad = (ypad + 0) // 2
227 bottom_pad = (ypad + 1) // 2
228 elif padding_type == Padding.VALID:
229 left_pad = 0
230 right_pad = 0
231 top_pad = 0
232 bottom_pad = 0
233 elif padding_type == Padding.EXPLICIT:
234 # Padding is specified in a PAD operator which has been bypassed.
235 top, left, bottom, right = explicit_padding
236 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
237 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
238 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000239 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200240 padding = (top_pad, left_pad, bottom_pad, right_pad)
241 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
242 return padding, skirt
243
244
245def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
246 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
247 if padding_type == Padding.SAME:
248 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
249 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
250 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
251 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
252 left_pad = max(kernel_width - 1 - right_pad, 0)
253 top_pad = max(kernel_height - 1 - bottom_pad, 0)
254 elif padding_type == Padding.VALID:
255 right_pad = max(kernel_width - 2, 0)
256 bottom_pad = max(kernel_height - 2, 0)
257 left_pad = kernel_width - 1
258 top_pad = kernel_height - 1
259 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000260 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200261 padding = (top_pad, left_pad, bottom_pad, right_pad)
262 skirt = padding
263 return padding, skirt
264
265
266def fixup_conv2d_backprop(op, arch, nng):
267 if op.type == Op.Conv2DBackpropInput:
268 # flip the inputs
269 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
270 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000271 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200272
273 # Update strides
274 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
275
276 return op
277
278
279# Convert the op to an elementwise add
280def convert_resizebilinear_1x1_to_add(op):
281 op.type = Op.Add
282 op.name = op.name + "_add"
283 op.attrs["resizebilinear"] = True
284 # Create an input tensor filled with zeros
285 shape = op.ofm_shapes[0].as_list()
286 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100287 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200288 tens.quantization = QuantizationParameters(0.0, 255.0)
289 tens.quantization.scale_f32 = 1.0
290 tens.quantization.zero_point = 0
291 tens.consumer_list = [op]
292 tens_op = op.inputs[1].ops[0]
293 tens_op.set_output_tensor(tens)
294 # Set the add inputs
295 op.inputs[1] = op.inputs[0]
296 op.inputs[0] = tens
297 op.set_ifm_ofm_shapes()
298
299 return op
300
301
Rickard Boline546def2022-01-25 15:45:00 +0000302# Convert ResizeBilinear to a number of 2x2 nearest neighbor upscaling and one avgpool op with kernel size dependent
303# on the upscaling factor. Avgpool kernel limit of 8x8 when padding is applied limits upscaling to 8x8.
304def convert_resizebilinear_to_nearest_neighbor_upscaling_and_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200305 pre_op = op
306 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000307 dtype = op.ifm.dtype
308 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200309 if op.attrs["align_corners"]:
310 shape_modifier = 1
311 op.attrs["padding"] = Padding.VALID
312 else:
313 shape_modifier = 0
314 op.attrs["padding"] = Padding.SAME
Tim Hall3c5cfe92022-03-16 16:31:57 +0000315 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200316
317 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
318 out_shape = np.array(op.ofm_shapes[0].get_hw_as_list())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200319
Rickard Boline546def2022-01-25 15:45:00 +0000320 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100321 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
322 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
323 upscale_factor = int(round(out_shape[1] / upscaled_shape[1]))
Rickard Boline546def2022-01-25 15:45:00 +0000324 n = int(np.log2(upscale_factor))
325
326 # Perform 2x2 upscaling n-1 times
327 scaled_op = pre_op
328 for count in range(n - 1):
329 if count > 0:
330 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200331 scaled_op.inputs[0] = pre_op.outputs[0]
332
Rickard Boline546def2022-01-25 15:45:00 +0000333 # Nearest neighbor 2x2 upscaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200334 upscaled_shape = upscaled_shape * 2 - shape_modifier
Rickard Boline546def2022-01-25 15:45:00 +0000335 shape = op.ofm_shapes[0].as_list()
336 shape[1:3] = upscaled_shape
337 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
338 out_tens.quantization = op.outputs[0].quantization.clone()
339 scaled_op.set_output_tensor(out_tens)
340 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200341
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200342 scaled_op.set_ifm_ofm_shapes()
343
Rickard Boline546def2022-01-25 15:45:00 +0000344 # Last 2x2 upscaling also applies avgpool with kernel size dependent on the upscaling factor and adds
345 # padding to the right and bottom.
346 if n > 1:
347 scaled_op = op.clone(f"_{n-1}")
348 scaled_op.inputs[0] = pre_op.outputs[0]
349 scaled_op.attrs["padding"] = Padding.EXPLICIT
350 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
351 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
352 scaled_op.outputs = outputs
353 scaled_op.outputs[0].ops = [scaled_op]
354 scaled_op.set_ifm_ofm_shapes()
355
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200356 return op
357
358
359def fixup_resizebilinear(op, arch, nng):
360 if op.type == Op.ResizeBilinear and op.run_on_npu:
361 if op.ifm_shapes[0] == op.ofm_shapes[0]:
362 # Bypass nop resizebilinear
363 op.inputs = op.inputs[:1]
364 op.type = Op.Identity
365 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
366 convert_resizebilinear_1x1_to_add(op)
367 else:
Rickard Boline546def2022-01-25 15:45:00 +0000368 convert_resizebilinear_to_nearest_neighbor_upscaling_and_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200369
370 return op
371
372
373def convert_nop_split_to_identity(op, arch, nng):
374 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
375 # the list comprehension should return a list with a single tensor
376 # if it shouldn't, remove_passthrough_tensor will fail appropriately
377 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
378 op.type = Op.Identity
379 return op
380
381
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100382def rewrite_fully_connected_input(op: Operation, arch, nng):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200383
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100384 if op.type == Op.FullyConnected:
385 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
386 assert new_shape is not None, "Tensor can not be reshaped to 2D"
387 op.ifm_shapes[0] = new_shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200388 return op
389
390
391def convert_batched_fc_shape(op, arch, nng):
392 if op.type == Op.FullyConnected:
393 # Check if the first dimension indicates batching
394 if op.ifm_shapes[0].batch > 1:
395 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
396 n = op.ifm_shapes[0].batch
397 h, w = batching_split.get(n, (1, n))
398 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
399
400 # Reshape Weights to be 4D. IO becomes HWIO
401 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100402 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
403 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200404
405 n = op.ofm_shapes[0].batch
406 h, w = batching_split.get(n, (1, n))
407 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
408 return op
409
410
411def unfuse_activation_function(op):
412 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
413 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
414 op.activation = None
415 out_tens = op.outputs[0]
416 intermediate_tens = out_tens.clone("_act_intermediate")
417 act_op.set_output_tensor(out_tens)
418 act_op.add_input_tensor(intermediate_tens)
419 op.set_output_tensor(intermediate_tens)
420 act_op.set_ifm_ofm_shapes()
421
422
423def rewrite_stridedslice_output(op, arch, nng):
424 if not op.run_on_npu or op.type != Op.StridedSlice:
425 return op
426
427 new_axis_mask = op.attrs["new_axis_mask"]
428 shrink_axis_mask = op.attrs["shrink_axis_mask"]
429
430 if shrink_axis_mask == 0 and new_axis_mask == 0:
431 return op
432
433 axis_4D = [0] * len(op.outputs)
434 for idx, out_tens in enumerate(op.outputs):
435 output_shape = list(out_tens.shape)
436
437 if shrink_axis_mask != 0:
438 n = 0
439 axis = 0
440 while shrink_axis_mask:
441 prev_mask = shrink_axis_mask
442 n += 1
443 shrink_axis_mask &= shrink_axis_mask - 1
444 axis = int(math.log2(prev_mask - shrink_axis_mask))
445 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
446
447 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
448 op.attrs["shrink_axis_mask"] = 0
449 if axis >= 0:
450 axis_4D[idx] = axis + (4 - len(output_shape))
451 else:
452 axis_4D[idx] = axis
453 op.ofm_shapes[idx] = Shape4D(output_shape)
454
455 elif new_axis_mask != 0:
456 n = 0
457 axis = 0
458 while new_axis_mask:
459 prev_mask = new_axis_mask
460 n += 1
461 new_axis_mask &= new_axis_mask - 1
462 axis = int(math.log2(prev_mask - new_axis_mask))
463 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
464 new_axis_mask >>= 1
465
466 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
467 op.attrs["new_axis_mask"] = 0
468 if axis >= 0:
469 axis_4D[idx] = axis + (4 - len(output_shape))
470 else:
471 axis_4D[idx] = axis
472 op.ofm_shapes[idx] = Shape4D(output_shape)
473
474 op.attrs["split_axis_4D"] = axis_4D
475 return op
476
477
478def rewrite_unpack_output(op, arch, nng):
479 tens = op.outputs[0]
480 if op.run_on_npu and op.type == Op.Unpack:
481 # Unpack is also referred to as Unstack
482 axis = int(op.attrs["axis"])
483 if axis < 0: # Convert to positive axis
484 axis = len(op.inputs[0].shape) + 1 + axis
485 op.type = Op.UnpackReshaped
486 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
487
488 axis_4D = axis + (4 - len(desired_output_shape))
489 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
490
491 for idx, out_tens in enumerate(op.outputs):
492 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
493 return op
494
495
496def add_padding_fields(op, arch, nng):
497 if op.run_on_npu:
498 if "padding" in op.attrs:
499 input_shape = op.ifm_shapes[0]
500 output_shape = op.ofm_shapes[0]
501 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
502 kernel_size = op.inputs[1].shape[:2]
503 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
504 kernel_size = op.attrs["ksize"][1:3]
505 else:
506 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
507
508 if op.type == Op.Conv2DBackpropInputSwitchedBias:
509 upscaling_factor = output_shape.height // input_shape.height
510 padding, skirt = calc_upscaled_padding_and_skirt(
511 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
512 )
513 else:
514 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200515 op.attrs["padding"],
516 op.kernel,
517 input_shape,
518 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200519 )
520
521 op.attrs["explicit_padding"] = padding
522 op.attrs["skirt"] = skirt
523
524 return op
525
526
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200527def reorder_depthwise_weights(op, arch, nng):
528 if op.type.is_depthwise_conv2d_op():
529 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100530 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
531 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200532 weight_tensor.weight_transpose_depthwise = True
533
534 return op
535
536
537def optimise_strided_conv(op, arch, nng):
Louis Verhaard43d27582022-03-17 14:06:00 +0100538 if op.type != Op.Conv2DBias or op.op_index != 0:
539 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200540 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100541 weight_tensor = op.weights
542 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200543
544 if (
Louis Verhaard43d27582022-03-17 14:06:00 +0100545 stride_x == 2
546 and ifm_shape.depth <= 4
547 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200548 and weight_tensor is not None
549 and weight_tensor.shape[1] >= 2
550 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100551 k_w, _ = op.get_kernel_size()
552 curr_padding_x = needed_total_padding(ifm_shape.width, 2, k_w)
553 optimised_padding_x = needed_total_padding(ifm_shape.width // 2, 1, (k_w + 1) // 2)
554 if curr_padding_x != optimised_padding_x:
555 # Horizontal padding would become different after optimisation; this would not work
556 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200557 # IFM
558 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
559
560 # Weights
561 weight_shape = weight_tensor.shape
562 if weight_shape[1] % 2 != 0:
563 weight_shape[1] = weight_shape[1] + 1
564 padded_array = np.zeros(weight_shape)
565 for i in range(weight_shape[0]):
566 padded_array[i] = np.vstack(
567 [
James Peet7519d502021-07-19 16:47:58 +0100568 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200569 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
570 ]
571 )
James Peet7519d502021-07-19 16:47:58 +0100572 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200573 weight_shape[1] //= 2
574 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100575 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200576 weight_tensor.set_all_shapes(weight_shape)
577 # If multiple copies of the weights are used, we could avoid
578 # them having the same address by changing the value_id
579 weight_tensor.value_id = uuid.uuid4()
580
581 # Strides
582 stride_x = 1
583 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
584
585 return op
586
587
588def convert_conv_to_fc(op, arch, nng):
589 # Conv 1x1 can be equivalent to Fully Connected.
590 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
591 # caching/double buffering for the weights.
592 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
593 if op.type == Op.Conv2DBias:
594 h = op.ifm_shapes[0].height
595 w = op.ifm_shapes[0].width
596 kh, kw, _, _ = op.inputs[1].shape
597 if h == 1 and w == 1 and kh == 1 and kw == 1:
598 # Overwrite this op as a Fully Connected Op
599 op.name += "_fc"
600 op.type = Op.FullyConnected
601 op.attrs = {
602 "weights_format": 0,
603 }
604 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
605 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100606 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
607 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200608
609 DebugDatabase.add_optimised(op, op)
610 return op
611
612
613def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
614 if op.run_on_npu and op.type.is_relu_op():
615 ifm = op.inputs[0]
616 ofm = op.outputs[0]
617 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
618 # and requires its own to be inserted
619 if not check_quantized_tens_scaling_equal(ifm, ofm):
620 # Override this op with its own primary op (avgpool)
621 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
622 # And fuse the original activation function to it
623 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200624 # Add explicit rescaling
625 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
626 multiplier, shift = scaling.quantise_scale(rescale)
627 relu_fused_op.rescale = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200628 # Tidy up and assign the ifm and ofm to the new op
629 ifm.consumer_list.remove(op)
630
631 relu_fused_op.add_input_tensor(ifm)
632 relu_fused_op.set_output_tensor(ofm)
633 relu_fused_op.set_ifm_ofm_shapes()
634 op = relu_fused_op
635 return op
636
637
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200638def convert_softmax(op, arch, nng):
639 if op.type == Op.Softmax and op.run_on_npu:
640 softmax = SoftMax(op)
641 op = softmax.get_graph()
642 return op
643
644
645def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
646 r"""Whenever there is a subgraph with this topology:
647
Jonas Ohlssond8575072022-03-30 10:30:25 +0200648 Input X For X = -1 or X > 0
649 | \ / This subgraph can be replaced with either
650 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
651 | /
652 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200653 """
654
655 if op.type == Op.Maximum:
656 # finds the Mul input(s) to the Max
657 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
658 if len(muls) == 1:
659 mul = muls[0].ops[0]
660 elif len(muls) == 2:
661 # In the case both inputs are Muls, find the one with the same input as the Max
662 mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
663 else:
664 # No Mul inputs
665 return op
666
667 # make sure the Mul doesn't have any other consumers
668 mul_ofm = mul.outputs[0]
669 if len(mul_ofm.consumers()) != 1:
670 return op
671 # make sure the Mul doesn't have a fused activation function
672 if mul.activation:
673 return op
674 ifm, ofm = op.get_ifm_ofm()
675 if ifm is None or ofm is None:
676 return op
677
678 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
679 return op
680 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
681 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
682 return op
683
684 # finds the branched input that goes to both the Max and the Mul
685 shared = set(op.inputs) & set(mul.inputs)
686 if len(shared) == 1:
687 shared_in = shared.pop()
688 # find the constant scalar input to the Mul
689 const_tens = (set(mul.inputs) - {shared_in}).pop()
690 # check that it is a scalar
691 if const_tens.shape != []:
692 return op
693 const = const_tens.ops[0]
694 # check that it is a constant
695 if const.type != Op.Const:
696 return op
697 # Remove the Mul from the shared input's consumers
698 shared_in.consumer_list.remove(mul)
699 else:
700 return op
701
702 val = const.outputs[0].values
703 if val >= 0:
704 new_op = Op.LeakyRelu
705 op.attrs["alpha"] = val
706 # to produce bit exact results, the alpha is not enough;
707 # save additional scaling info in attr "alpha_scale", to be used as input
708 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +0100709 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200710 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
711 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
712 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
713 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
714 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
715 elif val == -1:
716 new_op = Op.Abs
717 else:
718 return op
719
720 op.type = new_op
721 op.name = op.name.replace("Maximum", new_op.name)
722 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
723 op.inputs = [shared_in]
724 op.set_ifm_ofm_shapes()
725
726 # Record optimisation in debug database
727 DebugDatabase.add_optimised(op, op)
728
729 return op
730
731
732def convert_hardswish_to_lut(op, arch, nng):
733 if op.type == Op.HardSwish:
734 ifm, ofm = op.get_ifm_ofm()
735 # Generate the LUT
736 ifm_scale = np.double(ifm.quantization.scale_f32)
737 ofm_scale = np.double(ofm.quantization.scale_f32)
738 zp_in = ifm.quantization.zero_point
739 zp_out = ofm.quantization.zero_point
740 ifm_scale_hires = (1 / 128) * ifm_scale
741 relu_multiplier = np.double(3 / 32768)
742 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
743 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
744 # Use 16bit scale
745 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
746 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
747
748 values = []
749 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
750 quantized_min = min(ix)
751 quantized_max = max(ix)
752 for x in ix:
753 input_value = x - zp_in
754 input_value_hires = input_value * 128
755 # Compute the input value on essentially the output scale, not shifted yet
756 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
757 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
758 relu_value = np.int16(input_value_hires)
759 if relu_shift < 31:
760 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
761
762 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
763
764 if relu_shift < 31:
765 relu_value = fp_math.shift_left16(relu_value, 1)
766
767 if relu_shift > 31:
768 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
769
770 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
771 # Now convert that to a 16bit fixedpoint value in [0, 1]
772 relu_value = (relu_value + (1 << 15)) >> 1
773 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
774 shift = 31 - out_shift
775 shift = -shift if shift < 0 else 0
776 # Finally apply the output shift
777 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
778 lut_result = min(quantized_max, max(quantized_min, lut_result))
779 values.append(lut_result)
780 return convert_to_lut(op, values, "hardswish")
781 return op
782
783
784def convert_lrelu_to_mul_max(op, arch):
785 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
786 # (the opposite of convert_mul_max_to_abs_or_lrelu)
787 ifm, ofm = op.get_ifm_ofm()
788 if ifm is None or ofm is None:
789 return op
790
791 # Add multiplication with alpha
792 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
793 mul_alpha.add_input_tensor(ifm)
794 # Create const tensor containing alpha as scalar
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200795 alpha = np.float32(op.attrs["alpha"])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200796 quantization = ifm.quantization.clone()
797 quantization.min = 0
798 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
799 quantization.zero_point = 0
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200800 if np.isinf(1 / alpha):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200801 # Handling of alpha near zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200802 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200803 scalar = 0
804 else:
805 quantization.scale_f32 = alpha
806 scalar = alpha
807 alpha_tens = create_const_tensor(
808 op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
809 )
James Peet7519d502021-07-19 16:47:58 +0100810 alpha_tens.values = np.array([1])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200811 mul_alpha.add_input_tensor(alpha_tens)
812 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
813 mul_alpha.set_output_tensor(fm_alpha)
814 mul_alpha.set_ifm_ofm_shapes()
815 DebugDatabase.add_optimised(op, mul_alpha)
816
817 if check_quantized_tens_scaling_equal(ifm, ofm):
818 # No identity multiplication is needed
819 fm_id = ifm
820 else:
821 # Add multiplication with identity
822 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
823 mul_identity.add_input_tensor(ifm)
824 # Create const tensor containing identity as scalar
825 quantization = ifm.quantization.clone()
826 quantization.min = 0
827 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200828 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200829 quantization.zero_point = 0
830 identity_tens = create_const_tensor(
831 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
832 )
833 mul_identity.add_input_tensor(identity_tens)
834 # Make sure that fm_id is allocated to a different address than fm_alpha
835 fm_id = ofm.clone(op.name + "_id", set_unique=True)
836 mul_identity.set_output_tensor(fm_id)
837 mul_identity.set_ifm_ofm_shapes()
838 DebugDatabase.add_optimised(op, mul_identity)
839
840 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
841 op.type = Op.Maximum
842 op.name = op.name.replace("LeakyRelu", "Maximum")
843 op.inputs = []
844 ifm.consumer_list.remove(op)
845 op.add_input_tensor(fm_alpha)
846 op.add_input_tensor(fm_id)
847 op.set_ifm_ofm_shapes()
848
849 DebugDatabase.add_optimised(op, op)
850 return op
851
852
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200853def convert_to_lut8(op, fn, fn_name):
854 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
855 # fn is a function(real) -> real
856 ifm, ofm = op.get_ifm_ofm()
857 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
858 return op
859 # Generate the LUT
860 ifm_scale = np.double(ifm.quantization.scale_f32)
861 ofm_scale = np.double(ofm.quantization.scale_f32)
862 zp_in = ifm.quantization.zero_point
863 zp_out = ofm.quantization.zero_point
864 values = []
865 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
866 quantized_min = min(ix)
867 quantized_max = max(ix)
868 for x in ix:
869 x_real = ifm_scale * (x - zp_in)
870 y_real = fn(x_real)
871 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
872 lut_result = min(quantized_max, max(quantized_min, lut_result))
873 values.append(lut_result)
874 return convert_to_lut(op, values, fn_name)
875
876
877def convert_lrelu_to_lut(op, arch):
878 ifm, ofm = op.get_ifm_ofm()
879 # Generate the LUT
880 alpha = op.attrs["alpha"]
881 ifm_scale = np.double(ifm.quantization.scale_f32)
882 ofm_scale = np.double(ofm.quantization.scale_f32)
883 zp_in = ifm.quantization.zero_point
884 zp_out = ofm.quantization.zero_point
885 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
886 alpha_scalar = 1
887 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
888 if "alpha_scaling" in op.attrs:
889 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
890 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
891 values = []
892 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
893 quantized_min = min(ix)
894 quantized_max = max(ix)
895 for x in ix:
896 if x < zp_in:
897 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
898 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
899 )
900 else:
901 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
902 lut_result = min(quantized_max, max(quantized_min, lut_result))
903 values.append(lut_result)
904 return convert_to_lut(op, values, "lrelu")
905
906
907def convert_lrelu(op, arch, nng):
908 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
909 if op.type != Op.LeakyRelu:
910 return op
911 ifm, ofm = op.get_ifm_ofm()
912 if ifm is None or ofm is None:
913 return op
914 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
915 # use LUT for int8/uint8
916 return convert_lrelu_to_lut(op, arch)
917 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
918 # use LeakyRelu unmodified for int16 with equal input/output scaling
919 return op
920 return convert_lrelu_to_mul_max(op, arch)
921
922
923def convert_tanh_sigmoid_to_lut(op, arch, nng):
924 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
925 if op.type == Op.Sigmoid:
926 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
927 elif op.type == Op.Tanh:
928 return convert_to_lut8(op, math.tanh, "tanh")
929 return op
930
931
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200932def remove_memory_only_ops(op, arch):
933 if op.run_on_npu and op.type in memory_only_ops:
934 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200935
936
937def fuse_activation_function_with_prev(op, arch, nng):
938 # if op is a no-op: attempts to move the activation function to the preceding op
939 if not op.attrs.get("is_nop", False) or op.activation is None:
940 return op
941 ifm, ofm = op.get_ifm_ofm()
942 if ifm is None or ofm is None:
943 return op
944 # finds the input(s) to the operation
945 prev_op = ifm.ops[0]
946 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
947 fuse = (
948 prev_op.run_on_npu
949 and prev_op.type.npu_block_type != NpuBlockType.Default
950 and len(ifm.ops) == 1
951 and len(prev_op.outputs[0].consumers()) == 1
952 and prev_op.activation is None
953 )
954 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
955 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
956 # LUT currently only works correctly for elementwise ops
957 fuse = False
958 if not fuse:
959 return op
960 # Move the fused activation function + corresponding info to prev_op
961 prev_op.activation = op.activation
962 prev_op.forced_output_quantization = op.forced_output_quantization
963 if op.activation_lut is not None:
964 prev_op.set_activation_lut(op.activation_lut)
965 # Bypass op
966 prev_op.set_output_tensor(ofm)
967 DebugDatabase.add_optimised(op, prev_op)
968 return op
969
970
971def _leading_pad_ok(leading_pad, stride, kernel_size):
972 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
973 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
974 max_size = kernel_size // 2
975 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
976
977
978def replace_pad_by_hw_pad(op: Operation, arch, nng):
979 """
980 Tries to completely remove a PAD operator by using hardware padding.
981 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
982 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
983 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
984 if both operations can be run on the NPU.
985 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
986 """
987 if (
988 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +0000989 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200990 and op.run_on_npu
991 and op.attrs["padding"] == Padding.VALID
992 ):
993 pad_op = op.ifm.ops[0]
994 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
995 return op
996 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
997 return op
998 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
999 k = op.kernel
1000 k_w, k_h = k.dilated_wh()
1001
1002 # Check if the PAD operator can be replaced by hardware padding
1003 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1004 # Too much padding, it would require hardware padding to actually insert zeros
1005 return op
1006 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1007 return op
1008
1009 if op.type.is_avgpool_op():
1010 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1011 for pad, k_size in (
1012 (left, k_w),
1013 (right, k_w),
1014 (top, k_h),
1015 (bottom, k_h),
1016 ):
1017 if pad not in (0, k_size // 2):
1018 return op
1019 # Average pool is converted to depthwise, because NPU average pool + same padding
1020 # has a special implementation that is different from PAD followed by average pool with
1021 # valid padding.
1022 k_w, k_h = op.kernel.width, op.kernel.height
1023 ifm = op.ifm
1024 # Remember other inputs
1025 other_inputs = op.inputs[1:]
1026 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1027 quantization = QuantizationParameters(0.0, 255.0)
1028 quantization.scale_f32 = 1.0 / (k_w * k_h)
1029 quantization.zero_point = 0
1030 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1031 weights = np.full(shape, 1)
1032
1033 weight_tens = create_const_tensor(
1034 op.name + "_weights",
1035 shape,
1036 op.ifm.dtype,
1037 weights,
1038 np.uint8,
1039 purpose=TensorPurpose.Weights,
1040 quantization=quantization,
1041 )
James Peet7519d502021-07-19 16:47:58 +01001042 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001043 op.type = Op.DepthwiseConv2DBias
1044 op.inputs = []
1045 op.add_input_tensor(ifm)
1046 op.add_input_tensor(weight_tens)
1047 # Add bias tensor, all biases set to 0
1048 op.inputs.append(None)
1049 fixup_bias_tensors(op, arch, nng)
1050 # Add other inputs
1051 op.inputs.extend(other_inputs)
1052 op.rounding_mode = NpuRoundingMode.NATURAL
1053
1054 # Bypass the PAD operator
1055 op.set_input_tensor(pad_op.ifm, 0)
1056 # Adjust the padding attributes of the convolution operator
1057 op.attrs["padding"] = Padding.EXPLICIT
1058 op.attrs["explicit_padding"] = (top, left, bottom, right)
1059 op.set_ifm_ofm_shapes()
1060 return op
1061
1062
1063def convert_pad(op: Operation, arch, nng):
1064 """
1065 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1066 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1067 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1068 """
1069 if op.type != Op.Pad or not op.run_on_npu:
1070 return op
1071 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1072
1073 ifm = op.ifm
1074 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001075 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001076 ofm = op.ofm
1077 assert ofm is not None
1078 ofm.ops = []
1079 ofm_shape = op.ofm_shapes[0]
1080
1081 # Average pool op that copies IFM to the right place inside the OFM
1082 shp0 = Shape4D(0, 0, 0, 0)
1083 shp_top = shp0.with_height(top)
1084 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1085 avgpool_op.activation = op.activation
1086 quant = ofm.quantization
1087 pad_value = quant.zero_point
1088 # Add operations that fill the borders of the OFM
1089 if top > 0:
1090 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1091 zero_tens = create_const_tensor(
1092 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1093 )
1094 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1095 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1096 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1097 if bottom > 0:
1098 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1099 zero_tens = create_const_tensor(
1100 op.name + "_bottom",
1101 shape.as_list(),
1102 ofm.dtype,
1103 shape.elements() * [pad_value],
1104 np.uint8,
1105 quantization=quant,
1106 )
1107 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1108 create_avg_pool_for_concat(
1109 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1110 )
1111 if left > 0:
1112 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1113 zero_tens = create_const_tensor(
1114 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1115 )
1116 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1117 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1118 if right > 0:
1119 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1120 zero_tens = create_const_tensor(
1121 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1122 )
1123 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1124 create_avg_pool_for_concat(
1125 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1126 )
1127
1128 op.type = Op.ConcatTFLite
1129 return avgpool_op
1130
1131
1132def add_attrs_to_resizebilinear(op, arch, nng):
1133 if op.type == Op.ResizeBilinear and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001134 input_shape = op.ifm_shapes[0]
1135 upscaled_height = input_shape.height * 2
1136 upscaled_width = input_shape.width * 2
1137 out_shape = op.ofm_shapes[0]
1138 if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
1139 # this means the output is supposed to be a x2 upscale,
1140 # so we need to do SAME padding
1141 op.attrs["padding"] = Padding.SAME
1142 elif (
1143 op.attrs["align_corners"]
1144 and out_shape.height == (upscaled_height - 1)
1145 and out_shape.width == (upscaled_width - 1)
1146 ):
1147 # here we can just run the avg pool without padding and
1148 # produce a (M * 2 - 1, N * 2 - 1) sized output
1149 op.attrs["padding"] = Padding.VALID
1150 else:
1151 return op
Tim Hall3c5cfe92022-03-16 16:31:57 +00001152 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001153 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
1154 return op
1155
1156
1157def fixup_bias_tensors(op, arch, nng):
1158 if op.type.needs_bias() and op.bias is None:
1159 # Op has no bias, add bias tensor filled with zeros
1160 nr_biases = op.inputs[1].shape[-1]
1161 bias_values = [0] * nr_biases
1162 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001163 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1164
1165 return op
1166
1167
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001168def fixup_asymmetric_weights(op, arch, nng):
1169 if op.run_on_npu and (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op()):
1170 if op.ifm.dtype == DataType.int8:
1171 if not np.all(op.weights.quantization.zero_point == 0):
1172 print(f"Warning: {op.type} '{op.name}' has asymmetric weights, zero points have been adjusted.")
1173 op.weights.quantization.zero_point *= 0
1174
1175 return op
1176
1177
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001178def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1179 if op.type == Op.Mean and op.run_on_npu:
1180 keep_dims = op.attrs.get("keep_dims", False)
1181 inp, axis = op.inputs
1182 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001183 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001184 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001185 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001186
1187 # Height and width axes have different index depending on dimensions
1188 if axis.shape == [] or axis.shape[0] == 1: # single axis
1189 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1190 if dims in (2, 3):
1191 if axis == 0:
1192 h, w = shape[axis], 1
1193 else:
1194 h, w = 1, shape[axis]
1195 else:
1196 if axis == 1:
1197 h, w = shape[axis], 1
1198 else:
1199 h, w = 1, shape[axis]
1200 else: # multiple axes
1201 axis = sorted(axis.values)
1202 h, w = [shape[i] for i in axis]
1203
1204 # Set necessary depthwise attributes
1205 op.attrs.update(
1206 {
1207 "padding": Padding.VALID,
1208 "stride_h": 1,
1209 "stride_w": 1,
1210 "strides": (1, 1, 1, 1),
1211 "depth_multiplier": 1,
1212 "channel_multiplier": 1,
1213 "dilation_h_factor": 1,
1214 "dilation_w_factor": 1,
1215 "dilation": (1, 1, 1, 1),
1216 }
1217 )
1218 # Change op type
1219 op.type = Op.DepthwiseConv2DBias
1220 # Set IFM/OFM shapes after changing op type
1221 op.set_ifm_ofm_shapes()
1222
1223 weight_scale, bias = 1, None
1224 ofmq, ifmq = op.ofm.quantization, inp.quantization
1225 # Set rounding mode, scaling and zero point based on which reference implementation to match
1226 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1227 if inp.dtype == DataType.uint8:
1228 # This attribute means a different scaling calculation is used in order to match reference
1229 op.low_precision_scaling = True
1230 weight_scale = h * w
1231 # Set zero points to 0 as they will be adjusted for with bias term
1232 foq = ofmq.clone()
1233 foq.zero_point = 0
1234 fiq = ifmq.clone()
1235 fiq.zero_point = 0
1236 op.forced_input_quantization = fiq
1237 bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
1238 # If the bias term is outside uint8 range, we need an Add op to apply it.
1239 if bias_term < 0 or bias_term > 255:
1240 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1241 # Bias term has higher bitness (i32) than input/output (u8).
1242 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1243 # the bias can only effectively assume values in the range [-255, 255].
1244 intermediate.dtype = DataType.int16
1245 intermediate.quantization.zero_point = 0
1246 add_op = Operation(Op.Add, op.name + "_bias")
1247 add_op.forced_output_quantization = foq
1248 add_op.add_input_tensor(intermediate)
1249 quant = QuantizationParameters()
1250 quant.zero_point = 0
1251 bias_term_tens = create_const_tensor(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001252 op.name + "_bias",
1253 [1, 1, 1, 1],
1254 DataType.int16,
1255 [bias_term],
1256 np.int16,
1257 quantization=quant,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001258 )
1259 add_op.add_input_tensor(bias_term_tens)
1260 add_op.set_output_tensor(op.ofm)
1261 add_op.set_ifm_ofm_shapes()
1262 add_op.activation = op.activation
1263 op.activation = None
1264 op.set_output_tensor(intermediate)
1265 op.set_ifm_ofm_shapes()
1266 # If not, we can just do it with the OFM zero point.
1267 else:
1268 foq.zero_point = bias_term
1269 op.forced_output_quantization = foq
1270 else:
1271 assert inp.dtype == DataType.int8
1272 # Use a depthwise to calculate the sum,
1273 # followed by a multiplication with 1/N to get the MEAN
1274 weight_scale = 1
1275 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1276 intermediate.dtype = DataType.int16
1277 mul_op = Operation(Op.Mul, op.name + "_mul")
1278 mul_op.add_input_tensor(intermediate)
1279 # Create scalar containing 1/N
1280 quant = QuantizationParameters()
1281 quant.zero_point = 0
1282 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1283 # while rounding mode NATURAL would round this to -1.
1284 # This can only occur if N is even, and can be emulated by
1285 # multiplying with a number that is slightly smaller than 1/N.
1286 # It must be so small that other roundings are not affected;
1287 # the calculated value is based on worst case,
1288 # which is sum 256 * N (the maximum sum that can occur with int8)
1289 n = int(h * w)
1290 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1291 quant.scale_f32 = 1 / (n - eps)
1292 scalar = create_const_tensor(
1293 op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
1294 )
1295 mul_op.add_input_tensor(scalar)
1296 mul_op.set_output_tensor(op.ofm)
1297 mul_op.set_ifm_ofm_shapes()
1298 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1299 mul_op.activation = op.activation
1300 op.activation = None
1301 op.set_output_tensor(intermediate)
1302 op.set_ifm_ofm_shapes()
1303 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1304 # Here we can just use a simple AvgPool with truncating rounding,
1305 # as we're emulating simple integer division.
1306 op.rounding_mode = NpuRoundingMode.TRUNCATE
1307 op.type = Op.AvgPool
1308 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1309 else:
1310 op.rounding_mode = NpuRoundingMode.NATURAL
1311 weight_scale = 1 / (h * w)
1312 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1313 bias = -ifmq.zero_point * h * w
1314 fiq = ifmq.clone()
1315 fiq.zero_point = 0
1316 op.forced_input_quantization = fiq
1317
1318 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001319 def extend_dims(dim, in_shape):
1320 if dim < 4:
1321 in_shape = [1] + in_shape
1322 if dim == 2:
1323 in_shape += [1]
1324 return in_shape
1325
1326 if dims < 4 or dims_ofm < 4:
1327 # Fix the ofm dimension when keep_dims is false
1328 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1329 if isinstance(axis, int) and dims_ofm + 1 == dims:
1330 ofm_shape.insert(axis, 1)
1331 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1332 for i in axis:
1333 ofm_shape.insert(i, 1)
1334 shape = extend_dims(dims, shape)
1335 dims_ofm = len(ofm_shape)
1336 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1337 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001338
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001339 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
1340 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001341 shape = [shape[0], 1, h * w, shape[3]]
1342 op.ifm_shapes[0] = Shape4D(shape)
1343 if h > 256 and op.type == Op.AvgPool:
1344 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1345
1346 # If the AvgPool version is used, we don't need to do anything else
1347 if op.type == Op.AvgPool:
1348 return op
1349
1350 # Make unit weight tensor quantization
1351 weight_quant = ifmq.clone()
1352 weight_quant.min = 0
1353 weight_quant.max = 255
1354 weight_quant.scale_f32 = weight_scale
1355 weight_quant.zero_point = 0
1356
1357 # Set weight shape to [H,W,C,B]
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001358 weight_shape = [h, w, shape[3], shape[0]]
1359
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001360 # Add unit weight tensor
1361 op.set_input_tensor(
1362 create_const_tensor(
1363 "weights",
1364 weight_shape,
1365 inp.dtype,
1366 np.ones(weight_shape),
1367 value_dtype=np.uint8,
1368 quantization=weight_quant,
1369 ),
1370 1,
1371 )
James Peet7519d502021-07-19 16:47:58 +01001372 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001373
1374 # Add None bias tensor
1375 op.inputs.append(None)
1376 # Add bias tensor
1377 if bias:
1378 bias_shape = [shape[-1]]
1379 op.set_input_tensor(
1380 create_const_tensor(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001381 "bias",
1382 bias_shape,
1383 inp.dtype,
1384 np.ones(bias_shape) * bias,
1385 value_dtype=np.int32,
1386 quantization=None,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001387 ),
1388 2,
1389 )
1390
1391 return op
1392
1393
1394def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001395 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001396 return op
1397
1398
1399def tflite_optimise_graph(nng, arch):
1400 # Pre-processing step
1401 pre_process_list = [
1402 supported_operator_check,
1403 set_ifm_ofm_op_shapes,
1404 ]
1405
1406 for idx, sg in enumerate(nng.subgraphs):
1407 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001408 nng,
1409 sg,
1410 arch,
1411 [],
1412 pre_process_list,
1413 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001414 )
1415
1416 # Handle Concat Ops
1417 for idx, sg in enumerate(nng.subgraphs):
1418 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1419 sg.refresh_after_modification()
1420
1421 # Handle Split Ops
1422 for idx, sg in enumerate(nng.subgraphs):
1423 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1424 nng,
1425 sg,
1426 arch,
1427 [],
1428 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1429 rewrite_unsupported=False,
1430 )
1431
1432 for idx, sg in enumerate(nng.subgraphs):
1433 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001434 nng,
1435 sg,
1436 arch,
1437 [rewrite_split_ops],
1438 [],
1439 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001440 )
1441
1442 # Handle sg input output
1443 for idx, sg in enumerate(nng.subgraphs):
1444 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001445 nng,
1446 sg,
1447 arch,
1448 [],
1449 [fix_sg_input_output],
1450 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001451 )
1452
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001453 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001454 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001455 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001456 sg.refresh_after_modification()
1457
1458 # Rewrite of operators
1459 op_rewrite_list = [
1460 set_tensor_equivalence,
1461 convert_mean_to_depthwise_conv_or_avgpool,
1462 convert_depthwise_to_conv,
1463 convert_conv_to_fc,
1464 convert_softmax,
1465 optimise_strided_conv,
1466 convert_hardswish_to_lut,
1467 rewrite_fully_connected_input,
1468 convert_batched_fc_shape,
1469 fixup_conv2d_backprop,
1470 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001471 reorder_depthwise_weights,
1472 fixup_resizebilinear,
1473 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001474 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001475 convert_mul_max_to_abs_or_lrelu,
1476 convert_lrelu,
1477 convert_tanh_sigmoid_to_lut,
1478 replace_pad_by_hw_pad,
1479 ]
1480
1481 for idx, sg in enumerate(nng.subgraphs):
1482 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001483 nng,
1484 sg,
1485 arch,
1486 [],
1487 op_rewrite_list,
1488 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001489 )
1490
1491 for idx, sg in enumerate(nng.subgraphs):
1492 # remove passthrough tensors and attempt further optimizations
1493 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1494 nng,
1495 sg,
1496 arch,
1497 [remove_passthrough_tensor],
1498 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1499 )
1500
1501 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1502 # since ifm/ofm_shapes are of importance to this function
1503 for sg in nng.subgraphs:
1504 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1505 sg.refresh_after_modification()
1506
1507 return nng