blob: 6b454e3de1b876b949983a992637b44bb28ef0e4 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from . import rewrite_graph
26from . import scaling
27from .api import NpuRoundingMode
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020028from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020029from .data_type import DataType
30from .debug_database import DebugDatabase
31from .errors import UnsupportedFeatureError
32from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020033from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020034from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020035from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020036from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020037from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020038from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020039from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020040from .graph_optimiser_util import needed_total_padding
41from .graph_optimiser_util import set_ifm_ofm_op_shapes
42from .graph_optimiser_util import set_tensor_equivalence
43from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020044from .numeric_util import round_away_zero
Johan Alfvén17009392022-08-30 09:14:56 +020045from .numeric_util import round_up_to_int
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020046from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020047from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020048from .operation import NpuBlockType
49from .operation import Op
50from .operation import Operation
51from .operation import Padding
52from .operation_util import create_avgpool_nop
53from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010054from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020055from .shape4d import Shape4D
56from .softmax import SoftMax
57from .tensor import check_quantized_tens_scaling_equal
58from .tensor import create_const_tensor
59from .tensor import create_equivalence_id
60from .tensor import QuantizationParameters
61from .tensor import Tensor
62from .tensor import TensorPurpose
63from .tflite_mapping import optype_to_builtintype
64
65passthrough_nodes = (Op.Identity,)
66
67
68def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
69 """Creates an average pool for the given concat op/input feature map"""
70 ofm = concat_op.ofm
71 avgpool_op = create_avgpool_nop(name)
72 avgpool_op.inputs = [ifm]
73 avgpool_op.outputs = [ofm]
74
75 avgpool_op.write_offset = write_offset
76 avgpool_op.write_shape = ifm_shape
77 ofm.ops.append(avgpool_op)
78 DebugDatabase.add_optimised(concat_op, avgpool_op)
79 avgpool_op.ifm_shapes.append(ifm_shape)
80 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
81 avgpool_op.memory_function = Op.ConcatSliceWrite
82 return avgpool_op
83
84
85def remove_passthrough_tensor(tens, arch, nng):
86 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
87 assert len(tens.ops[0].inputs) == 1
88 tens = tens.ops[0].inputs[0]
89 return tens
90
91
92def rewrite_concat_ops(op, arch):
93 if not op.run_on_npu or not op.type.is_concat_op():
94 return
95
96 axis_4D = 0
97 ofm = op.ofm
98 ofm.ops = []
99 offset = 0
100
101 unfuse_activation_function(op)
102
103 if op.type == Op.Pack:
104 # Pack is also referred to as Stack
105 axis = int(op.attrs["axis"])
106 if axis < 0: # Convert to positive axis
107 axis = len(op.inputs[0].shape) + 1 + axis
108
109 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
110
111 axis_4D = axis + (4 - len(desired_shape))
112
113 for idx, inp in enumerate(op.inputs):
114 op.ifm_shapes[idx] = Shape4D(desired_shape)
115 op.type = Op.PackReshaped
116
117 inputs, axis = op.get_concat_inputs_axis()
118 for idx, inp in enumerate(inputs):
119 if op.type != Op.PackReshaped:
120 op.ifm_shapes[idx] = Shape4D(inp.shape)
121 if axis >= 0:
122 axis_4D = axis + (4 - len(inp.shape))
123 else:
124 axis_4D = axis
125 write_offset = [0, 0, 0, 0]
126 write_offset[axis_4D] = offset
127 concat_end = offset + op.ifm_shapes[idx][axis_4D]
128 create_avg_pool_for_concat(
129 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
130 )
131 offset = concat_end
132 assert ofm.shape[axis] == offset
133
134 return op
135
136
137def rewrite_split_ops(tens, arch, nng):
138
139 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
140 split_op = tens.ops[0]
141
142 # Not supported so leave it and run on CPU
143 if not split_op.run_on_npu:
144 return tens
145
146 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
147
148 tens.ops = []
149 new_op = Operation(Op.SplitSliceRead, split_op.name)
150 new_op.inputs = [inp]
151 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000152 if None in (offset_end, offset_start):
153 read_shape = None
154 else:
155 # the read shape is relative to each start offset
156 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200157
158 # For Split the offset cannot be extracted from the tensor so it has to
159 # be calculated from the index of the output tensor
160 if axis is not None:
161 # Get the start and end of the split
162 offset_start = [0] * 4
163 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
164 for idx, out in enumerate(outputs):
165 if axis_4D_list is not None:
166 axis_4D = axis_4D_list[idx]
167 else:
168 split_op.ofm_shapes[idx] = Shape4D(out.shape)
169 if axis >= 0:
170 axis_4D = axis + (4 - len(out.shape))
171 else:
172 axis_4D = axis
173
174 if out == tens:
175 ofm_shape_idx = idx
176 read_shape = split_op.ofm_shapes[idx]
177 break
178
179 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
180
181 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
182 new_op.read_shapes[0] = read_shape
183 new_op.run_on_npu = True
184 new_op.set_output_tensor(tens)
185 new_op.ifm_shapes.append(Shape4D(inp.shape))
186 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
187 DebugDatabase.add_optimised(split_op, new_op)
188
189 return tens
190
191
192def remove_SplitSliceRead(op, arch):
193
194 if op.type == Op.SplitSliceRead:
195 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
196 if (
197 len(op.ofm.consumer_list) == 1
198 and op.ofm.consumer_list[0] is not None
199 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200200 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200201 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
202 ):
203 # SplitSliceRead can be performed by tensor consumer
204 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200205 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200206 else:
207 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
208 avgpool_op.add_input_tensor(op.ifm)
209 avgpool_op.outputs = [op.ofm]
210 op.ofm.ops.remove(op)
211 op.ofm.ops.append(avgpool_op)
212 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
213 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
214 avgpool_op.read_offsets[0] = op.read_offsets[0]
215 avgpool_op.read_shapes[0] = op.read_shapes[0]
216
217 op.ifm.consumer_list.remove(op)
218 DebugDatabase.add_optimised(op, avgpool_op)
219
220
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200221def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
222 k_w, k_h = kernel.dilated_wh()
223 s_x, s_y = kernel.stride
224 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
225 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
226 if padding_type == Padding.SAME:
227 left_pad = (xpad + 0) // 2
228 right_pad = (xpad + 1) // 2
229 top_pad = (ypad + 0) // 2
230 bottom_pad = (ypad + 1) // 2
231 elif padding_type == Padding.VALID:
232 left_pad = 0
233 right_pad = 0
234 top_pad = 0
235 bottom_pad = 0
236 elif padding_type == Padding.EXPLICIT:
237 # Padding is specified in a PAD operator which has been bypassed.
238 top, left, bottom, right = explicit_padding
239 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
240 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Rickard Bolin9ae34552022-06-09 13:07:17 +0000241 elif padding_type == Padding.TILE:
242 # The values in the explicit padding only represent the "direction" in which to pad
243 top_pad, left_pad, bottom_pad, right_pad = explicit_padding
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200244 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000245 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200246 padding = (top_pad, left_pad, bottom_pad, right_pad)
247 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
248 return padding, skirt
249
250
251def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
252 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
253 if padding_type == Padding.SAME:
254 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
255 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
256 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
257 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
258 left_pad = max(kernel_width - 1 - right_pad, 0)
259 top_pad = max(kernel_height - 1 - bottom_pad, 0)
260 elif padding_type == Padding.VALID:
261 right_pad = max(kernel_width - 2, 0)
262 bottom_pad = max(kernel_height - 2, 0)
263 left_pad = kernel_width - 1
264 top_pad = kernel_height - 1
265 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000266 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200267 padding = (top_pad, left_pad, bottom_pad, right_pad)
268 skirt = padding
269 return padding, skirt
270
271
272def fixup_conv2d_backprop(op, arch, nng):
273 if op.type == Op.Conv2DBackpropInput:
274 # flip the inputs
275 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
276 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000277 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200278
279 # Update strides
280 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
281
282 return op
283
284
285# Convert the op to an elementwise add
Tim Hall885033b2022-07-21 11:46:03 +0100286def convert_resize_1x1_to_add(op):
287 op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200288 op.name = op.name + "_add"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200289 # Create an input tensor filled with zeros
290 shape = op.ofm_shapes[0].as_list()
291 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100292 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200293 tens.quantization = QuantizationParameters(0.0, 255.0)
294 tens.quantization.scale_f32 = 1.0
295 tens.quantization.zero_point = 0
296 tens.consumer_list = [op]
297 tens_op = op.inputs[1].ops[0]
298 tens_op.set_output_tensor(tens)
299 # Set the add inputs
300 op.inputs[1] = op.inputs[0]
301 op.inputs[0] = tens
302 op.set_ifm_ofm_shapes()
303
304 return op
305
306
Tim Hall885033b2022-07-21 11:46:03 +0100307# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
308# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
309# to select the appropriate nearest neighbor value
310def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
311 ifm = op.ifm
312 ofm = op.ofm
313 output_depth = ofm.shape[-1]
314 dw_op_attrs = {
315 "padding": Padding.VALID,
316 "stride_h": 1,
317 "stride_w": 1,
318 "strides": (1, 1, 1, 1),
319 "depth_multiplier": 1,
320 "channel_multiplier": 1,
321 "dilation_h_factor": 1,
322 "dilation_w_factor": 1,
323 "dilation": (1, 1, 1, 1),
324 }
325
326 # change resizebilinear to depthwise
327 op.type = Op.DepthwiseConv2DBias
328 op.attrs.update(dw_op_attrs)
329 op.set_input_tensor(ifm, 0) # ifm tensor index
330 op.activation = None
331
332 # add input resample to resize by x2
333 op.ifm_resampling_mode = resampling_mode.NEAREST
334
335 # don't care about the rounding mode as it is nearest neighbor
336
337 # setup weight tensor
338 weight_quant = QuantizationParameters()
339 weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
340 weight_quant.zero_point = 0
341 weight_quant.quant_dim = 0
342 ofm_dtype = ofm.dtype
343 if ofm_dtype == DataType.uint8:
344 weight_value_dtype = np.uint8
345 weight_quant.quant_min = 0
346 weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
347 else:
348 if ofm_dtype == DataType.int8:
349 weight_value_dtype = np.int8
350 else:
351 assert ofm_dtype == DataType.int16
352 weight_value_dtype = np.int16
353
354 weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
355 weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
356
357 weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
358
359 # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
360 # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
361 # below-and-right (i.e. next) to it (D).
362 # 0---1---2
363 # | A | B |
364 # 1---*---+
365 # | C | D |
366 # 2---+---+
367 weight_values = [0] * (upscale_factor * upscale_factor)
368 centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
369 weight_values[centre_coeff] = 1
370
371 # add weight tensor, this will discard the size tensor of the resize op
372 op.set_input_tensor(
373 create_const_tensor(
374 "weights",
375 weight_shape,
376 ofm.dtype,
377 np.array(weight_values).reshape(weight_shape),
378 value_dtype=weight_value_dtype,
379 quantization=weight_quant,
380 ),
381 1, # inputs tensor weight index
382 )
383
384 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
385 # need to append the bias tensor as resize ops only have 2 inputs
386 assert len(op.inputs) == 2
387 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +0200388 fixup_bias_tensors(op, None, None, DataType.int32)
Tim Hall885033b2022-07-21 11:46:03 +0100389
390 # finally update the shape incase we've change the tensor shapes or connections
391 op.set_ifm_ofm_shapes()
392
393 return op
394
395
396# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
397# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
398# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
399def convert_resize_to_upscale_and_average_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200400 pre_op = op
401 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000402 dtype = op.ifm.dtype
Tim Hall885033b2022-07-21 11:46:03 +0100403
Rickard Boline546def2022-01-25 15:45:00 +0000404 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Tim Hall47c76362022-07-18 21:26:47 +0100405 op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
Tim Hall3c5cfe92022-03-16 16:31:57 +0000406 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200407
408 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
Tim Hall47c76362022-07-18 21:26:47 +0100409
410 # Get upscale factor that was calculated in the supported operators check
411 upscale_factor = op.attrs["upscale_factor"]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200412
Rickard Boline546def2022-01-25 15:45:00 +0000413 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100414 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
415 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
Rickard Boline546def2022-01-25 15:45:00 +0000416 n = int(np.log2(upscale_factor))
417
Tim Hall885033b2022-07-21 11:46:03 +0100418 # Perform x2 upscaling n-1 times
Rickard Boline546def2022-01-25 15:45:00 +0000419 scaled_op = pre_op
420 for count in range(n - 1):
421 if count > 0:
422 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200423 scaled_op.inputs[0] = pre_op.outputs[0]
424
Tim Hall885033b2022-07-21 11:46:03 +0100425 # Nearest neighbor x2 upscaling
Tim Hall47c76362022-07-18 21:26:47 +0100426 upscaled_shape = upscaled_shape * 2
Rickard Boline546def2022-01-25 15:45:00 +0000427 shape = op.ofm_shapes[0].as_list()
428 shape[1:3] = upscaled_shape
429 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
430 out_tens.quantization = op.outputs[0].quantization.clone()
431 scaled_op.set_output_tensor(out_tens)
432 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200433
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200434 scaled_op.set_ifm_ofm_shapes()
435
Tim Hall885033b2022-07-21 11:46:03 +0100436 # Last x2 upscaling
Rickard Boline546def2022-01-25 15:45:00 +0000437 if n > 1:
438 scaled_op = op.clone(f"_{n-1}")
439 scaled_op.inputs[0] = pre_op.outputs[0]
Tim Hall885033b2022-07-21 11:46:03 +0100440
441 if scaled_op.original_type == Op.ResizeBilinear:
442 if scaled_op.attrs["align_corners"]:
443 # no padding
444 scaled_op.attrs["padding"] = Padding.VALID
445 else:
446 # padding to the right and bottom (limits average pool to 8x8 kernel)
447 scaled_op.attrs["padding"] = Padding.EXPLICIT
448 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
449
450 # kernal size dependent on the upscaling factor
451 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
452 else: # Op.ResizeNearestNeighbor
453 if scaled_op.attrs["align_corners"]:
454 # use depthwise conv to select the correct value
455 scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
456 else:
457 # keep 1x1 kernel and average pool
458 pass
459
Rickard Boline546def2022-01-25 15:45:00 +0000460 scaled_op.outputs = outputs
461 scaled_op.outputs[0].ops = [scaled_op]
462 scaled_op.set_ifm_ofm_shapes()
463
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200464 return op
465
466
Tim Hall885033b2022-07-21 11:46:03 +0100467def fixup_resize(op, arch, nng):
468 if op.type.is_resize_op() and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200469 if op.ifm_shapes[0] == op.ofm_shapes[0]:
Tim Hall885033b2022-07-21 11:46:03 +0100470 # Bypass the resize op which is essentially a NOP
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200471 op.inputs = op.inputs[:1]
472 op.type = Op.Identity
473 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
Tim Hall885033b2022-07-21 11:46:03 +0100474 convert_resize_1x1_to_add(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200475 else:
Tim Hall885033b2022-07-21 11:46:03 +0100476 convert_resize_to_upscale_and_average_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200477
478 return op
479
480
481def convert_nop_split_to_identity(op, arch, nng):
482 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
483 # the list comprehension should return a list with a single tensor
484 # if it shouldn't, remove_passthrough_tensor will fail appropriately
485 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
486 op.type = Op.Identity
487 return op
488
489
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100490def rewrite_fully_connected_input(op: Operation, arch, nng):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200491
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100492 if op.type == Op.FullyConnected:
493 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
494 assert new_shape is not None, "Tensor can not be reshaped to 2D"
495 op.ifm_shapes[0] = new_shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200496 return op
497
498
499def convert_batched_fc_shape(op, arch, nng):
500 if op.type == Op.FullyConnected:
501 # Check if the first dimension indicates batching
502 if op.ifm_shapes[0].batch > 1:
503 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
504 n = op.ifm_shapes[0].batch
505 h, w = batching_split.get(n, (1, n))
506 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
507
508 # Reshape Weights to be 4D. IO becomes HWIO
509 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100510 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
511 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200512
513 n = op.ofm_shapes[0].batch
514 h, w = batching_split.get(n, (1, n))
515 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
516 return op
517
518
519def unfuse_activation_function(op):
520 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
521 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
522 op.activation = None
523 out_tens = op.outputs[0]
524 intermediate_tens = out_tens.clone("_act_intermediate")
525 act_op.set_output_tensor(out_tens)
526 act_op.add_input_tensor(intermediate_tens)
527 op.set_output_tensor(intermediate_tens)
528 act_op.set_ifm_ofm_shapes()
529
530
531def rewrite_stridedslice_output(op, arch, nng):
532 if not op.run_on_npu or op.type != Op.StridedSlice:
533 return op
534
535 new_axis_mask = op.attrs["new_axis_mask"]
536 shrink_axis_mask = op.attrs["shrink_axis_mask"]
537
538 if shrink_axis_mask == 0 and new_axis_mask == 0:
539 return op
540
541 axis_4D = [0] * len(op.outputs)
542 for idx, out_tens in enumerate(op.outputs):
543 output_shape = list(out_tens.shape)
544
545 if shrink_axis_mask != 0:
546 n = 0
547 axis = 0
548 while shrink_axis_mask:
549 prev_mask = shrink_axis_mask
550 n += 1
551 shrink_axis_mask &= shrink_axis_mask - 1
552 axis = int(math.log2(prev_mask - shrink_axis_mask))
553 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
554
555 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
556 op.attrs["shrink_axis_mask"] = 0
557 if axis >= 0:
558 axis_4D[idx] = axis + (4 - len(output_shape))
559 else:
560 axis_4D[idx] = axis
561 op.ofm_shapes[idx] = Shape4D(output_shape)
562
563 elif new_axis_mask != 0:
564 n = 0
565 axis = 0
566 while new_axis_mask:
567 prev_mask = new_axis_mask
568 n += 1
569 new_axis_mask &= new_axis_mask - 1
570 axis = int(math.log2(prev_mask - new_axis_mask))
571 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
572 new_axis_mask >>= 1
573
574 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
575 op.attrs["new_axis_mask"] = 0
576 if axis >= 0:
577 axis_4D[idx] = axis + (4 - len(output_shape))
578 else:
579 axis_4D[idx] = axis
580 op.ofm_shapes[idx] = Shape4D(output_shape)
581
582 op.attrs["split_axis_4D"] = axis_4D
583 return op
584
585
586def rewrite_unpack_output(op, arch, nng):
587 tens = op.outputs[0]
588 if op.run_on_npu and op.type == Op.Unpack:
589 # Unpack is also referred to as Unstack
590 axis = int(op.attrs["axis"])
591 if axis < 0: # Convert to positive axis
592 axis = len(op.inputs[0].shape) + 1 + axis
593 op.type = Op.UnpackReshaped
594 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
595
596 axis_4D = axis + (4 - len(desired_output_shape))
597 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
598
599 for idx, out_tens in enumerate(op.outputs):
600 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
601 return op
602
603
604def add_padding_fields(op, arch, nng):
605 if op.run_on_npu:
606 if "padding" in op.attrs:
607 input_shape = op.ifm_shapes[0]
608 output_shape = op.ofm_shapes[0]
609 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
610 kernel_size = op.inputs[1].shape[:2]
611 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
612 kernel_size = op.attrs["ksize"][1:3]
613 else:
614 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
615
616 if op.type == Op.Conv2DBackpropInputSwitchedBias:
617 upscaling_factor = output_shape.height // input_shape.height
618 padding, skirt = calc_upscaled_padding_and_skirt(
619 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
620 )
621 else:
622 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200623 op.attrs["padding"],
624 op.kernel,
625 input_shape,
626 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200627 )
628
629 op.attrs["explicit_padding"] = padding
630 op.attrs["skirt"] = skirt
631
632 return op
633
634
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200635def reorder_depthwise_weights(op, arch, nng):
636 if op.type.is_depthwise_conv2d_op():
637 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100638 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
639 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200640 weight_tensor.weight_transpose_depthwise = True
641
642 return op
643
644
645def optimise_strided_conv(op, arch, nng):
Louis Verhaard43d27582022-03-17 14:06:00 +0100646 if op.type != Op.Conv2DBias or op.op_index != 0:
647 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200648 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100649 weight_tensor = op.weights
650 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200651
652 if (
Louis Verhaard43d27582022-03-17 14:06:00 +0100653 stride_x == 2
654 and ifm_shape.depth <= 4
655 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200656 and weight_tensor is not None
657 and weight_tensor.shape[1] >= 2
658 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100659 k_w, _ = op.get_kernel_size()
660 curr_padding_x = needed_total_padding(ifm_shape.width, 2, k_w)
661 optimised_padding_x = needed_total_padding(ifm_shape.width // 2, 1, (k_w + 1) // 2)
662 if curr_padding_x != optimised_padding_x:
663 # Horizontal padding would become different after optimisation; this would not work
664 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200665 # IFM
666 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
667
668 # Weights
669 weight_shape = weight_tensor.shape
670 if weight_shape[1] % 2 != 0:
671 weight_shape[1] = weight_shape[1] + 1
672 padded_array = np.zeros(weight_shape)
673 for i in range(weight_shape[0]):
674 padded_array[i] = np.vstack(
675 [
James Peet7519d502021-07-19 16:47:58 +0100676 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200677 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
678 ]
679 )
James Peet7519d502021-07-19 16:47:58 +0100680 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200681 weight_shape[1] //= 2
682 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100683 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200684 weight_tensor.set_all_shapes(weight_shape)
685 # If multiple copies of the weights are used, we could avoid
686 # them having the same address by changing the value_id
687 weight_tensor.value_id = uuid.uuid4()
688
689 # Strides
690 stride_x = 1
691 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
692
693 return op
694
695
696def convert_conv_to_fc(op, arch, nng):
697 # Conv 1x1 can be equivalent to Fully Connected.
698 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
699 # caching/double buffering for the weights.
700 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
701 if op.type == Op.Conv2DBias:
702 h = op.ifm_shapes[0].height
703 w = op.ifm_shapes[0].width
704 kh, kw, _, _ = op.inputs[1].shape
705 if h == 1 and w == 1 and kh == 1 and kw == 1:
706 # Overwrite this op as a Fully Connected Op
707 op.name += "_fc"
708 op.type = Op.FullyConnected
709 op.attrs = {
710 "weights_format": 0,
711 }
712 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
713 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100714 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
715 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200716
717 DebugDatabase.add_optimised(op, op)
718 return op
719
720
721def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
722 if op.run_on_npu and op.type.is_relu_op():
723 ifm = op.inputs[0]
724 ofm = op.outputs[0]
725 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
726 # and requires its own to be inserted
727 if not check_quantized_tens_scaling_equal(ifm, ofm):
728 # Override this op with its own primary op (avgpool)
729 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
730 # And fuse the original activation function to it
731 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200732 # Add explicit rescaling
733 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
734 multiplier, shift = scaling.quantise_scale(rescale)
735 relu_fused_op.rescale = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200736 # Tidy up and assign the ifm and ofm to the new op
737 ifm.consumer_list.remove(op)
738
739 relu_fused_op.add_input_tensor(ifm)
740 relu_fused_op.set_output_tensor(ofm)
741 relu_fused_op.set_ifm_ofm_shapes()
742 op = relu_fused_op
743 return op
744
745
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200746def convert_softmax(op, arch, nng):
747 if op.type == Op.Softmax and op.run_on_npu:
748 softmax = SoftMax(op)
749 op = softmax.get_graph()
750 return op
751
752
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200753def convert_prelu(op, arch, nng):
754 if op.type == Op.Prelu:
755 ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
756 if None in (ifm, alpha, ofm):
757 return op
758
Fredrik Svedberg66591652022-08-29 10:51:27 +0200759 if alpha.values is not None:
760 # If const alpha check for possible optimisations
761 alpha_zp = alpha.quantization.zero_point
762 alpha_scale = alpha.quantization.scale_f32
763 # If all alpha values are the same the PReLU can be converted to LeakyRelu
764 alpha_min = (alpha.values.min().astype(np.int) - alpha_zp) * alpha_scale
765 alpha_max = (alpha.values.max().astype(np.int) - alpha_zp) * alpha_scale
766 if alpha_min == alpha_max:
767 # or even a Relu
768 if alpha_min == 0:
769 new_op = Op.Relu
770 else:
771 new_op = Op.LeakyRelu
772 op.attrs["alpha"] = alpha_min
773 # setup alpha_scaling for bit exact result
774 ifm_scale = ifm.quantization.scale_f32
775 ofm_scale = ofm.quantization.scale_f32
776 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha_scale, ofm_scale)
777 op.attrs["alpha_scaling"] = (alpha.values.min() - alpha_zp, alpha_scale, alpha_shift)
778 # Change op type
779 op.type = new_op
780 op.name = op.name.replace("Prelu", new_op.name)
781 del op.inputs[1] # Remove alpha tensor
782 return op
783 elif alpha_max < 1:
784 # If alpha_max is less than 1 convert PReLU to Max(alpha * IFM, identity * IFM)
785 # Multiply with alpha tensor
786 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
787 mul_alpha.add_input_tensor(ifm)
788 mul_alpha.add_input_tensor(alpha)
789 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
790 mul_alpha.set_output_tensor(fm_alpha)
791 mul_alpha.set_ifm_ofm_shapes()
792 DebugDatabase.add_optimised(op, mul_alpha)
793 if check_quantized_tens_scaling_equal(ifm, ofm):
794 # No scaling is needed
795 fm_id = ifm
796 else:
797 # Add multiplication with identity
798 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
799 mul_identity.add_input_tensor(ifm)
800 # Create const tensor containing identity as scalar
801 quantization = ifm.quantization.clone()
802 quantization.scale_f32 = np.float32(1)
803 quantization.zero_point = 0
804 one = create_const_tensor("one_const", [], ifm.dtype, [1], quantization=quantization)
805 mul_identity.add_input_tensor(one)
806 # Make sure that fm_id is allocated to a different address than fm_alpha
807 fm_id = ofm.clone(op.name + "_id", set_unique=True)
808 mul_identity.set_output_tensor(fm_id)
809 mul_identity.set_ifm_ofm_shapes()
810
811 # Combine scaled and alpha multiplied values
812 max_op = Operation(Op.Maximum, op.name + "_max")
813 max_op.add_input_tensor(fm_alpha)
814 max_op.add_input_tensor(fm_id)
815 max_op.set_output_tensor(ofm)
816 max_op.set_ifm_ofm_shapes()
817
818 DebugDatabase.add_optimised(op, max_op)
819 ifm.consumer_list.remove(op)
820 return max_op
821
822 # Catch all PReLU conversion for the cases that could not be optimised above
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200823 no_scale_quant = ifm.quantization.clone()
824 no_scale_quant.scale_f32 = None
825 no_scale_quant.zero_point = 0
Fredrik Svedberg66591652022-08-29 10:51:27 +0200826 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200827
828 # Select values < 0
829 min_op = Operation(Op.Minimum, op.name + "_min")
830 min_op.add_input_tensor(ifm)
831 min_op.add_input_tensor(zero)
832 fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
833 min_op.set_output_tensor(fm_negative)
834 min_op.set_ifm_ofm_shapes()
835 DebugDatabase.add_optimised(op, min_op)
836
837 # and multiply with alpha tensor
838 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
839 mul_alpha.add_input_tensor(fm_negative)
840 mul_alpha.add_input_tensor(alpha)
841 fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
842 mul_alpha.set_output_tensor(fm_alpha)
843 mul_alpha.set_ifm_ofm_shapes()
844 DebugDatabase.add_optimised(op, mul_alpha)
845
846 # Select (and scale) values > 0
847 relu_op = Operation(Op.Relu, op.name + "_relu")
848 relu_op.add_input_tensor(ifm)
849 fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
850 relu_op.set_output_tensor(fm_scaled)
851 relu_op.set_ifm_ofm_shapes()
852 DebugDatabase.add_optimised(op, relu_op)
853
854 # Add scaled and alpha multiplied values (without scaling)
855 add_op = Operation(Op.RescaleAdd, op.name + "_add")
856 add_op.rescale = (1, 0) # No scale or shift
857 add_op.add_input_tensor(fm_alpha)
858 add_op.add_input_tensor(fm_scaled)
859 add_op.set_output_tensor(ofm)
860 add_op.set_ifm_ofm_shapes()
861
862 DebugDatabase.add_optimised(op, add_op)
863 ifm.consumer_list.remove(op)
864 op = add_op
865
866 return op
867
868
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200869def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
870 r"""Whenever there is a subgraph with this topology:
871
Jonas Ohlssond8575072022-03-30 10:30:25 +0200872 Input X For X = -1 or X > 0
873 | \ / This subgraph can be replaced with either
874 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
875 | /
876 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200877 """
878
879 if op.type == Op.Maximum:
880 # finds the Mul input(s) to the Max
881 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
882 if len(muls) == 1:
883 mul = muls[0].ops[0]
884 elif len(muls) == 2:
885 # In the case both inputs are Muls, find the one with the same input as the Max
Fredrik Svedberg66591652022-08-29 10:51:27 +0200886 mul_ifms = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1]
887 if len(mul_ifms):
888 mul = mul_ifms[0].ops[0]
889 else:
890 # Not using same input
891 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200892 else:
893 # No Mul inputs
894 return op
895
896 # make sure the Mul doesn't have any other consumers
897 mul_ofm = mul.outputs[0]
898 if len(mul_ofm.consumers()) != 1:
899 return op
900 # make sure the Mul doesn't have a fused activation function
901 if mul.activation:
902 return op
903 ifm, ofm = op.get_ifm_ofm()
904 if ifm is None or ofm is None:
905 return op
906
907 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
908 return op
909 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
910 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
911 return op
912
913 # finds the branched input that goes to both the Max and the Mul
914 shared = set(op.inputs) & set(mul.inputs)
915 if len(shared) == 1:
916 shared_in = shared.pop()
917 # find the constant scalar input to the Mul
918 const_tens = (set(mul.inputs) - {shared_in}).pop()
919 # check that it is a scalar
920 if const_tens.shape != []:
921 return op
922 const = const_tens.ops[0]
923 # check that it is a constant
924 if const.type != Op.Const:
925 return op
926 # Remove the Mul from the shared input's consumers
927 shared_in.consumer_list.remove(mul)
928 else:
929 return op
930
931 val = const.outputs[0].values
932 if val >= 0:
933 new_op = Op.LeakyRelu
934 op.attrs["alpha"] = val
935 # to produce bit exact results, the alpha is not enough;
936 # save additional scaling info in attr "alpha_scale", to be used as input
937 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +0100938 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200939 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
940 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
941 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
942 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
943 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
944 elif val == -1:
945 new_op = Op.Abs
946 else:
947 return op
948
949 op.type = new_op
950 op.name = op.name.replace("Maximum", new_op.name)
951 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
952 op.inputs = [shared_in]
953 op.set_ifm_ofm_shapes()
954
955 # Record optimisation in debug database
956 DebugDatabase.add_optimised(op, op)
957
958 return op
959
960
961def convert_hardswish_to_lut(op, arch, nng):
962 if op.type == Op.HardSwish:
963 ifm, ofm = op.get_ifm_ofm()
964 # Generate the LUT
965 ifm_scale = np.double(ifm.quantization.scale_f32)
966 ofm_scale = np.double(ofm.quantization.scale_f32)
967 zp_in = ifm.quantization.zero_point
968 zp_out = ofm.quantization.zero_point
969 ifm_scale_hires = (1 / 128) * ifm_scale
970 relu_multiplier = np.double(3 / 32768)
971 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
972 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
973 # Use 16bit scale
974 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
975 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
976
977 values = []
978 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
979 quantized_min = min(ix)
980 quantized_max = max(ix)
981 for x in ix:
982 input_value = x - zp_in
983 input_value_hires = input_value * 128
984 # Compute the input value on essentially the output scale, not shifted yet
985 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
986 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
987 relu_value = np.int16(input_value_hires)
988 if relu_shift < 31:
989 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
990
991 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
992
993 if relu_shift < 31:
994 relu_value = fp_math.shift_left16(relu_value, 1)
995
996 if relu_shift > 31:
997 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
998
999 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
1000 # Now convert that to a 16bit fixedpoint value in [0, 1]
1001 relu_value = (relu_value + (1 << 15)) >> 1
1002 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
1003 shift = 31 - out_shift
1004 shift = -shift if shift < 0 else 0
1005 # Finally apply the output shift
1006 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
1007 lut_result = min(quantized_max, max(quantized_min, lut_result))
1008 values.append(lut_result)
1009 return convert_to_lut(op, values, "hardswish")
1010 return op
1011
1012
1013def convert_lrelu_to_mul_max(op, arch):
1014 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
1015 # (the opposite of convert_mul_max_to_abs_or_lrelu)
1016 ifm, ofm = op.get_ifm_ofm()
1017 if ifm is None or ofm is None:
1018 return op
1019
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001020 alpha = np.float32(op.attrs["alpha"])
1021 use_mul_max = 0 < alpha < 1
Fredrik Svedberg36424312022-09-16 09:39:26 +02001022 is_converted_prelu = "alpha_scaling" in op.attrs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001023 if use_mul_max:
1024 mul_ifm = ifm
1025 new_op = Op.Maximum
1026 else:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001027 # Need to use a different approach for alpha < 0 or alpha > 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001028 no_scale_quant = ifm.quantization.clone()
1029 no_scale_quant.scale_f32 = None
1030 no_scale_quant.zero_point = 0
1031 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
1032
1033 # Select values < 0
1034 min_op = Operation(Op.Minimum, op.name + "_min")
1035 min_op.add_input_tensor(ifm)
1036 min_op.add_input_tensor(zero)
1037 mul_ifm = ifm.clone(op.name + "_negative", set_unique=True)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001038 if alpha < 0 and not is_converted_prelu:
1039 # For negative alpha that is not from a converted PReLU we need to use
1040 # int32 Mul below to perform the (negative) alpha scaling
1041 mul_ifm.dtype = DataType.int32
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001042 min_op.set_output_tensor(mul_ifm)
1043 min_op.set_ifm_ofm_shapes()
1044 new_op = Op.RescaleAdd
1045 op.rescale = (1, 0) # No scale or shift
1046 DebugDatabase.add_optimised(op, min_op)
1047
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001048 # Add multiplication with alpha
1049 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001050 mul_alpha.add_input_tensor(mul_ifm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001051 # Create const tensor containing alpha as scalar
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001052 quantization = ifm.quantization.clone()
1053 quantization.min = 0
1054 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
1055 quantization.zero_point = 0
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001056 alpha_dtype = mul_ifm.dtype
Fredrik Svedberg36424312022-09-16 09:39:26 +02001057 if is_converted_prelu:
1058 # The LeakyRelu was the result from convert_prelu and the scaling is provided
Fredrik Svedberg66591652022-08-29 10:51:27 +02001059 scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1060 mul_alpha.type = Op.RescaleMul
1061 mul_alpha.rescale = [alpha_scale, alpha_shift]
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001062 elif alpha == 0 or np.isinf(1 / alpha):
1063 # Handling of alpha near or at zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001064 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001065 scalar = 0
1066 else:
1067 quantization.scale_f32 = alpha
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001068 if alpha_dtype == DataType.int32:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001069 # When the datatype is int32 (alpha negative) we need to do the scaling with the multiplication
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001070 scalar, _ = scaling.elementwise_mul_scale(ifm.quantization.scale_f32, alpha, ofm.quantization.scale_f32)
1071 else:
1072 scalar = 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001073 alpha_tens = create_const_tensor(
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001074 op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], alpha_dtype.as_numpy_type(), quantization=quantization
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001075 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001076 mul_alpha.add_input_tensor(alpha_tens)
1077 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1078 mul_alpha.set_output_tensor(fm_alpha)
1079 mul_alpha.set_ifm_ofm_shapes()
1080 DebugDatabase.add_optimised(op, mul_alpha)
1081
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001082 if not use_mul_max:
1083 relu_op = Operation(Op.Relu, op.name + "_relu")
1084 relu_op.add_input_tensor(ifm)
1085 fm_id = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1086 relu_op.set_output_tensor(fm_id)
1087 relu_op.set_ifm_ofm_shapes()
1088 DebugDatabase.add_optimised(op, relu_op)
1089 elif check_quantized_tens_scaling_equal(ifm, ofm):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001090 # No identity multiplication is needed
1091 fm_id = ifm
1092 else:
1093 # Add multiplication with identity
1094 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1095 mul_identity.add_input_tensor(ifm)
1096 # Create const tensor containing identity as scalar
1097 quantization = ifm.quantization.clone()
1098 quantization.min = 0
1099 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001100 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001101 quantization.zero_point = 0
1102 identity_tens = create_const_tensor(
1103 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
1104 )
1105 mul_identity.add_input_tensor(identity_tens)
1106 # Make sure that fm_id is allocated to a different address than fm_alpha
1107 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1108 mul_identity.set_output_tensor(fm_id)
1109 mul_identity.set_ifm_ofm_shapes()
1110 DebugDatabase.add_optimised(op, mul_identity)
1111
1112 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001113 op.type = new_op
1114 op.name = op.name.replace("LeakyRelu", new_op.name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001115 op.inputs = []
1116 ifm.consumer_list.remove(op)
1117 op.add_input_tensor(fm_alpha)
1118 op.add_input_tensor(fm_id)
1119 op.set_ifm_ofm_shapes()
1120
1121 DebugDatabase.add_optimised(op, op)
1122 return op
1123
1124
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001125def convert_to_lut8(op, fn, fn_name):
1126 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
1127 # fn is a function(real) -> real
1128 ifm, ofm = op.get_ifm_ofm()
1129 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1130 return op
1131 # Generate the LUT
1132 ifm_scale = np.double(ifm.quantization.scale_f32)
1133 ofm_scale = np.double(ofm.quantization.scale_f32)
1134 zp_in = ifm.quantization.zero_point
1135 zp_out = ofm.quantization.zero_point
1136 values = []
1137 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1138 quantized_min = min(ix)
1139 quantized_max = max(ix)
1140 for x in ix:
1141 x_real = ifm_scale * (x - zp_in)
1142 y_real = fn(x_real)
1143 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1144 lut_result = min(quantized_max, max(quantized_min, lut_result))
1145 values.append(lut_result)
1146 return convert_to_lut(op, values, fn_name)
1147
1148
1149def convert_lrelu_to_lut(op, arch):
1150 ifm, ofm = op.get_ifm_ofm()
1151 # Generate the LUT
1152 alpha = op.attrs["alpha"]
1153 ifm_scale = np.double(ifm.quantization.scale_f32)
1154 ofm_scale = np.double(ofm.quantization.scale_f32)
1155 zp_in = ifm.quantization.zero_point
1156 zp_out = ofm.quantization.zero_point
1157 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1158 alpha_scalar = 1
1159 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1160 if "alpha_scaling" in op.attrs:
1161 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1162 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1163 values = []
1164 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1165 quantized_min = min(ix)
1166 quantized_max = max(ix)
1167 for x in ix:
1168 if x < zp_in:
1169 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1170 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1171 )
1172 else:
1173 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1174 lut_result = min(quantized_max, max(quantized_min, lut_result))
1175 values.append(lut_result)
1176 return convert_to_lut(op, values, "lrelu")
1177
1178
1179def convert_lrelu(op, arch, nng):
1180 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
1181 if op.type != Op.LeakyRelu:
1182 return op
1183 ifm, ofm = op.get_ifm_ofm()
1184 if ifm is None or ofm is None:
1185 return op
Fredrik Svedberg36424312022-09-16 09:39:26 +02001186 alpha = op.attrs["alpha"]
1187 if alpha == 0:
1188 # When alpha is 0 the opertion can be converted to a ReLU
1189 op.type = Op.Relu
1190 op.name = op.name.replace("LeakyRelu", op.type.name)
1191 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001192 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1193 # use LUT for int8/uint8
1194 return convert_lrelu_to_lut(op, arch)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001195 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16 and alpha > 0:
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001196 # use LeakyRelu unmodified for int16 with equal input/output scaling and positive alpha
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001197 return op
1198 return convert_lrelu_to_mul_max(op, arch)
1199
1200
1201def convert_tanh_sigmoid_to_lut(op, arch, nng):
1202 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
1203 if op.type == Op.Sigmoid:
1204 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1205 elif op.type == Op.Tanh:
1206 return convert_to_lut8(op, math.tanh, "tanh")
1207 return op
1208
1209
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001210def remove_memory_only_ops(op, arch):
1211 if op.run_on_npu and op.type in memory_only_ops:
1212 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001213
1214
1215def fuse_activation_function_with_prev(op, arch, nng):
1216 # if op is a no-op: attempts to move the activation function to the preceding op
1217 if not op.attrs.get("is_nop", False) or op.activation is None:
1218 return op
1219 ifm, ofm = op.get_ifm_ofm()
1220 if ifm is None or ofm is None:
1221 return op
1222 # finds the input(s) to the operation
1223 prev_op = ifm.ops[0]
1224 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1225 fuse = (
1226 prev_op.run_on_npu
1227 and prev_op.type.npu_block_type != NpuBlockType.Default
1228 and len(ifm.ops) == 1
1229 and len(prev_op.outputs[0].consumers()) == 1
1230 and prev_op.activation is None
1231 )
1232 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1233 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1234 # LUT currently only works correctly for elementwise ops
1235 fuse = False
1236 if not fuse:
1237 return op
1238 # Move the fused activation function + corresponding info to prev_op
1239 prev_op.activation = op.activation
1240 prev_op.forced_output_quantization = op.forced_output_quantization
1241 if op.activation_lut is not None:
1242 prev_op.set_activation_lut(op.activation_lut)
1243 # Bypass op
1244 prev_op.set_output_tensor(ofm)
1245 DebugDatabase.add_optimised(op, prev_op)
1246 return op
1247
1248
1249def _leading_pad_ok(leading_pad, stride, kernel_size):
1250 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1251 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1252 max_size = kernel_size // 2
1253 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1254
1255
1256def replace_pad_by_hw_pad(op: Operation, arch, nng):
1257 """
1258 Tries to completely remove a PAD operator by using hardware padding.
1259 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1260 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1261 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1262 if both operations can be run on the NPU.
1263 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1264 """
1265 if (
1266 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +00001267 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001268 and op.run_on_npu
1269 and op.attrs["padding"] == Padding.VALID
1270 ):
1271 pad_op = op.ifm.ops[0]
1272 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1273 return op
1274 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1275 return op
1276 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1277 k = op.kernel
1278 k_w, k_h = k.dilated_wh()
1279
1280 # Check if the PAD operator can be replaced by hardware padding
1281 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1282 # Too much padding, it would require hardware padding to actually insert zeros
1283 return op
1284 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1285 return op
1286
1287 if op.type.is_avgpool_op():
1288 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1289 for pad, k_size in (
1290 (left, k_w),
1291 (right, k_w),
1292 (top, k_h),
1293 (bottom, k_h),
1294 ):
1295 if pad not in (0, k_size // 2):
1296 return op
1297 # Average pool is converted to depthwise, because NPU average pool + same padding
1298 # has a special implementation that is different from PAD followed by average pool with
1299 # valid padding.
1300 k_w, k_h = op.kernel.width, op.kernel.height
1301 ifm = op.ifm
1302 # Remember other inputs
1303 other_inputs = op.inputs[1:]
1304 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1305 quantization = QuantizationParameters(0.0, 255.0)
1306 quantization.scale_f32 = 1.0 / (k_w * k_h)
1307 quantization.zero_point = 0
1308 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1309 weights = np.full(shape, 1)
1310
1311 weight_tens = create_const_tensor(
1312 op.name + "_weights",
1313 shape,
1314 op.ifm.dtype,
1315 weights,
1316 np.uint8,
1317 purpose=TensorPurpose.Weights,
1318 quantization=quantization,
1319 )
James Peet7519d502021-07-19 16:47:58 +01001320 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001321 op.type = Op.DepthwiseConv2DBias
1322 op.inputs = []
1323 op.add_input_tensor(ifm)
1324 op.add_input_tensor(weight_tens)
1325 # Add bias tensor, all biases set to 0
1326 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001327 fixup_bias_tensors(op, arch, nng, DataType.int32)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001328 # Add other inputs
1329 op.inputs.extend(other_inputs)
1330 op.rounding_mode = NpuRoundingMode.NATURAL
1331
1332 # Bypass the PAD operator
1333 op.set_input_tensor(pad_op.ifm, 0)
1334 # Adjust the padding attributes of the convolution operator
1335 op.attrs["padding"] = Padding.EXPLICIT
1336 op.attrs["explicit_padding"] = (top, left, bottom, right)
1337 op.set_ifm_ofm_shapes()
1338 return op
1339
1340
1341def convert_pad(op: Operation, arch, nng):
1342 """
1343 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1344 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1345 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1346 """
1347 if op.type != Op.Pad or not op.run_on_npu:
1348 return op
1349 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1350
1351 ifm = op.ifm
1352 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001353 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001354 ofm = op.ofm
1355 assert ofm is not None
1356 ofm.ops = []
1357 ofm_shape = op.ofm_shapes[0]
1358
1359 # Average pool op that copies IFM to the right place inside the OFM
1360 shp0 = Shape4D(0, 0, 0, 0)
1361 shp_top = shp0.with_height(top)
1362 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1363 avgpool_op.activation = op.activation
1364 quant = ofm.quantization
1365 pad_value = quant.zero_point
1366 # Add operations that fill the borders of the OFM
1367 if top > 0:
1368 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1369 zero_tens = create_const_tensor(
1370 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1371 )
1372 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1373 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1374 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1375 if bottom > 0:
1376 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1377 zero_tens = create_const_tensor(
1378 op.name + "_bottom",
1379 shape.as_list(),
1380 ofm.dtype,
1381 shape.elements() * [pad_value],
1382 np.uint8,
1383 quantization=quant,
1384 )
1385 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1386 create_avg_pool_for_concat(
1387 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1388 )
1389 if left > 0:
1390 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1391 zero_tens = create_const_tensor(
1392 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1393 )
1394 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1395 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1396 if right > 0:
1397 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1398 zero_tens = create_const_tensor(
1399 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1400 )
1401 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1402 create_avg_pool_for_concat(
1403 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1404 )
1405
1406 op.type = Op.ConcatTFLite
1407 return avgpool_op
1408
1409
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001410def fixup_bias_tensors(op, arch, nng, dtype=None):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001411 if op.type.needs_bias() and op.bias is None:
1412 # Op has no bias, add bias tensor filled with zeros
1413 nr_biases = op.inputs[1].shape[-1]
1414 bias_values = [0] * nr_biases
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001415 # The DataType of the bias tensor can be explicitly provided or deduced from the ifm
1416 # DataType. Default is int32 bias for 8-bit ifms and int64 for int16 ifms.
1417 # For int16 the selected bias DataType will have an impact on the scaling
1418 # used when encoding the scales and biases later. The default mode will match the
1419 # refence with reduced scaling for int64 bias.
1420 # This means that in cases (in the graph optimiser) where DepthwiseConv2DBias
1421 # is used to emulate average pool int32 bias should be selected for full precision
1422 # int16 scaling.
1423 if dtype is None:
1424 dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32
1425 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001426 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1427
1428 return op
1429
1430
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001431def fixup_asymmetric_weights(op, arch, nng):
1432 if op.run_on_npu and (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op()):
1433 if op.ifm.dtype == DataType.int8:
1434 if not np.all(op.weights.quantization.zero_point == 0):
1435 print(f"Warning: {op.type} '{op.name}' has asymmetric weights, zero points have been adjusted.")
1436 op.weights.quantization.zero_point *= 0
1437
1438 return op
1439
1440
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001441def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1442 if op.type == Op.Mean and op.run_on_npu:
1443 keep_dims = op.attrs.get("keep_dims", False)
1444 inp, axis = op.inputs
1445 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001446 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001447 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001448 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001449
1450 # Height and width axes have different index depending on dimensions
1451 if axis.shape == [] or axis.shape[0] == 1: # single axis
1452 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1453 if dims in (2, 3):
1454 if axis == 0:
1455 h, w = shape[axis], 1
1456 else:
1457 h, w = 1, shape[axis]
1458 else:
1459 if axis == 1:
1460 h, w = shape[axis], 1
1461 else:
1462 h, w = 1, shape[axis]
1463 else: # multiple axes
1464 axis = sorted(axis.values)
1465 h, w = [shape[i] for i in axis]
1466
1467 # Set necessary depthwise attributes
1468 op.attrs.update(
1469 {
1470 "padding": Padding.VALID,
1471 "stride_h": 1,
1472 "stride_w": 1,
1473 "strides": (1, 1, 1, 1),
1474 "depth_multiplier": 1,
1475 "channel_multiplier": 1,
1476 "dilation_h_factor": 1,
1477 "dilation_w_factor": 1,
1478 "dilation": (1, 1, 1, 1),
1479 }
1480 )
1481 # Change op type
1482 op.type = Op.DepthwiseConv2DBias
1483 # Set IFM/OFM shapes after changing op type
1484 op.set_ifm_ofm_shapes()
1485
1486 weight_scale, bias = 1, None
1487 ofmq, ifmq = op.ofm.quantization, inp.quantization
1488 # Set rounding mode, scaling and zero point based on which reference implementation to match
1489 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1490 if inp.dtype == DataType.uint8:
1491 # This attribute means a different scaling calculation is used in order to match reference
1492 op.low_precision_scaling = True
1493 weight_scale = h * w
1494 # Set zero points to 0 as they will be adjusted for with bias term
1495 foq = ofmq.clone()
1496 foq.zero_point = 0
1497 fiq = ifmq.clone()
1498 fiq.zero_point = 0
1499 op.forced_input_quantization = fiq
Johan Alfvén17009392022-08-30 09:14:56 +02001500 bias_term = ofmq.zero_point - round_up_to_int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001501 # If the bias term is outside uint8 range, we need an Add op to apply it.
1502 if bias_term < 0 or bias_term > 255:
1503 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1504 # Bias term has higher bitness (i32) than input/output (u8).
1505 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1506 # the bias can only effectively assume values in the range [-255, 255].
1507 intermediate.dtype = DataType.int16
1508 intermediate.quantization.zero_point = 0
1509 add_op = Operation(Op.Add, op.name + "_bias")
1510 add_op.forced_output_quantization = foq
1511 add_op.add_input_tensor(intermediate)
1512 quant = QuantizationParameters()
1513 quant.zero_point = 0
1514 bias_term_tens = create_const_tensor(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001515 op.name + "_bias",
1516 [1, 1, 1, 1],
1517 DataType.int16,
1518 [bias_term],
1519 np.int16,
1520 quantization=quant,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001521 )
1522 add_op.add_input_tensor(bias_term_tens)
1523 add_op.set_output_tensor(op.ofm)
1524 add_op.set_ifm_ofm_shapes()
1525 add_op.activation = op.activation
1526 op.activation = None
1527 op.set_output_tensor(intermediate)
1528 op.set_ifm_ofm_shapes()
1529 # If not, we can just do it with the OFM zero point.
1530 else:
1531 foq.zero_point = bias_term
1532 op.forced_output_quantization = foq
1533 else:
1534 assert inp.dtype == DataType.int8
1535 # Use a depthwise to calculate the sum,
1536 # followed by a multiplication with 1/N to get the MEAN
1537 weight_scale = 1
1538 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
Johan Alfvén05916632022-09-06 20:33:22 +02001539 intermediate.dtype = DataType.int32
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001540 mul_op = Operation(Op.Mul, op.name + "_mul")
1541 mul_op.add_input_tensor(intermediate)
Johan Alfvén05916632022-09-06 20:33:22 +02001542 mul_op.set_output_tensor(op.ofm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001543 # Create scalar containing 1/N
1544 quant = QuantizationParameters()
1545 quant.zero_point = 0
1546 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1547 # while rounding mode NATURAL would round this to -1.
1548 # This can only occur if N is even, and can be emulated by
1549 # multiplying with a number that is slightly smaller than 1/N.
1550 # It must be so small that other roundings are not affected;
1551 # the calculated value is based on worst case,
1552 # which is sum 256 * N (the maximum sum that can occur with int8)
1553 n = int(h * w)
1554 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1555 quant.scale_f32 = 1 / (n - eps)
Johan Alfvén05916632022-09-06 20:33:22 +02001556
1557 # For int8/int16 we could use IFM/OFM scaling to do the division
1558 # intermediate * 1 -> scale > round and shift.
1559 #
1560 # For int32 scaling is not supported so instead multiply with the scale
1561 # intermediate * scale -> round and shift.
1562 #
1563 # Calculate the scale and shift value. const Tensor must be created
1564 # with correct quantization since the scale and shift is calculated later
1565 # in the command stream generator.
1566 mul_scale, _ = scaling.elementwise_mul_scale(
1567 mul_op.ifm.quantization.scale_f32, quant.scale_f32, mul_op.ofm.quantization.scale_f32
1568 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001569 scalar = create_const_tensor(
Johan Alfvén05916632022-09-06 20:33:22 +02001570 op.name + "_scalar", [1, 1, 1, 1], DataType.int32, [mul_scale], np.int32, quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001571 )
1572 mul_op.add_input_tensor(scalar)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001573 mul_op.set_ifm_ofm_shapes()
1574 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1575 mul_op.activation = op.activation
1576 op.activation = None
1577 op.set_output_tensor(intermediate)
1578 op.set_ifm_ofm_shapes()
1579 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1580 # Here we can just use a simple AvgPool with truncating rounding,
1581 # as we're emulating simple integer division.
1582 op.rounding_mode = NpuRoundingMode.TRUNCATE
1583 op.type = Op.AvgPool
1584 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1585 else:
1586 op.rounding_mode = NpuRoundingMode.NATURAL
1587 weight_scale = 1 / (h * w)
1588 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1589 bias = -ifmq.zero_point * h * w
1590 fiq = ifmq.clone()
1591 fiq.zero_point = 0
1592 op.forced_input_quantization = fiq
1593
1594 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001595 def extend_dims(dim, in_shape):
1596 if dim < 4:
1597 in_shape = [1] + in_shape
1598 if dim == 2:
1599 in_shape += [1]
1600 return in_shape
1601
1602 if dims < 4 or dims_ofm < 4:
1603 # Fix the ofm dimension when keep_dims is false
1604 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1605 if isinstance(axis, int) and dims_ofm + 1 == dims:
1606 ofm_shape.insert(axis, 1)
1607 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1608 for i in axis:
1609 ofm_shape.insert(i, 1)
1610 shape = extend_dims(dims, shape)
1611 dims_ofm = len(ofm_shape)
1612 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1613 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001614
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001615 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
1616 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001617 shape = [shape[0], 1, h * w, shape[3]]
1618 op.ifm_shapes[0] = Shape4D(shape)
1619 if h > 256 and op.type == Op.AvgPool:
1620 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1621
1622 # If the AvgPool version is used, we don't need to do anything else
1623 if op.type == Op.AvgPool:
1624 return op
1625
1626 # Make unit weight tensor quantization
1627 weight_quant = ifmq.clone()
1628 weight_quant.min = 0
1629 weight_quant.max = 255
1630 weight_quant.scale_f32 = weight_scale
1631 weight_quant.zero_point = 0
1632
1633 # Set weight shape to [H,W,C,B]
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001634 weight_shape = [h, w, shape[3], shape[0]]
1635
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001636 # Add unit weight tensor
1637 op.set_input_tensor(
1638 create_const_tensor(
1639 "weights",
1640 weight_shape,
1641 inp.dtype,
1642 np.ones(weight_shape),
1643 value_dtype=np.uint8,
1644 quantization=weight_quant,
1645 ),
1646 1,
1647 )
James Peet7519d502021-07-19 16:47:58 +01001648 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001649
1650 # Add None bias tensor
1651 op.inputs.append(None)
1652 # Add bias tensor
1653 if bias:
1654 bias_shape = [shape[-1]]
1655 op.set_input_tensor(
1656 create_const_tensor(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001657 "bias",
1658 bias_shape,
1659 inp.dtype,
1660 np.ones(bias_shape) * bias,
1661 value_dtype=np.int32,
1662 quantization=None,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001663 ),
1664 2,
1665 )
1666
1667 return op
1668
1669
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001670def optimise_quantize(op: Operation, arch, nng):
1671
1672 if op.type == Op.Quantize and op.run_on_npu:
1673
1674 ifm, ofm = op.get_ifm_ofm()
1675 input_values = ifm.values
1676
1677 # Guard clause - input not const or no values to quantize
1678 if ifm.ops[0].type != Op.Const or input_values is None:
1679 return op
1680
1681 # Singular val in numpy array, convert to indexable array
1682 if input_values.ndim == 0:
1683 input_values = np.array([input_values])
1684
Fredrik Svedberg11563172022-07-06 14:54:12 +02001685 # requantized int8 to int8 or int16 to int16
1686 if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001687
1688 # scale needs to use double precision to match TFLite reference kernel
1689 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
1690 effective_multiplier, effective_shift = quantise_scale(effective_scale)
1691
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001692 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001693 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001694 input_val = val - ifm.quantization.zero_point
1695
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001696 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
1697 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001698
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001699 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
1700 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001701
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001702 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
1703 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001704
1705 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001706 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001707
1708 quantized_vals = []
1709 for val in input_values:
1710
1711 # Derive quantized value
1712 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001713 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
1714 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001715
1716 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001717 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
1718
1719 # Unsupported data type
1720 else:
1721 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001722
1723 # Make quantize op const and disconnect from parent node
1724
1725 # Remove reference of the current quant op from the parent tensor's consumer list
1726 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1727
1728 # Clear any references to parent node
1729 op.inputs = []
1730
1731 # Convert this quantize op to const
1732 op.type = Op.Const
1733
1734 return op
1735
1736
Ayaan Masood4965fae2022-06-29 11:30:57 +01001737def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
1738 """Static optimisation for SHAPE operator output value known at compile time"""
1739
1740 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
1741
1742 if op.type == Op.Shape and op.run_on_npu:
1743
1744 ifm, ofm = op.get_ifm_ofm()
1745
1746 if len(ifm.shape) != ofm.shape[0]:
1747 return op
1748
1749 # Remove reference of the current shape op from the parent tensor's consumer list
1750 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1751
1752 # Clear any references to parent node
1753 op.inputs = []
1754
1755 # Convert this SHAPE op to const
1756 op.type = Op.Const
1757
1758 # Add size calculation to shape output tensors
1759 ofm.values = np.array(ifm.shape)
1760
1761 return op
1762
1763
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001764def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001765 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001766 return op
1767
1768
1769def tflite_optimise_graph(nng, arch):
Fredrik Svedberg11563172022-07-06 14:54:12 +02001770 # Compile time static optimisations
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001771 optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
1772
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001773 for idx, sg in enumerate(nng.subgraphs):
1774 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001775 nng,
1776 sg,
1777 arch,
1778 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01001779 optimisation_list,
1780 rewrite_unsupported=False,
1781 )
1782
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001783 # Pre-processing step
1784 pre_process_list = [
1785 supported_operator_check,
1786 set_ifm_ofm_op_shapes,
1787 ]
1788
Ayaan Masood4965fae2022-06-29 11:30:57 +01001789 for idx, sg in enumerate(nng.subgraphs):
1790 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1791 nng,
1792 sg,
1793 arch,
1794 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02001795 pre_process_list,
1796 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001797 )
1798
1799 # Handle Concat Ops
1800 for idx, sg in enumerate(nng.subgraphs):
1801 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1802 sg.refresh_after_modification()
1803
1804 # Handle Split Ops
1805 for idx, sg in enumerate(nng.subgraphs):
1806 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1807 nng,
1808 sg,
1809 arch,
1810 [],
1811 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1812 rewrite_unsupported=False,
1813 )
1814
1815 for idx, sg in enumerate(nng.subgraphs):
1816 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001817 nng,
1818 sg,
1819 arch,
1820 [rewrite_split_ops],
1821 [],
1822 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001823 )
1824
1825 # Handle sg input output
1826 for idx, sg in enumerate(nng.subgraphs):
1827 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001828 nng,
1829 sg,
1830 arch,
1831 [],
1832 [fix_sg_input_output],
1833 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001834 )
1835
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001836 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001837 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001838 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001839 sg.refresh_after_modification()
1840
1841 # Rewrite of operators
1842 op_rewrite_list = [
1843 set_tensor_equivalence,
1844 convert_mean_to_depthwise_conv_or_avgpool,
1845 convert_depthwise_to_conv,
1846 convert_conv_to_fc,
1847 convert_softmax,
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001848 convert_prelu,
Fredrik Svedberg36424312022-09-16 09:39:26 +02001849 convert_mul_max_to_abs_or_lrelu,
1850 convert_lrelu,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001851 optimise_strided_conv,
1852 convert_hardswish_to_lut,
1853 rewrite_fully_connected_input,
1854 convert_batched_fc_shape,
1855 fixup_conv2d_backprop,
1856 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001857 reorder_depthwise_weights,
Tim Hall885033b2022-07-21 11:46:03 +01001858 fixup_resize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001859 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001860 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001861 convert_tanh_sigmoid_to_lut,
1862 replace_pad_by_hw_pad,
1863 ]
1864
1865 for idx, sg in enumerate(nng.subgraphs):
1866 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001867 nng,
1868 sg,
1869 arch,
1870 [],
1871 op_rewrite_list,
1872 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001873 )
1874
1875 for idx, sg in enumerate(nng.subgraphs):
1876 # remove passthrough tensors and attempt further optimizations
1877 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1878 nng,
1879 sg,
1880 arch,
1881 [remove_passthrough_tensor],
1882 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1883 )
1884
1885 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1886 # since ifm/ofm_shapes are of importance to this function
1887 for sg in nng.subgraphs:
1888 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1889 sg.refresh_after_modification()
1890
1891 return nng