blob: 1b70165ebe67bb9162a5664cc0f5e94336a4c986 [file] [log] [blame]
Tim Hall3b1578e2023-01-13 17:57:25 +00001# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020017# Description:
18# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
19# to do the traversal of the graph.
20import math
21import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020022
23import numpy as np
24
25from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020026from . import rewrite_graph
27from . import scaling
28from .api import NpuRoundingMode
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020029from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020030from .data_type import DataType
31from .debug_database import DebugDatabase
32from .errors import UnsupportedFeatureError
33from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020034from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020035from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020036from .graph_optimiser_util import convert_depthwise_to_conv
Fredrik Svedberg0ac08042023-04-11 22:35:04 +020037from .graph_optimiser_util import create_avg_pool_for_concat
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020038from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020039from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020040from .graph_optimiser_util import needed_total_padding
41from .graph_optimiser_util import set_ifm_ofm_op_shapes
42from .graph_optimiser_util import set_tensor_equivalence
Fredrik Svedberg0ac08042023-04-11 22:35:04 +020043from .lstm import Lstm
Johan Alfvence502732023-04-24 13:35:40 +020044from .lut import convert_to_lut
45from .lut import create_lut_8bit_op
46from .lut import create_lut_int16_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020047from .numeric_util import clamp_sigmoid
Johan Alfven56811e62023-03-27 11:33:50 +020048from .numeric_util import full_shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020049from .numeric_util import round_away_zero
Johan Alfven7b3008a2023-04-13 18:54:47 +020050from .numeric_util import round_down_log2
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020051from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020052from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020053from .operation import NpuBlockType
54from .operation import Op
55from .operation import Operation
56from .operation import Padding
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +010057from .operation_util import create_add_nop
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020058from .operation_util import create_avgpool_nop
Johan Alfvenc1ad80b2023-03-31 10:19:23 +020059from .operation_util import create_cast_op
Rickard Bolin6986a072022-12-19 12:33:40 +000060from .operation_util import create_depthwise_maxpool
Johan Alfvenc1ad80b2023-03-31 10:19:23 +020061from .operation_util import create_memcpy
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020062from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010063from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020064from .shape4d import Shape4D
65from .softmax import SoftMax
66from .tensor import check_quantized_tens_scaling_equal
67from .tensor import create_const_tensor
68from .tensor import create_equivalence_id
69from .tensor import QuantizationParameters
70from .tensor import Tensor
71from .tensor import TensorPurpose
72from .tflite_mapping import optype_to_builtintype
73
74passthrough_nodes = (Op.Identity,)
75
76
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020077def remove_passthrough_tensor(tens, arch, nng):
78 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
79 assert len(tens.ops[0].inputs) == 1
80 tens = tens.ops[0].inputs[0]
81 return tens
82
83
84def rewrite_concat_ops(op, arch):
85 if not op.run_on_npu or not op.type.is_concat_op():
86 return
87
88 axis_4D = 0
89 ofm = op.ofm
90 ofm.ops = []
91 offset = 0
92
93 unfuse_activation_function(op)
94
95 if op.type == Op.Pack:
96 # Pack is also referred to as Stack
97 axis = int(op.attrs["axis"])
98 if axis < 0: # Convert to positive axis
99 axis = len(op.inputs[0].shape) + 1 + axis
100
101 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
102
103 axis_4D = axis + (4 - len(desired_shape))
104
105 for idx, inp in enumerate(op.inputs):
106 op.ifm_shapes[idx] = Shape4D(desired_shape)
107 op.type = Op.PackReshaped
108
109 inputs, axis = op.get_concat_inputs_axis()
110 for idx, inp in enumerate(inputs):
111 if op.type != Op.PackReshaped:
112 op.ifm_shapes[idx] = Shape4D(inp.shape)
113 if axis >= 0:
114 axis_4D = axis + (4 - len(inp.shape))
115 else:
116 axis_4D = axis
117 write_offset = [0, 0, 0, 0]
118 write_offset[axis_4D] = offset
119 concat_end = offset + op.ifm_shapes[idx][axis_4D]
120 create_avg_pool_for_concat(
121 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
122 )
123 offset = concat_end
124 assert ofm.shape[axis] == offset
125
126 return op
127
128
129def rewrite_split_ops(tens, arch, nng):
130
131 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
132 split_op = tens.ops[0]
133
134 # Not supported so leave it and run on CPU
135 if not split_op.run_on_npu:
136 return tens
137
138 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
139
140 tens.ops = []
141 new_op = Operation(Op.SplitSliceRead, split_op.name)
142 new_op.inputs = [inp]
143 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000144 if None in (offset_end, offset_start):
145 read_shape = None
146 else:
147 # the read shape is relative to each start offset
148 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200149
150 # For Split the offset cannot be extracted from the tensor so it has to
151 # be calculated from the index of the output tensor
152 if axis is not None:
153 # Get the start and end of the split
154 offset_start = [0] * 4
155 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
156 for idx, out in enumerate(outputs):
157 if axis_4D_list is not None:
158 axis_4D = axis_4D_list[idx]
159 else:
160 split_op.ofm_shapes[idx] = Shape4D(out.shape)
161 if axis >= 0:
162 axis_4D = axis + (4 - len(out.shape))
163 else:
164 axis_4D = axis
165
166 if out == tens:
167 ofm_shape_idx = idx
168 read_shape = split_op.ofm_shapes[idx]
169 break
170
171 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
172
173 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
174 new_op.read_shapes[0] = read_shape
175 new_op.run_on_npu = True
176 new_op.set_output_tensor(tens)
177 new_op.ifm_shapes.append(Shape4D(inp.shape))
178 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
179 DebugDatabase.add_optimised(split_op, new_op)
180
181 return tens
182
183
184def remove_SplitSliceRead(op, arch):
185
186 if op.type == Op.SplitSliceRead:
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200187 # Check if it is possible to put the SplitSliceRead on the tensor consumer(s),
188 # or if an avgpool need to be inserted
189 if op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape) and all(
190 consumer is not None and consumer.run_on_npu and consumer.type not in memory_only_ops
191 for consumer in op.ofm.consumer_list
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200192 ):
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200193 # SplitSliceRead can be performed by tensor consumer(s)
194 for cons_op in list(op.ofm.consumer_list):
195 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200196 else:
197 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
198 avgpool_op.add_input_tensor(op.ifm)
199 avgpool_op.outputs = [op.ofm]
200 op.ofm.ops.remove(op)
201 op.ofm.ops.append(avgpool_op)
202 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
203 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
204 avgpool_op.read_offsets[0] = op.read_offsets[0]
205 avgpool_op.read_shapes[0] = op.read_shapes[0]
206
207 op.ifm.consumer_list.remove(op)
208 DebugDatabase.add_optimised(op, avgpool_op)
209
210
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200211def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
212 k_w, k_h = kernel.dilated_wh()
213 s_x, s_y = kernel.stride
214 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
215 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
216 if padding_type == Padding.SAME:
217 left_pad = (xpad + 0) // 2
218 right_pad = (xpad + 1) // 2
219 top_pad = (ypad + 0) // 2
220 bottom_pad = (ypad + 1) // 2
221 elif padding_type == Padding.VALID:
222 left_pad = 0
223 right_pad = 0
224 top_pad = 0
225 bottom_pad = 0
226 elif padding_type == Padding.EXPLICIT:
227 # Padding is specified in a PAD operator which has been bypassed.
228 top, left, bottom, right = explicit_padding
229 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
230 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Rickard Bolin9ae34552022-06-09 13:07:17 +0000231 elif padding_type == Padding.TILE:
232 # The values in the explicit padding only represent the "direction" in which to pad
233 top_pad, left_pad, bottom_pad, right_pad = explicit_padding
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200234 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000235 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200236 padding = (top_pad, left_pad, bottom_pad, right_pad)
237 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
238 return padding, skirt
239
240
241def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
242 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
243 if padding_type == Padding.SAME:
244 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
245 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
246 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
247 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
248 left_pad = max(kernel_width - 1 - right_pad, 0)
249 top_pad = max(kernel_height - 1 - bottom_pad, 0)
250 elif padding_type == Padding.VALID:
251 right_pad = max(kernel_width - 2, 0)
252 bottom_pad = max(kernel_height - 2, 0)
253 left_pad = kernel_width - 1
254 top_pad = kernel_height - 1
255 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000256 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200257 padding = (top_pad, left_pad, bottom_pad, right_pad)
258 skirt = padding
259 return padding, skirt
260
261
262def fixup_conv2d_backprop(op, arch, nng):
263 if op.type == Op.Conv2DBackpropInput:
264 # flip the inputs
265 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
266 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000267 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200268
269 # Update strides
270 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
wilisa0179a89042022-11-02 17:18:43 +0000271 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200272
273 return op
274
275
276# Convert the op to an elementwise add
Tim Hall885033b2022-07-21 11:46:03 +0100277def convert_resize_1x1_to_add(op):
278 op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200279 op.name = op.name + "_add"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200280 # Create an input tensor filled with zeros
wilisa018289d512023-01-12 08:17:23 +0000281 name = op.inputs[1].name + "_add"
282 dtype = op.inputs[0].dtype
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200283 shape = op.ofm_shapes[0].as_list()
wilisa018289d512023-01-12 08:17:23 +0000284 values = np.zeros(shape, dtype.as_numpy_type())
285 quantization = QuantizationParameters(0.0, 255.0)
286 quantization.scale_f32 = 1.0
287 quantization.zero_point = 0
wilisa0116b5e5e2023-02-14 12:03:59 +0000288 op.inputs[1] = op.inputs[0]
289 op.set_input_tensor(create_const_tensor(name, shape, dtype, values, quantization=quantization), 0)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200290 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000291 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200292
293 return op
294
295
Tim Hall885033b2022-07-21 11:46:03 +0100296# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
297# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
298# to select the appropriate nearest neighbor value
299def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
300 ifm = op.ifm
301 ofm = op.ofm
302 output_depth = ofm.shape[-1]
303 dw_op_attrs = {
304 "padding": Padding.VALID,
305 "stride_h": 1,
306 "stride_w": 1,
307 "strides": (1, 1, 1, 1),
308 "depth_multiplier": 1,
309 "channel_multiplier": 1,
310 "dilation_h_factor": 1,
311 "dilation_w_factor": 1,
312 "dilation": (1, 1, 1, 1),
313 }
314
315 # change resizebilinear to depthwise
316 op.type = Op.DepthwiseConv2DBias
317 op.attrs.update(dw_op_attrs)
318 op.set_input_tensor(ifm, 0) # ifm tensor index
319 op.activation = None
320
321 # add input resample to resize by x2
322 op.ifm_resampling_mode = resampling_mode.NEAREST
323
324 # don't care about the rounding mode as it is nearest neighbor
325
326 # setup weight tensor
327 weight_quant = QuantizationParameters()
328 weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
329 weight_quant.zero_point = 0
330 weight_quant.quant_dim = 0
331 ofm_dtype = ofm.dtype
Tim Hall3b1578e2023-01-13 17:57:25 +0000332 if ofm_dtype.type == BaseType.UnsignedInt:
Tim Hall885033b2022-07-21 11:46:03 +0100333 weight_quant.quant_min = 0
334 weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
335 else:
Tim Hall885033b2022-07-21 11:46:03 +0100336 weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
337 weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
338
339 weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
340
341 # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
342 # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
343 # below-and-right (i.e. next) to it (D).
344 # 0---1---2
345 # | A | B |
346 # 1---*---+
347 # | C | D |
348 # 2---+---+
349 weight_values = [0] * (upscale_factor * upscale_factor)
350 centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
351 weight_values[centre_coeff] = 1
352
353 # add weight tensor, this will discard the size tensor of the resize op
354 op.set_input_tensor(
355 create_const_tensor(
356 "weights",
357 weight_shape,
Tim Hall3b1578e2023-01-13 17:57:25 +0000358 ofm_dtype,
Tim Hall885033b2022-07-21 11:46:03 +0100359 np.array(weight_values).reshape(weight_shape),
Tim Hall885033b2022-07-21 11:46:03 +0100360 quantization=weight_quant,
361 ),
362 1, # inputs tensor weight index
363 )
364
365 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
366 # need to append the bias tensor as resize ops only have 2 inputs
367 assert len(op.inputs) == 2
368 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +0200369 fixup_bias_tensors(op, None, None, DataType.int32)
Tim Hall885033b2022-07-21 11:46:03 +0100370
371 # finally update the shape incase we've change the tensor shapes or connections
372 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000373 DebugDatabase.add_optimised(op, op)
Tim Hall885033b2022-07-21 11:46:03 +0100374
375 return op
376
377
378# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
379# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
380# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
381def convert_resize_to_upscale_and_average_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200382 pre_op = op
383 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000384 dtype = op.ifm.dtype
Tim Hall885033b2022-07-21 11:46:03 +0100385
Rickard Boline546def2022-01-25 15:45:00 +0000386 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Tim Hall47c76362022-07-18 21:26:47 +0100387 op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
Tim Hall3c5cfe92022-03-16 16:31:57 +0000388 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200389
390 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
Tim Hall47c76362022-07-18 21:26:47 +0100391
392 # Get upscale factor that was calculated in the supported operators check
393 upscale_factor = op.attrs["upscale_factor"]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200394
Rickard Boline546def2022-01-25 15:45:00 +0000395 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100396 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
397 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
Rickard Boline546def2022-01-25 15:45:00 +0000398 n = int(np.log2(upscale_factor))
399
Tim Hall885033b2022-07-21 11:46:03 +0100400 # Perform x2 upscaling n-1 times
Rickard Boline546def2022-01-25 15:45:00 +0000401 scaled_op = pre_op
402 for count in range(n - 1):
403 if count > 0:
404 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200405 scaled_op.inputs[0] = pre_op.outputs[0]
406
Tim Hall885033b2022-07-21 11:46:03 +0100407 # Nearest neighbor x2 upscaling
Tim Hall47c76362022-07-18 21:26:47 +0100408 upscaled_shape = upscaled_shape * 2
Rickard Boline546def2022-01-25 15:45:00 +0000409 shape = op.ofm_shapes[0].as_list()
410 shape[1:3] = upscaled_shape
411 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
412 out_tens.quantization = op.outputs[0].quantization.clone()
413 scaled_op.set_output_tensor(out_tens)
414 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200415
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200416 scaled_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000417 DebugDatabase.add_optimised(op, scaled_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200418
Tim Hall885033b2022-07-21 11:46:03 +0100419 # Last x2 upscaling
Rickard Boline546def2022-01-25 15:45:00 +0000420 if n > 1:
421 scaled_op = op.clone(f"_{n-1}")
422 scaled_op.inputs[0] = pre_op.outputs[0]
Tim Hall885033b2022-07-21 11:46:03 +0100423
424 if scaled_op.original_type == Op.ResizeBilinear:
425 if scaled_op.attrs["align_corners"]:
426 # no padding
427 scaled_op.attrs["padding"] = Padding.VALID
428 else:
429 # padding to the right and bottom (limits average pool to 8x8 kernel)
430 scaled_op.attrs["padding"] = Padding.EXPLICIT
431 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
432
433 # kernal size dependent on the upscaling factor
434 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
435 else: # Op.ResizeNearestNeighbor
436 if scaled_op.attrs["align_corners"]:
437 # use depthwise conv to select the correct value
438 scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
439 else:
Johan Alfvéna64616c2022-10-17 12:29:12 +0200440 # Keep 1x1 kernel and average pool, this applies both when
441 # half-pixel-centers is True and False. Calculations are the
442 # same in the reference.
Tim Hall885033b2022-07-21 11:46:03 +0100443 pass
444
Rickard Boline546def2022-01-25 15:45:00 +0000445 scaled_op.outputs = outputs
446 scaled_op.outputs[0].ops = [scaled_op]
447 scaled_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000448 DebugDatabase.add_optimised(op, scaled_op)
Rickard Boline546def2022-01-25 15:45:00 +0000449
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200450 return op
451
452
Rickard Bolin6986a072022-12-19 12:33:40 +0000453def convert_argmax_to_depthwise_conv_and_max_pool(op, arch, nng):
454 """
455 Convert ArgMax to DWConv2D->MaxPool->DWConv2D, see details below.
456
457 Example:
458 arr = [4, [00000100,
459 6, = 00000110, # <-- This is the largest value, so we're expecting argmax(arr) = 1
460 5] 00000101]
461
462 Use 16-bit precision and shift all values 7 bits to the left:
463 Shifted_arr = [0000001000000000,
464 0000001100000000,
465 0000001010000000]
466
467 Add "c - index of channel" to each channel:
468 Shifted_arr_plus_reverse_idx = [0000001000000010, (+2)
469 0000001100000001, (+1)
470 0000001010000000] (+0)
471
472 The index is reversed since ArgMax selects the lowest index if maximum value is found at two index. The index will
473 act as a tie-breaker between channels with equal values and since we want the smallest channel index to be chosen
474 we reverse the index before the maxpool and then subtract the index from the number of channel after the maxpool to
475 get the correct index.
476
477 Find the maximum value in the array:
478 val = max(shifted_arr_plus_reverse_idx) = 0000001100000001
479
480 Subtract the value from the number of channels:
481 shifted_arr_plus_idx = (c-1) - val = 2 - 1 = 1
482
483 Extract the 7 lowest bits using a LUT to cut off the 9 most significant bits:
484 idx = LUT(val) = 0000000000000001 = 1
485 """
486
487 if op.type == Op.ArgMax:
488 ifm, ofm = op.inputs[0], op.outputs[0]
489 identity_quant = QuantizationParameters()
490 identity_quant.zero_point = 0
491 identity_quant.scale_f32 = 1.0
Rickard Bolin6986a072022-12-19 12:33:40 +0000492 # Add last dimension to ofm shape
493 ofm.shape += [1]
494 ofm.ops = []
495
496 # Create 1x1 Depthwise convolution with 2**7 weights for each channel to convert precision to 16 bit and shift
497 # all values 7 bits to the left
498 # Set necessary depthwise attributes
499 dw_op_attrs = {
500 "padding": Padding.VALID,
501 "stride_h": 1,
502 "stride_w": 1,
503 "strides": (1, 1, 1, 1),
504 "depth_multiplier": 1,
505 "channel_multiplier": 1,
506 "dilation_h_factor": 1,
507 "dilation_w_factor": 1,
508 "dilation": (1, 1, 1, 1),
509 "explicit_padding": None,
510 }
Johan Alfvenc1ad80b2023-03-31 10:19:23 +0200511 orig_name = op.name
512 op.name = f"{orig_name}_depthwise_conv_SHL_7"
Rickard Bolin6986a072022-12-19 12:33:40 +0000513 op.type = Op.DepthwiseConv2DBias
514 op.attrs.update(dw_op_attrs)
Johan Alfven56811e62023-03-27 11:33:50 +0200515 n, h, w, c = full_shape(4, ifm.shape, 1)
Rickard Bolin6986a072022-12-19 12:33:40 +0000516 shape = [1, 1, 1, c]
517 kernel = np.dstack([2**7] * c)
518 op.inputs = []
519 op.add_input_tensor(ifm)
520 op.add_input_tensor(
521 create_const_tensor(
522 "weights",
523 shape,
524 DataType.uint8,
525 np.array(kernel).reshape(shape),
526 quantization=identity_quant,
527 ),
528 )
529 # Let the bias for each channel be the "reverse" index of the channel it is in, ie c - channel_idx
530 reverse_idxs = list(reversed(range(c)))
531 bias_tensor = create_const_tensor(op.name + "_bias", [c], DataType.int64, reverse_idxs)
532 op.add_input_tensor(bias_tensor)
533
534 intermediate_tens = Tensor([n, h, w, c], DataType.int16, "int16_and_shifted_7_bits_left")
535 intermediate_tens.quantization = ifm.quantization
536 op.set_output_tensor(intermediate_tens)
537 op.set_ifm_ofm_shapes()
538 orig_ifm_shape = op.ifm_shapes[0]
539 DebugDatabase.add_optimised(op, op)
540
541 # To extract 7 least significant bits and swap reverse index back to real index using a LUT activation, we set
542 # the base value to c-1 and slope to -128. The 16-bit LUT uses a table of 32-bit values where the top 16 bits
543 # represent the slope and bottom 16 bits the base which are used to interpolate the activation value.
544 slope = (-128 & 0xFFFF) << 16 # Top 16 bits of 32 bit LUT table value
545 base = c - 1 # Bottom 16 bits of the LUT table value
546 lut_tensor = create_const_tensor(
547 "maxpool_LUT_extract_7_LSB",
548 [1, 1, 1, 512],
549 DataType.uint32,
550 [slope + base] * 512,
551 TensorPurpose.LUT,
552 )
553
554 # Split large feature maps into smaller chunks since the Depthwise Maxpool height dimension can overflow due to
555 # flattening the ifm to (H*W)xCx1
556 max_height = 2**16 // orig_ifm_shape.width
557 num_full_height_ops = orig_ifm_shape.height // max_height
558 last_op_height = orig_ifm_shape.height - max_height * num_full_height_ops
559 op_heights = [max_height] * num_full_height_ops
560 if last_op_height > 0:
561 op_heights.append(last_op_height)
562
563 # Create maxpool output tensor which is reshaped to 1x(H*W)x1x1. The product H*W might be larger than the
564 # maximum allowed height, but that's handled by reading and writing the data in chunks
565 maxpool_ofm = Tensor([1, orig_ifm_shape.height * orig_ifm_shape.width, 1, 1], DataType.int16, "argmax_maxpool")
566 maxpool_ofm.quantization = identity_quant
567
568 for op_idx, op_height in enumerate(op_heights):
569 maxpool_op = create_depthwise_maxpool(
570 f"dw_maxpool_{op_idx}", intermediate_tens, orig_ifm_shape, identity_quant
571 )
572 maxpool_op.outputs = [maxpool_ofm]
573 maxpool_ofm.ops.append(maxpool_op)
574 maxpool_op.ofm_shapes = [Shape4D(maxpool_ofm.shape)]
575 maxpool_op.set_activation_lut(lut_tensor)
576
577 # Set read and write shapes/offsets to read/write chunks of the IFM/OFM
578 maxpool_op.read_shapes[0] = Shape4D([1, op_height * orig_ifm_shape.width, orig_ifm_shape.depth, 1])
579 maxpool_op.read_offsets[0] = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0])
580 maxpool_op.write_shape = Shape4D([1, op_height * orig_ifm_shape.width, 1, 1])
581 maxpool_op.write_offset = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0])
582 DebugDatabase.add_optimised(op, maxpool_op)
583
Johan Alfvenc1ad80b2023-03-31 10:19:23 +0200584 # Set final shape
585 maxpool_ofm.set_all_shapes([1, h, w, 1])
586
587 # Convert 16bit to 32bit or 64bit
588 if ofm.dtype == DataType.int64:
589 # If OFM dtype is int64 the result is converted by two cast ops (16bit to 32bit)
590 #
591 # A -> B -> C -> D (OFM)
592 # |0001| |00010000| |0001|0000| |00010000|00000000|
593 # i16 i32 i16 i16 i32 i32
594 # <-------i64------->
595 #
596 # Memcpy is used to copy the content from B to C and from D to OFM
597 # Memcpy will be turned into a nop or an DMA transer if memory regions differs.
598 intermediate_32bit = Tensor([1, h, w, 1], DataType.int32, f"{orig_name}_32bit")
599 else:
600 intermediate_32bit = ofm
601
602 op_cast = create_cast_op(f"{orig_name}_cast_to_32bit_1", maxpool_ofm, intermediate_32bit)
603 DebugDatabase.add_optimised(op, op_cast)
604
605 if ofm.dtype == DataType.int64:
606 # Create int16 tensor with double shape to cover the intermediate_32bit result from the first cast
607 intermediate_16bit_2x_size = Tensor([1, h, w, 2], DataType.int16, f"{orig_name}_16bit_2x_size")
608 memcpy_op = create_memcpy(f"{orig_name}_memcpy_1", intermediate_32bit, intermediate_16bit_2x_size)
609 DebugDatabase.add_optimised(op, memcpy_op)
610
611 # Create int32 tensor with double ofm shape to be able to store a "int64" result
612 intermediate_32bit_2x_size = Tensor([1, h, w, 2], DataType.int32, f"{orig_name}_32bit_2x_size")
613
614 op_cast = create_cast_op(
615 f"{orig_name}_cast_to_32bit_2", intermediate_16bit_2x_size, intermediate_32bit_2x_size
616 )
617 DebugDatabase.add_optimised(op, op_cast)
618
619 memcpy_op = create_memcpy("f{orig_name}_memcpy_2", intermediate_32bit_2x_size, ofm)
620 DebugDatabase.add_optimised(op, memcpy_op)
Rickard Bolin6986a072022-12-19 12:33:40 +0000621
622 return op
623
624
Rickard Bolinfea15162022-07-04 16:19:16 +0000625def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
626 def _compute_interpolation_values(index, input_size, output_size):
627 scale = input_size / output_size
628 scaled_value = (index + 0.5 * half_pixel_centers) * scale - 0.5 * half_pixel_centers
629 lower_bound = max(np.floor(scaled_value), 0)
630
631 return scaled_value, lower_bound
632
633 def _compute_kernels(input_height, input_width, output_height, output_width):
634 kernels = []
635 for y in (1, 2):
636 for x in (1, 2):
637 sv_h, lb_h = _compute_interpolation_values(y, input_height, output_height)
638 sv_w, lb_w = _compute_interpolation_values(x, input_width, output_width)
639
640 # Interpolation values calculated for (x, y) = ([1, 2], [1, 2]) will always generalize to the whole
641 # input for upscale = 2 and input sizes >= 2x2 and be in the correct order for going left-to-right,
642 # top-to-bottom - same as the depthwise convolution strides across each tile
643 kernel = np.zeros((2, 2))
644 kernel[1, 1] = (1 - (sv_h - lb_h)) * (1 - (sv_w - lb_w))
645 kernel[0, 1] = (sv_h - lb_h) * (1 - (sv_w - lb_w))
646 kernel[1, 0] = (1 - (sv_h - lb_h)) * (sv_w - lb_w)
647 kernel[0, 0] = (sv_h - lb_h) * (sv_w - lb_w)
648 kernel *= 16
649 kernels.append(kernel)
650
651 return kernels
652
653 def _build_convolutions(op, kernels):
654 dw_op_attrs = {
655 "padding": Padding.TILE,
656 "stride_h": 1,
657 "stride_w": 1,
658 "strides": (1, 1, 1, 1),
659 "depth_multiplier": 1,
660 "channel_multiplier": 1,
661 "dilation_h_factor": 1,
662 "dilation_w_factor": 1,
663 "dilation": (1, 1, 1, 1),
664 }
665 ifm = op.ifm
666 ofm = op.ofm
667 ofm.ops = []
668 elem_size = 2 if ofm.dtype == DataType.int16 else 1
669
670 n, h, w, c = ifm.shape
671 _, _, ow, _ = ofm.shape
672
673 intermediate_tens = Tensor(ifm.shape, ifm.dtype, "intermediate_tens")
674 intermediate_tens.quantization = op.outputs[0].quantization.clone()
675 avgpool_op = op
676 avgpool_op.name = "rb_init_avgpool"
677 avgpool_op.type = Op.AvgPool
678 avgpool_op.attrs["padding"] = Padding.VALID
679 avgpool_op.attrs["stride_w"] = 1
680 avgpool_op.attrs["stride_h"] = 1
681 avgpool_op.attrs["filter_width"] = 1
682 avgpool_op.attrs["filter_height"] = 1
683 avgpool_op.attrs["strides"] = [1, 1, 1, 1]
684 avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
685
686 avgpool_op.add_input_tensor(ifm)
687 avgpool_op.set_output_tensor(intermediate_tens)
688 avgpool_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000689 DebugDatabase.add_optimised(op, op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000690
691 dw_conv = Operation(Op.DepthwiseConv2DBias, "depthwise_conv")
692 dw_conv._original_type = Op.ResizeBilinear
693 dw_conv.write_shape = Shape4D(n, h, w, c)
694 dw_conv.write_offset = Shape4D(0, 0, 0, 0)
695
696 # Set the output rounding mode. Resize bilinear requires rounding away from zero. Therefore, we need to
697 # adjust the accumulated value by a "small" amount before applying natural rounding. The "small" amount
698 # should be big enough to cause a x.5 to be rounded correctly but small enough not to cause smaller
699 # values to be incorrectly rounded
700 ofm.quantization.next_after = True
701 dw_conv.rounding_mode = NpuRoundingMode.NATURAL
702
703 # Double height and width stride to write the output of each of the four depthwise convolutions below
704 # interleaved with each other when combined with OFM tile base offsets.
705 dw_conv.ofm_stride_multiplier = [1, 2, 2] # C/H/W
706
707 # Choose tile padding direction - pad by 1 with edge values in two direction.
708 # For example, TL (top left) will pad top and left in H/W-plane in all channels.
709 directions = [[1, 1, 0, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 1]] # TL, TR, BL, BR
710 for i in (0, 1):
711 for j in (0, 1):
712 index = i * 2 + j
713 dw_conv.name = f"depthwise_conv_{index}"
714 dw_op_attrs["explicit_padding"] = directions[index]
715 dw_conv.attrs.update(dw_op_attrs)
716
717 # This will offset the start of the write by modifying the Tile 0 base address
718 dw_conv.tile_base_offsets_ofm[0] = (i * ow + j) * c * elem_size
719
720 ofm.ops.append(dw_conv)
721 dw_conv.outputs = [ofm]
722
723 kernel = kernels[index]
724 shape = [2, 2, 1, c]
725 kernel = np.dstack([kernel] * c)
726
727 quant = QuantizationParameters()
728 quant.zero_point = 0
729 quant.scale_f32 = 1.0 / 16
730
731 dw_conv.inputs = []
732 dw_conv.add_input_tensor(intermediate_tens)
733 dw_conv.add_input_tensor(
734 create_const_tensor(
735 "weights",
736 shape,
737 intermediate_tens.dtype,
738 np.array(kernel).reshape(shape),
Rickard Bolinfea15162022-07-04 16:19:16 +0000739 quantization=quant,
740 ),
741 )
742
743 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
744 # need to append the bias tensor as resize ops only have 2 inputs
745 assert len(dw_conv.inputs) == 2
746 dw_conv.inputs.append(None)
Rickard Bolin017b4cc2022-09-23 10:16:48 +0000747 fixup_bias_tensors(dw_conv, None, None, dtype=DataType.int32)
Rickard Bolinfea15162022-07-04 16:19:16 +0000748
749 dw_conv.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000750 DebugDatabase.add_optimised(op, dw_conv)
751
Rickard Bolinfea15162022-07-04 16:19:16 +0000752 dw_conv = dw_conv.clone(f"_{index}")
753 return op
754
755 _, input_height, input_width, _ = op.ifm.shape
756 _, output_height, output_width, _ = op.ofm.shape
757
758 kernels = _compute_kernels(input_height, input_width, output_height, output_width)
759 op = _build_convolutions(op, kernels)
760
761 return op
762
763
Tim Hall885033b2022-07-21 11:46:03 +0100764def fixup_resize(op, arch, nng):
765 if op.type.is_resize_op() and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200766 if op.ifm_shapes[0] == op.ofm_shapes[0]:
Tim Hall885033b2022-07-21 11:46:03 +0100767 # Bypass the resize op which is essentially a NOP
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200768 op.inputs = op.inputs[:1]
769 op.type = Op.Identity
770 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
Tim Hall885033b2022-07-21 11:46:03 +0100771 convert_resize_1x1_to_add(op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000772 elif op.type == Op.ResizeBilinear and op.attrs.get("half_pixel_centers", False):
773 convert_resizebilinear_to_depthwise_convolutions(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200774 else:
Tim Hall885033b2022-07-21 11:46:03 +0100775 convert_resize_to_upscale_and_average_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200776
777 return op
778
779
780def convert_nop_split_to_identity(op, arch, nng):
781 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
782 # the list comprehension should return a list with a single tensor
783 # if it shouldn't, remove_passthrough_tensor will fail appropriately
784 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
785 op.type = Op.Identity
786 return op
787
788
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100789def rewrite_fully_connected_input(op: Operation, arch, nng):
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200790 # If the operation already have a read shape do not modify
791 # the ifm shape, since that will already be correct
792 if op.type == Op.FullyConnected and not op.read_shapes[0]:
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100793 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
794 assert new_shape is not None, "Tensor can not be reshaped to 2D"
795 op.ifm_shapes[0] = new_shape
Johan Alfvén65835e02022-10-13 10:49:30 +0200796
797 if op.ifm_shapes[0].batch > 1 and op.ofm_shapes[0].batch == 1:
798 # If IFM is batching then also make sure OFM is batching
799 h, w = op.ofm_shapes[0].height, op.ofm_shapes[0].width
800 op.ofm_shapes[0] = Shape4D([h * w, 1, 1, op.ofm_shapes[0].depth])
801
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200802 return op
803
804
805def convert_batched_fc_shape(op, arch, nng):
806 if op.type == Op.FullyConnected:
807 # Check if the first dimension indicates batching
808 if op.ifm_shapes[0].batch > 1:
809 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
810 n = op.ifm_shapes[0].batch
811 h, w = batching_split.get(n, (1, n))
812 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
813
814 # Reshape Weights to be 4D. IO becomes HWIO
815 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100816 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
817 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200818
819 n = op.ofm_shapes[0].batch
820 h, w = batching_split.get(n, (1, n))
821 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
822 return op
823
824
825def unfuse_activation_function(op):
826 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
827 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
828 op.activation = None
829 out_tens = op.outputs[0]
830 intermediate_tens = out_tens.clone("_act_intermediate")
831 act_op.set_output_tensor(out_tens)
832 act_op.add_input_tensor(intermediate_tens)
833 op.set_output_tensor(intermediate_tens)
834 act_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000835 DebugDatabase.add_optimised(op, act_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200836
837
838def rewrite_stridedslice_output(op, arch, nng):
839 if not op.run_on_npu or op.type != Op.StridedSlice:
840 return op
841
842 new_axis_mask = op.attrs["new_axis_mask"]
843 shrink_axis_mask = op.attrs["shrink_axis_mask"]
844
845 if shrink_axis_mask == 0 and new_axis_mask == 0:
846 return op
847
848 axis_4D = [0] * len(op.outputs)
849 for idx, out_tens in enumerate(op.outputs):
850 output_shape = list(out_tens.shape)
851
852 if shrink_axis_mask != 0:
853 n = 0
854 axis = 0
855 while shrink_axis_mask:
856 prev_mask = shrink_axis_mask
857 n += 1
858 shrink_axis_mask &= shrink_axis_mask - 1
859 axis = int(math.log2(prev_mask - shrink_axis_mask))
860 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
861
862 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
863 op.attrs["shrink_axis_mask"] = 0
864 if axis >= 0:
865 axis_4D[idx] = axis + (4 - len(output_shape))
866 else:
867 axis_4D[idx] = axis
868 op.ofm_shapes[idx] = Shape4D(output_shape)
869
870 elif new_axis_mask != 0:
871 n = 0
872 axis = 0
873 while new_axis_mask:
874 prev_mask = new_axis_mask
875 n += 1
876 new_axis_mask &= new_axis_mask - 1
877 axis = int(math.log2(prev_mask - new_axis_mask))
878 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
879 new_axis_mask >>= 1
880
881 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
882 op.attrs["new_axis_mask"] = 0
883 if axis >= 0:
884 axis_4D[idx] = axis + (4 - len(output_shape))
885 else:
886 axis_4D[idx] = axis
887 op.ofm_shapes[idx] = Shape4D(output_shape)
888
889 op.attrs["split_axis_4D"] = axis_4D
890 return op
891
892
893def rewrite_unpack_output(op, arch, nng):
894 tens = op.outputs[0]
895 if op.run_on_npu and op.type == Op.Unpack:
896 # Unpack is also referred to as Unstack
897 axis = int(op.attrs["axis"])
898 if axis < 0: # Convert to positive axis
899 axis = len(op.inputs[0].shape) + 1 + axis
900 op.type = Op.UnpackReshaped
901 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
902
903 axis_4D = axis + (4 - len(desired_output_shape))
904 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
905
906 for idx, out_tens in enumerate(op.outputs):
907 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
908 return op
909
910
911def add_padding_fields(op, arch, nng):
912 if op.run_on_npu:
913 if "padding" in op.attrs:
914 input_shape = op.ifm_shapes[0]
915 output_shape = op.ofm_shapes[0]
916 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
917 kernel_size = op.inputs[1].shape[:2]
918 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
919 kernel_size = op.attrs["ksize"][1:3]
920 else:
921 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
922
923 if op.type == Op.Conv2DBackpropInputSwitchedBias:
924 upscaling_factor = output_shape.height // input_shape.height
925 padding, skirt = calc_upscaled_padding_and_skirt(
926 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
927 )
928 else:
929 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200930 op.attrs["padding"],
931 op.kernel,
932 input_shape,
933 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200934 )
935
936 op.attrs["explicit_padding"] = padding
937 op.attrs["skirt"] = skirt
938
939 return op
940
941
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200942def reorder_depthwise_weights(op, arch, nng):
943 if op.type.is_depthwise_conv2d_op():
944 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100945 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
946 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200947 weight_tensor.weight_transpose_depthwise = True
948
949 return op
950
951
Raul Farkas72c6a242023-03-16 16:38:05 +0000952def fixup_strided_conv(op: Operation, arch, nng):
953 """Optimize or fixup strided Conv2DBias
954 Optimization:
955 Reduce, when possible, the Conv2DBias stride from 2 to 1 by re-shaping
956 both IFM and filter.
957
958 Fixup:
959 Introduce software support for Conv2DBias with stride_width = 4 by
960 reducing it to 1 when possible by re-shaping both IFM and filter.
961 """
Raul Farkas090f18a2023-01-24 16:29:06 +0000962 if op.type != Op.Conv2DBias:
Louis Verhaard43d27582022-03-17 14:06:00 +0100963 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200964 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100965 weight_tensor = op.weights
966 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200967 if (
Raul Farkas090f18a2023-01-24 16:29:06 +0000968 (stride_x == 2 or stride_x == 4)
Louis Verhaard43d27582022-03-17 14:06:00 +0100969 and ifm_shape.depth <= 4
970 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200971 and weight_tensor is not None
972 and weight_tensor.shape[1] >= 2
973 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100974 k_w, _ = op.get_kernel_size()
Raul Farkas090f18a2023-01-24 16:29:06 +0000975 curr_padding_x = needed_total_padding(ifm_shape.width, stride_x, k_w)
976 optimised_padding_x = needed_total_padding(ifm_shape.width // stride_x, 1, (k_w + 1) // stride_x)
977 padding_type = op.attrs.get("padding", None)
978
979 # If padding is enabled, check if current padding matches optimised padding
980 if not padding_type or (padding_type != Padding.VALID and curr_padding_x != optimised_padding_x):
Louis Verhaard43d27582022-03-17 14:06:00 +0100981 # Horizontal padding would become different after optimisation; this would not work
982 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200983 # IFM
Raul Farkas090f18a2023-01-24 16:29:06 +0000984 op.ifm_shapes[0] = Shape4D(
985 [ifm_shape.batch, ifm_shape.height, ifm_shape.width // stride_x, ifm_shape.depth * stride_x]
986 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200987
988 # Weights
989 weight_shape = weight_tensor.shape
990 if weight_shape[1] % 2 != 0:
991 weight_shape[1] = weight_shape[1] + 1
992 padded_array = np.zeros(weight_shape)
993 for i in range(weight_shape[0]):
994 padded_array[i] = np.vstack(
995 [
James Peet7519d502021-07-19 16:47:58 +0100996 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200997 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
998 ]
999 )
James Peet7519d502021-07-19 16:47:58 +01001000 weight_tensor.values = padded_array
Raul Farkas090f18a2023-01-24 16:29:06 +00001001
1002 # Change weight shape based on stride_x
1003 weight_shape[1] //= stride_x
1004 weight_shape[2] *= stride_x
1005
James Peet7519d502021-07-19 16:47:58 +01001006 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001007 weight_tensor.set_all_shapes(weight_shape)
1008 # If multiple copies of the weights are used, we could avoid
1009 # them having the same address by changing the value_id
1010 weight_tensor.value_id = uuid.uuid4()
1011
1012 # Strides
1013 stride_x = 1
1014 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
1015
Raul Farkas72c6a242023-03-16 16:38:05 +00001016 op.ifm.force_linear_format = True
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001017 return op
1018
1019
1020def convert_conv_to_fc(op, arch, nng):
1021 # Conv 1x1 can be equivalent to Fully Connected.
1022 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
1023 # caching/double buffering for the weights.
1024 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
1025 if op.type == Op.Conv2DBias:
1026 h = op.ifm_shapes[0].height
1027 w = op.ifm_shapes[0].width
1028 kh, kw, _, _ = op.inputs[1].shape
1029 if h == 1 and w == 1 and kh == 1 and kw == 1:
1030 # Overwrite this op as a Fully Connected Op
1031 op.name += "_fc"
1032 op.type = Op.FullyConnected
1033 op.attrs = {
1034 "weights_format": 0,
1035 }
1036 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
1037 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +01001038 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
1039 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001040
1041 DebugDatabase.add_optimised(op, op)
1042 return op
1043
1044
1045def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
1046 if op.run_on_npu and op.type.is_relu_op():
1047 ifm = op.inputs[0]
1048 ofm = op.outputs[0]
1049 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
1050 # and requires its own to be inserted
1051 if not check_quantized_tens_scaling_equal(ifm, ofm):
1052 # Override this op with its own primary op (avgpool)
1053 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
1054 # And fuse the original activation function to it
1055 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +02001056 # Add explicit rescaling
1057 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
1058 multiplier, shift = scaling.quantise_scale(rescale)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001059 relu_fused_op.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001060 # Tidy up and assign the ifm and ofm to the new op
1061 ifm.consumer_list.remove(op)
1062
1063 relu_fused_op.add_input_tensor(ifm)
1064 relu_fused_op.set_output_tensor(ofm)
1065 relu_fused_op.set_ifm_ofm_shapes()
1066 op = relu_fused_op
1067 return op
1068
1069
Fredrik Svedberg0ac08042023-04-11 22:35:04 +02001070def convert_lstm(op, arch, nng):
1071 if op.type == Op.UnidirectionalSequenceLstm:
1072 lstm = Lstm(op)
1073 op = lstm.get_graph()
1074 return op
1075
1076
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001077def convert_softmax(op, arch, nng):
1078 if op.type == Op.Softmax and op.run_on_npu:
1079 softmax = SoftMax(op)
1080 op = softmax.get_graph()
1081 return op
1082
1083
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001084def convert_prelu(op, arch, nng):
1085 if op.type == Op.Prelu:
1086 ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
1087 if None in (ifm, alpha, ofm):
1088 return op
1089
Fredrik Svedberg66591652022-08-29 10:51:27 +02001090 if alpha.values is not None:
1091 # If const alpha check for possible optimisations
1092 alpha_zp = alpha.quantization.zero_point
1093 alpha_scale = alpha.quantization.scale_f32
1094 # If all alpha values are the same the PReLU can be converted to LeakyRelu
Rickard Bolin5fdcf172022-12-19 12:56:17 +00001095 alpha_min = (alpha.values.min().astype(int) - alpha_zp) * alpha_scale
1096 alpha_max = (alpha.values.max().astype(int) - alpha_zp) * alpha_scale
Fredrik Svedberg66591652022-08-29 10:51:27 +02001097 if alpha_min == alpha_max:
1098 # or even a Relu
1099 if alpha_min == 0:
1100 new_op = Op.Relu
1101 else:
1102 new_op = Op.LeakyRelu
1103 op.attrs["alpha"] = alpha_min
1104 # setup alpha_scaling for bit exact result
1105 ifm_scale = ifm.quantization.scale_f32
1106 ofm_scale = ofm.quantization.scale_f32
1107 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha_scale, ofm_scale)
1108 op.attrs["alpha_scaling"] = (alpha.values.min() - alpha_zp, alpha_scale, alpha_shift)
1109 # Change op type
1110 op.type = new_op
1111 op.name = op.name.replace("Prelu", new_op.name)
1112 del op.inputs[1] # Remove alpha tensor
1113 return op
1114 elif alpha_max < 1:
1115 # If alpha_max is less than 1 convert PReLU to Max(alpha * IFM, identity * IFM)
1116 # Multiply with alpha tensor
1117 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
1118 mul_alpha.add_input_tensor(ifm)
1119 mul_alpha.add_input_tensor(alpha)
1120 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1121 mul_alpha.set_output_tensor(fm_alpha)
1122 mul_alpha.set_ifm_ofm_shapes()
1123 DebugDatabase.add_optimised(op, mul_alpha)
1124 if check_quantized_tens_scaling_equal(ifm, ofm):
1125 # No scaling is needed
1126 fm_id = ifm
1127 else:
1128 # Add multiplication with identity
1129 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1130 mul_identity.add_input_tensor(ifm)
1131 # Create const tensor containing identity as scalar
1132 quantization = ifm.quantization.clone()
1133 quantization.scale_f32 = np.float32(1)
1134 quantization.zero_point = 0
1135 one = create_const_tensor("one_const", [], ifm.dtype, [1], quantization=quantization)
1136 mul_identity.add_input_tensor(one)
1137 # Make sure that fm_id is allocated to a different address than fm_alpha
1138 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1139 mul_identity.set_output_tensor(fm_id)
1140 mul_identity.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +00001141 DebugDatabase.add_optimised(op, mul_identity)
Fredrik Svedberg66591652022-08-29 10:51:27 +02001142
1143 # Combine scaled and alpha multiplied values
1144 max_op = Operation(Op.Maximum, op.name + "_max")
1145 max_op.add_input_tensor(fm_alpha)
1146 max_op.add_input_tensor(fm_id)
1147 max_op.set_output_tensor(ofm)
1148 max_op.set_ifm_ofm_shapes()
1149
1150 DebugDatabase.add_optimised(op, max_op)
1151 ifm.consumer_list.remove(op)
1152 return max_op
1153
1154 # Catch all PReLU conversion for the cases that could not be optimised above
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001155 no_scale_quant = ifm.quantization.clone()
1156 no_scale_quant.scale_f32 = None
1157 no_scale_quant.zero_point = 0
Fredrik Svedberg66591652022-08-29 10:51:27 +02001158 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001159
1160 # Select values < 0
1161 min_op = Operation(Op.Minimum, op.name + "_min")
1162 min_op.add_input_tensor(ifm)
1163 min_op.add_input_tensor(zero)
1164 fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
1165 min_op.set_output_tensor(fm_negative)
1166 min_op.set_ifm_ofm_shapes()
1167 DebugDatabase.add_optimised(op, min_op)
1168
1169 # and multiply with alpha tensor
1170 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
1171 mul_alpha.add_input_tensor(fm_negative)
1172 mul_alpha.add_input_tensor(alpha)
1173 fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
1174 mul_alpha.set_output_tensor(fm_alpha)
1175 mul_alpha.set_ifm_ofm_shapes()
1176 DebugDatabase.add_optimised(op, mul_alpha)
1177
1178 # Select (and scale) values > 0
1179 relu_op = Operation(Op.Relu, op.name + "_relu")
1180 relu_op.add_input_tensor(ifm)
1181 fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1182 relu_op.set_output_tensor(fm_scaled)
1183 relu_op.set_ifm_ofm_shapes()
1184 DebugDatabase.add_optimised(op, relu_op)
1185
1186 # Add scaled and alpha multiplied values (without scaling)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001187 add_op = Operation(Op.Add, op.name + "_add")
1188 add_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001189 add_op.add_input_tensor(fm_alpha)
1190 add_op.add_input_tensor(fm_scaled)
1191 add_op.set_output_tensor(ofm)
1192 add_op.set_ifm_ofm_shapes()
1193
1194 DebugDatabase.add_optimised(op, add_op)
1195 ifm.consumer_list.remove(op)
1196 op = add_op
1197
1198 return op
1199
1200
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001201def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
1202 r"""Whenever there is a subgraph with this topology:
1203
Jonas Ohlssond8575072022-03-30 10:30:25 +02001204 Input X For X = -1 or X > 0
1205 | \ / This subgraph can be replaced with either
1206 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
1207 | /
1208 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001209 """
1210
1211 if op.type == Op.Maximum:
1212 # finds the Mul input(s) to the Max
1213 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
1214 if len(muls) == 1:
1215 mul = muls[0].ops[0]
1216 elif len(muls) == 2:
1217 # In the case both inputs are Muls, find the one with the same input as the Max
Fredrik Svedberg66591652022-08-29 10:51:27 +02001218 mul_ifms = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1]
1219 if len(mul_ifms):
1220 mul = mul_ifms[0].ops[0]
1221 else:
1222 # Not using same input
1223 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001224 else:
1225 # No Mul inputs
1226 return op
1227
1228 # make sure the Mul doesn't have any other consumers
1229 mul_ofm = mul.outputs[0]
1230 if len(mul_ofm.consumers()) != 1:
1231 return op
1232 # make sure the Mul doesn't have a fused activation function
1233 if mul.activation:
1234 return op
1235 ifm, ofm = op.get_ifm_ofm()
1236 if ifm is None or ofm is None:
1237 return op
1238
1239 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1240 return op
1241 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
1242 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
1243 return op
1244
1245 # finds the branched input that goes to both the Max and the Mul
1246 shared = set(op.inputs) & set(mul.inputs)
1247 if len(shared) == 1:
1248 shared_in = shared.pop()
1249 # find the constant scalar input to the Mul
1250 const_tens = (set(mul.inputs) - {shared_in}).pop()
1251 # check that it is a scalar
1252 if const_tens.shape != []:
1253 return op
1254 const = const_tens.ops[0]
1255 # check that it is a constant
1256 if const.type != Op.Const:
1257 return op
1258 # Remove the Mul from the shared input's consumers
1259 shared_in.consumer_list.remove(mul)
1260 else:
1261 return op
1262
1263 val = const.outputs[0].values
1264 if val >= 0:
1265 new_op = Op.LeakyRelu
1266 op.attrs["alpha"] = val
1267 # to produce bit exact results, the alpha is not enough;
1268 # save additional scaling info in attr "alpha_scale", to be used as input
1269 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +01001270 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001271 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
1272 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
1273 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
1274 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
1275 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
1276 elif val == -1:
1277 new_op = Op.Abs
1278 else:
1279 return op
1280
1281 op.type = new_op
1282 op.name = op.name.replace("Maximum", new_op.name)
1283 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
1284 op.inputs = [shared_in]
1285 op.set_ifm_ofm_shapes()
1286
1287 # Record optimisation in debug database
1288 DebugDatabase.add_optimised(op, op)
1289
1290 return op
1291
1292
1293def convert_hardswish_to_lut(op, arch, nng):
1294 if op.type == Op.HardSwish:
1295 ifm, ofm = op.get_ifm_ofm()
1296 # Generate the LUT
1297 ifm_scale = np.double(ifm.quantization.scale_f32)
1298 ofm_scale = np.double(ofm.quantization.scale_f32)
1299 zp_in = ifm.quantization.zero_point
1300 zp_out = ofm.quantization.zero_point
1301 ifm_scale_hires = (1 / 128) * ifm_scale
1302 relu_multiplier = np.double(3 / 32768)
1303 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
1304 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
1305 # Use 16bit scale
1306 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
1307 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
1308
1309 values = []
1310 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1311 quantized_min = min(ix)
1312 quantized_max = max(ix)
1313 for x in ix:
1314 input_value = x - zp_in
1315 input_value_hires = input_value * 128
1316 # Compute the input value on essentially the output scale, not shifted yet
1317 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
1318 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
1319 relu_value = np.int16(input_value_hires)
1320 if relu_shift < 31:
1321 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
1322
1323 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
1324
1325 if relu_shift < 31:
1326 relu_value = fp_math.shift_left16(relu_value, 1)
1327
1328 if relu_shift > 31:
1329 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
1330
1331 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
1332 # Now convert that to a 16bit fixedpoint value in [0, 1]
1333 relu_value = (relu_value + (1 << 15)) >> 1
1334 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
1335 shift = 31 - out_shift
1336 shift = -shift if shift < 0 else 0
1337 # Finally apply the output shift
1338 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
1339 lut_result = min(quantized_max, max(quantized_min, lut_result))
1340 values.append(lut_result)
1341 return convert_to_lut(op, values, "hardswish")
1342 return op
1343
1344
1345def convert_lrelu_to_mul_max(op, arch):
1346 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
1347 # (the opposite of convert_mul_max_to_abs_or_lrelu)
1348 ifm, ofm = op.get_ifm_ofm()
1349 if ifm is None or ofm is None:
1350 return op
1351
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001352 alpha = np.float32(op.attrs["alpha"])
1353 use_mul_max = 0 < alpha < 1
Fredrik Svedberg36424312022-09-16 09:39:26 +02001354 is_converted_prelu = "alpha_scaling" in op.attrs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001355 if use_mul_max:
1356 mul_ifm = ifm
1357 new_op = Op.Maximum
1358 else:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001359 # Need to use a different approach for alpha < 0 or alpha > 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001360 no_scale_quant = ifm.quantization.clone()
1361 no_scale_quant.scale_f32 = None
1362 no_scale_quant.zero_point = 0
1363 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
1364
1365 # Select values < 0
1366 min_op = Operation(Op.Minimum, op.name + "_min")
1367 min_op.add_input_tensor(ifm)
1368 min_op.add_input_tensor(zero)
1369 mul_ifm = ifm.clone(op.name + "_negative", set_unique=True)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001370 if alpha < 0 and not is_converted_prelu:
1371 # For negative alpha that is not from a converted PReLU we need to use
1372 # int32 Mul below to perform the (negative) alpha scaling
1373 mul_ifm.dtype = DataType.int32
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001374 min_op.set_output_tensor(mul_ifm)
1375 min_op.set_ifm_ofm_shapes()
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001376 new_op = Op.Add
1377 op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001378 DebugDatabase.add_optimised(op, min_op)
1379
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001380 # Add multiplication with alpha
1381 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001382 mul_alpha.add_input_tensor(mul_ifm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001383 # Create const tensor containing alpha as scalar
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001384 quantization = ifm.quantization.clone()
1385 quantization.min = 0
1386 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
1387 quantization.zero_point = 0
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001388 alpha_dtype = mul_ifm.dtype
Fredrik Svedberg36424312022-09-16 09:39:26 +02001389 if is_converted_prelu:
1390 # The LeakyRelu was the result from convert_prelu and the scaling is provided
Fredrik Svedberg66591652022-08-29 10:51:27 +02001391 scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001392 mul_alpha.explicit_scaling = ExplicitScaling(False, [alpha_shift], [alpha_scale])
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001393 elif alpha == 0 or np.isinf(1 / alpha):
1394 # Handling of alpha near or at zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001395 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001396 scalar = 0
1397 else:
1398 quantization.scale_f32 = alpha
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001399 if alpha_dtype == DataType.int32:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001400 # When the datatype is int32 (alpha negative) we need to do the scaling with the multiplication
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001401 scalar, _ = scaling.elementwise_mul_scale(ifm.quantization.scale_f32, alpha, ofm.quantization.scale_f32)
1402 else:
1403 scalar = 1
Tim Hall3b1578e2023-01-13 17:57:25 +00001404 alpha_tens = create_const_tensor(op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], quantization=quantization)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001405 mul_alpha.add_input_tensor(alpha_tens)
1406 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1407 mul_alpha.set_output_tensor(fm_alpha)
1408 mul_alpha.set_ifm_ofm_shapes()
1409 DebugDatabase.add_optimised(op, mul_alpha)
1410
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001411 if not use_mul_max:
1412 relu_op = Operation(Op.Relu, op.name + "_relu")
1413 relu_op.add_input_tensor(ifm)
1414 fm_id = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1415 relu_op.set_output_tensor(fm_id)
1416 relu_op.set_ifm_ofm_shapes()
1417 DebugDatabase.add_optimised(op, relu_op)
1418 elif check_quantized_tens_scaling_equal(ifm, ofm):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001419 # No identity multiplication is needed
1420 fm_id = ifm
1421 else:
1422 # Add multiplication with identity
1423 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1424 mul_identity.add_input_tensor(ifm)
1425 # Create const tensor containing identity as scalar
1426 quantization = ifm.quantization.clone()
1427 quantization.min = 0
1428 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001429 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001430 quantization.zero_point = 0
Tim Hall3b1578e2023-01-13 17:57:25 +00001431 identity_tens = create_const_tensor(op.name + "_id_scalar", [], ifm.dtype, [1], quantization=quantization)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001432 mul_identity.add_input_tensor(identity_tens)
1433 # Make sure that fm_id is allocated to a different address than fm_alpha
1434 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1435 mul_identity.set_output_tensor(fm_id)
1436 mul_identity.set_ifm_ofm_shapes()
1437 DebugDatabase.add_optimised(op, mul_identity)
1438
1439 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001440 op.type = new_op
1441 op.name = op.name.replace("LeakyRelu", new_op.name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001442 op.inputs = []
1443 ifm.consumer_list.remove(op)
1444 op.add_input_tensor(fm_alpha)
1445 op.add_input_tensor(fm_id)
1446 op.set_ifm_ofm_shapes()
1447
1448 DebugDatabase.add_optimised(op, op)
1449 return op
1450
1451
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001452def convert_to_lut8(op, fn, fn_name):
1453 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
1454 # fn is a function(real) -> real
1455 ifm, ofm = op.get_ifm_ofm()
1456 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1457 return op
1458 # Generate the LUT
1459 ifm_scale = np.double(ifm.quantization.scale_f32)
1460 ofm_scale = np.double(ofm.quantization.scale_f32)
1461 zp_in = ifm.quantization.zero_point
1462 zp_out = ofm.quantization.zero_point
1463 values = []
1464 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1465 quantized_min = min(ix)
1466 quantized_max = max(ix)
1467 for x in ix:
1468 x_real = ifm_scale * (x - zp_in)
1469 y_real = fn(x_real)
1470 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1471 lut_result = min(quantized_max, max(quantized_min, lut_result))
1472 values.append(lut_result)
1473 return convert_to_lut(op, values, fn_name)
1474
1475
1476def convert_lrelu_to_lut(op, arch):
1477 ifm, ofm = op.get_ifm_ofm()
1478 # Generate the LUT
1479 alpha = op.attrs["alpha"]
1480 ifm_scale = np.double(ifm.quantization.scale_f32)
1481 ofm_scale = np.double(ofm.quantization.scale_f32)
1482 zp_in = ifm.quantization.zero_point
1483 zp_out = ofm.quantization.zero_point
1484 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1485 alpha_scalar = 1
1486 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1487 if "alpha_scaling" in op.attrs:
1488 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1489 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1490 values = []
1491 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1492 quantized_min = min(ix)
1493 quantized_max = max(ix)
1494 for x in ix:
1495 if x < zp_in:
1496 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1497 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1498 )
1499 else:
1500 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1501 lut_result = min(quantized_max, max(quantized_min, lut_result))
1502 values.append(lut_result)
1503 return convert_to_lut(op, values, "lrelu")
1504
1505
1506def convert_lrelu(op, arch, nng):
1507 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
1508 if op.type != Op.LeakyRelu:
1509 return op
1510 ifm, ofm = op.get_ifm_ofm()
1511 if ifm is None or ofm is None:
1512 return op
Fredrik Svedberg36424312022-09-16 09:39:26 +02001513 alpha = op.attrs["alpha"]
1514 if alpha == 0:
1515 # When alpha is 0 the opertion can be converted to a ReLU
1516 op.type = Op.Relu
1517 op.name = op.name.replace("LeakyRelu", op.type.name)
1518 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001519 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1520 # use LUT for int8/uint8
1521 return convert_lrelu_to_lut(op, arch)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001522 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16 and alpha > 0:
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001523 # use LeakyRelu unmodified for int16 with equal input/output scaling and positive alpha
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001524 return op
1525 return convert_lrelu_to_mul_max(op, arch)
1526
1527
1528def convert_tanh_sigmoid_to_lut(op, arch, nng):
1529 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
1530 if op.type == Op.Sigmoid:
1531 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1532 elif op.type == Op.Tanh:
1533 return convert_to_lut8(op, math.tanh, "tanh")
1534 return op
1535
1536
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001537def fuse_activation_function_with_prev(op, arch, nng):
1538 # if op is a no-op: attempts to move the activation function to the preceding op
1539 if not op.attrs.get("is_nop", False) or op.activation is None:
1540 return op
1541 ifm, ofm = op.get_ifm_ofm()
1542 if ifm is None or ofm is None:
1543 return op
1544 # finds the input(s) to the operation
1545 prev_op = ifm.ops[0]
1546 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1547 fuse = (
1548 prev_op.run_on_npu
1549 and prev_op.type.npu_block_type != NpuBlockType.Default
1550 and len(ifm.ops) == 1
1551 and len(prev_op.outputs[0].consumers()) == 1
1552 and prev_op.activation is None
1553 )
1554 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1555 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1556 # LUT currently only works correctly for elementwise ops
1557 fuse = False
1558 if not fuse:
1559 return op
1560 # Move the fused activation function + corresponding info to prev_op
1561 prev_op.activation = op.activation
1562 prev_op.forced_output_quantization = op.forced_output_quantization
1563 if op.activation_lut is not None:
1564 prev_op.set_activation_lut(op.activation_lut)
1565 # Bypass op
1566 prev_op.set_output_tensor(ofm)
wilisa0179a89042022-11-02 17:18:43 +00001567 DebugDatabase.add_optimised(prev_op, prev_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001568 return op
1569
1570
1571def _leading_pad_ok(leading_pad, stride, kernel_size):
1572 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1573 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1574 max_size = kernel_size // 2
1575 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1576
1577
1578def replace_pad_by_hw_pad(op: Operation, arch, nng):
1579 """
1580 Tries to completely remove a PAD operator by using hardware padding.
1581 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1582 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1583 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1584 if both operations can be run on the NPU.
1585 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1586 """
1587 if (
1588 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +00001589 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001590 and op.run_on_npu
1591 and op.attrs["padding"] == Padding.VALID
1592 ):
1593 pad_op = op.ifm.ops[0]
1594 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1595 return op
1596 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1597 return op
1598 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1599 k = op.kernel
1600 k_w, k_h = k.dilated_wh()
1601
1602 # Check if the PAD operator can be replaced by hardware padding
1603 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1604 # Too much padding, it would require hardware padding to actually insert zeros
1605 return op
1606 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1607 return op
1608
1609 if op.type.is_avgpool_op():
1610 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1611 for pad, k_size in (
1612 (left, k_w),
1613 (right, k_w),
1614 (top, k_h),
1615 (bottom, k_h),
1616 ):
1617 if pad not in (0, k_size // 2):
1618 return op
1619 # Average pool is converted to depthwise, because NPU average pool + same padding
1620 # has a special implementation that is different from PAD followed by average pool with
1621 # valid padding.
1622 k_w, k_h = op.kernel.width, op.kernel.height
1623 ifm = op.ifm
1624 # Remember other inputs
1625 other_inputs = op.inputs[1:]
1626 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1627 quantization = QuantizationParameters(0.0, 255.0)
1628 quantization.scale_f32 = 1.0 / (k_w * k_h)
1629 quantization.zero_point = 0
1630 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1631 weights = np.full(shape, 1)
1632
1633 weight_tens = create_const_tensor(
1634 op.name + "_weights",
1635 shape,
1636 op.ifm.dtype,
1637 weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001638 purpose=TensorPurpose.Weights,
1639 quantization=quantization,
1640 )
James Peet7519d502021-07-19 16:47:58 +01001641 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001642 op.type = Op.DepthwiseConv2DBias
1643 op.inputs = []
1644 op.add_input_tensor(ifm)
1645 op.add_input_tensor(weight_tens)
1646 # Add bias tensor, all biases set to 0
1647 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001648 fixup_bias_tensors(op, arch, nng, DataType.int32)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001649 # Add other inputs
1650 op.inputs.extend(other_inputs)
1651 op.rounding_mode = NpuRoundingMode.NATURAL
1652
1653 # Bypass the PAD operator
1654 op.set_input_tensor(pad_op.ifm, 0)
1655 # Adjust the padding attributes of the convolution operator
1656 op.attrs["padding"] = Padding.EXPLICIT
1657 op.attrs["explicit_padding"] = (top, left, bottom, right)
1658 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +00001659 DebugDatabase.add_optimised(op, op)
1660
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001661 return op
1662
1663
1664def convert_pad(op: Operation, arch, nng):
1665 """
1666 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1667 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1668 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1669 """
1670 if op.type != Op.Pad or not op.run_on_npu:
1671 return op
1672 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1673
1674 ifm = op.ifm
1675 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001676 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001677 ofm = op.ofm
1678 assert ofm is not None
1679 ofm.ops = []
1680 ofm_shape = op.ofm_shapes[0]
1681
1682 # Average pool op that copies IFM to the right place inside the OFM
1683 shp0 = Shape4D(0, 0, 0, 0)
1684 shp_top = shp0.with_height(top)
1685 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1686 avgpool_op.activation = op.activation
1687 quant = ofm.quantization
1688 pad_value = quant.zero_point
1689 # Add operations that fill the borders of the OFM
1690 if top > 0:
1691 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1692 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001693 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001694 )
1695 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1696 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1697 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1698 if bottom > 0:
1699 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1700 zero_tens = create_const_tensor(
1701 op.name + "_bottom",
1702 shape.as_list(),
1703 ofm.dtype,
1704 shape.elements() * [pad_value],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001705 quantization=quant,
1706 )
1707 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1708 create_avg_pool_for_concat(
1709 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1710 )
1711 if left > 0:
1712 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1713 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001714 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001715 )
1716 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1717 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1718 if right > 0:
1719 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1720 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001721 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001722 )
1723 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1724 create_avg_pool_for_concat(
1725 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1726 )
1727
1728 op.type = Op.ConcatTFLite
1729 return avgpool_op
1730
1731
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001732def fixup_bias_tensors(op, arch, nng, dtype=None):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001733 if op.type.needs_bias() and op.bias is None:
1734 # Op has no bias, add bias tensor filled with zeros
1735 nr_biases = op.inputs[1].shape[-1]
1736 bias_values = [0] * nr_biases
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001737 # The DataType of the bias tensor can be explicitly provided or deduced from the ifm
1738 # DataType. Default is int32 bias for 8-bit ifms and int64 for int16 ifms.
1739 # For int16 the selected bias DataType will have an impact on the scaling
1740 # used when encoding the scales and biases later. The default mode will match the
1741 # refence with reduced scaling for int64 bias.
1742 # This means that in cases (in the graph optimiser) where DepthwiseConv2DBias
1743 # is used to emulate average pool int32 bias should be selected for full precision
1744 # int16 scaling.
1745 if dtype is None:
1746 dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32
1747 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001748 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1749
1750 return op
1751
1752
wilisa0146c94772023-02-08 09:56:14 +00001753def detect_asymmetric_weights(op):
1754 # Check all ops (cpu and npu)
1755 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
1756 if op.ifm.dtype in (DataType.int8, DataType.int16):
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001757 if not np.all(op.weights.quantization.zero_point == 0):
wilisa0146c94772023-02-08 09:56:14 +00001758 print(f"Warning: Op {op.type} '{op.name}' has asymmetric weights.", end=" ")
1759 return True
1760 return False
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001761
wilisa0146c94772023-02-08 09:56:14 +00001762
1763def fixup_asymmetric_weights(op, arch, nng):
1764 if detect_asymmetric_weights(op):
1765 if op.run_on_npu:
1766 print("Zero points have been adjusted.")
1767 op.weights.quantization.zero_point *= 0
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001768 return op
1769
1770
wilisa0146c94772023-02-08 09:56:14 +00001771def check_asymmetric_weights(op, arch, nng):
1772 # This function can modify the run_on_npu flag which causes an operator to be placed on the CPU. It is usually only
1773 # set by the supported operator checks. Therefore, it should be run immediately after those checks to avoid the
1774 # possibility of other graph optimiser functions modify the operator (that is later run on the CPU)
1775 if detect_asymmetric_weights(op):
1776 if op.run_on_npu:
1777 print("To run the operator on Ethos-U use the option --force-symmetric-int-weights")
1778 op.run_on_npu = False
1779 return op
1780
1781
1782def fixup_or_check_asymmetric_weights(force_symmetric_int_weights):
1783 if force_symmetric_int_weights:
1784 return fixup_asymmetric_weights
1785 else:
1786 return check_asymmetric_weights
1787
1788
Rickard Bolina68b82a2023-04-20 15:12:28 +00001789def convert_mean_to_depthwise_conv(op, arch, nng):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001790 if op.type == Op.Mean and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001791 inp, axis = op.inputs
1792 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001793 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001794 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001795 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001796
1797 # Height and width axes have different index depending on dimensions
1798 if axis.shape == [] or axis.shape[0] == 1: # single axis
1799 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1800 if dims in (2, 3):
Rickard Bolina68b82a2023-04-20 15:12:28 +00001801 # If dims is 2 or 3, axis 0 refers to h-dimension
1802 h, w = (shape[axis], 1) if axis == 0 else (1, shape[axis])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001803 else:
Rickard Bolina68b82a2023-04-20 15:12:28 +00001804 # If dims is 4, axis 1 refers to h-dimension
1805 h, w = (shape[axis], 1) if axis == 1 else (1, shape[axis])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001806 else: # multiple axes
1807 axis = sorted(axis.values)
1808 h, w = [shape[i] for i in axis]
1809
1810 # Set necessary depthwise attributes
1811 op.attrs.update(
1812 {
1813 "padding": Padding.VALID,
1814 "stride_h": 1,
1815 "stride_w": 1,
1816 "strides": (1, 1, 1, 1),
1817 "depth_multiplier": 1,
1818 "channel_multiplier": 1,
1819 "dilation_h_factor": 1,
1820 "dilation_w_factor": 1,
1821 "dilation": (1, 1, 1, 1),
1822 }
1823 )
1824 # Change op type
1825 op.type = Op.DepthwiseConv2DBias
1826 # Set IFM/OFM shapes after changing op type
1827 op.set_ifm_ofm_shapes()
1828
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001829 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001830 def extend_dims(dim, in_shape):
1831 if dim < 4:
1832 in_shape = [1] + in_shape
1833 if dim == 2:
1834 in_shape += [1]
1835 return in_shape
1836
1837 if dims < 4 or dims_ofm < 4:
1838 # Fix the ofm dimension when keep_dims is false
1839 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1840 if isinstance(axis, int) and dims_ofm + 1 == dims:
1841 ofm_shape.insert(axis, 1)
1842 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1843 for i in axis:
1844 ofm_shape.insert(i, 1)
1845 shape = extend_dims(dims, shape)
1846 dims_ofm = len(ofm_shape)
1847 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1848 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001849
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001850 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
Johan Alfven7b3008a2023-04-13 18:54:47 +02001851 if h > 64:
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001852 # This can only happen and be done for multiple axes, and
Johan Alfven7b3008a2023-04-13 18:54:47 +02001853 # h * w <= 4096 for DepthwiseConv2DBias
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001854 # which is checked in supported ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001855 shape = [shape[0], 1, h * w, shape[3]]
1856 op.ifm_shapes[0] = Shape4D(shape)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001857 weight_shape = [1, h * w, shape[3], shape[0]]
Rickard Bolina68b82a2023-04-20 15:12:28 +00001858 else:
1859 # Set weight shape to [H,W,C,B]
1860 weight_shape = [h, w, shape[3], shape[0]]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001861
Johan Alfven7b3008a2023-04-13 18:54:47 +02001862 op.rounding_mode = NpuRoundingMode.NATURAL
1863 identity_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
1864 op.forced_input_quantization = identity_quant
1865 op.forced_output_quantization = identity_quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001866
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001867 # Add unit weight tensor
1868 op.set_input_tensor(
1869 create_const_tensor(
1870 "weights",
1871 weight_shape,
1872 inp.dtype,
1873 np.ones(weight_shape),
Johan Alfven7b3008a2023-04-13 18:54:47 +02001874 quantization=identity_quant,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001875 ),
1876 1,
1877 )
James Peet7519d502021-07-19 16:47:58 +01001878 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001879
Johan Alfven7b3008a2023-04-13 18:54:47 +02001880 # Input zero point is adjusted after the sum calculation, so we emulate that with a bias
Rickard Bolina68b82a2023-04-20 15:12:28 +00001881 ofmq, ifmq = op.ofm.quantization, inp.quantization
Johan Alfven7b3008a2023-04-13 18:54:47 +02001882 bias = -ifmq.zero_point * h * w
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001883 bias_shape = [shape[-1]]
Johan Alfven7b3008a2023-04-13 18:54:47 +02001884 op.inputs.append(create_const_tensor(op.name + "_bias", bias_shape, DataType.int32, np.ones(bias_shape) * bias))
wilisa0179a89042022-11-02 17:18:43 +00001885 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001886
Rickard Bolina68b82a2023-04-20 15:12:28 +00001887 # Create intermediate tensor between depthwise conv and mul
Johan Alfven7b3008a2023-04-13 18:54:47 +02001888 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1889 intermediate.dtype = DataType.int32
Rickard Bolina68b82a2023-04-20 15:12:28 +00001890
1891 # Multiply sum with 1/num_elements_in_axis to get the mean
Johan Alfven7b3008a2023-04-13 18:54:47 +02001892 mul_op = Operation(Op.Mul, op.name + "_mul")
1893 mul_op.add_input_tensor(intermediate)
1894 mul_op.set_output_tensor(op.ofm)
1895 mul_op.forced_input_quantization = identity_quant
1896
Rickard Bolina68b82a2023-04-20 15:12:28 +00001897 # Set dw conv output to the intermediate tensor
1898 op.set_output_tensor(intermediate)
1899
1900 # Move activation from original op to mean op
1901 mul_op.activation = op.activation
1902 op.activation = None
1903
Johan Alfven7b3008a2023-04-13 18:54:47 +02001904 # The multiplier is calculated in the same way as in the reference,
1905 # clamping the shift value at the price of some precision loss.
1906 num_elements_in_axis = int(h * w)
1907 output_multiplier, output_shift_vela = quantise_scale(np.double(ifmq.scale_f32) / np.double(ofmq.scale_f32))
1908
1909 # Convert to reference representation shift value
1910 output_shift = 31 - output_shift_vela
1911
1912 # Reference calculation
1913 # round_down_log2 same as 63 - CountLeadingZeros(num_elements_in_axis)
1914 shift = round_down_log2(num_elements_in_axis)
1915 shift = min(shift, 32)
1916 shift = min(shift, 31 + output_shift)
1917 output_multiplier = (output_multiplier << shift) // num_elements_in_axis
1918 output_shift = output_shift - shift
1919
1920 # Convert to vela representation shift
1921 output_shift_vela = 31 - output_shift
1922
1923 # For int32 scaling is not supported so instead multiply with the scale
1924 # intermediate * scale -> round and shift.
1925 scalar = create_const_tensor(
1926 op.name + "_scalar", [1, 1, 1, 1], DataType.int32, [output_multiplier], quantization=identity_quant
1927 )
1928 mul_op.add_input_tensor(scalar)
1929 mul_op.set_ifm_ofm_shapes()
1930
1931 # Reference using TFL rounding for the multiply
1932 mul_op.rounding_mode = NpuRoundingMode.TFL
1933
1934 # Need to use explicit scaling to get the wanted shift
1935 mul_op.explicit_scaling = ExplicitScaling(False, [output_shift_vela], [1])
Johan Alfven7b3008a2023-04-13 18:54:47 +02001936 DebugDatabase.add_optimised(op, mul_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001937 return op
1938
1939
Johan Alfvence502732023-04-24 13:35:40 +02001940def convert_ops_to_lut(op, arch, nng):
1941 if op.type == Op.Exp:
1942 if op.ifm.dtype == DataType.int8:
1943 return create_lut_8bit_op(op, math.exp, "exp")
1944 elif op.ifm.dtype == DataType.int16:
1945 return create_lut_int16_op(op, math.exp, "exp")
1946 else:
1947 # Should already be catched in tflite supported ops
1948 assert False, f"Unsupported data type {op.ifm.dtype} for {op.type}"
1949
1950 return op
1951
1952
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001953def optimise_quantize(op: Operation, arch, nng):
1954
1955 if op.type == Op.Quantize and op.run_on_npu:
1956
1957 ifm, ofm = op.get_ifm_ofm()
1958 input_values = ifm.values
1959
1960 # Guard clause - input not const or no values to quantize
1961 if ifm.ops[0].type != Op.Const or input_values is None:
1962 return op
1963
1964 # Singular val in numpy array, convert to indexable array
1965 if input_values.ndim == 0:
1966 input_values = np.array([input_values])
1967
Fredrik Svedberg11563172022-07-06 14:54:12 +02001968 # requantized int8 to int8 or int16 to int16
1969 if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001970
1971 # scale needs to use double precision to match TFLite reference kernel
1972 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
1973 effective_multiplier, effective_shift = quantise_scale(effective_scale)
1974
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001975 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001976 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001977 input_val = val - ifm.quantization.zero_point
1978
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001979 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
1980 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001981
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001982 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
1983 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001984
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001985 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
1986 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001987
1988 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001989 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001990
1991 quantized_vals = []
1992 for val in input_values:
1993
1994 # Derive quantized value
1995 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001996 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
1997 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001998
1999 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002000 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
2001
2002 # Unsupported data type
2003 else:
2004 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002005
2006 # Make quantize op const and disconnect from parent node
2007
2008 # Remove reference of the current quant op from the parent tensor's consumer list
2009 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
2010
2011 # Clear any references to parent node
2012 op.inputs = []
2013
2014 # Convert this quantize op to const
2015 op.type = Op.Const
2016
2017 return op
2018
2019
Ayaan Masood4965fae2022-06-29 11:30:57 +01002020def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
2021 """Static optimisation for SHAPE operator output value known at compile time"""
2022
2023 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
2024
2025 if op.type == Op.Shape and op.run_on_npu:
2026
2027 ifm, ofm = op.get_ifm_ofm()
2028
2029 if len(ifm.shape) != ofm.shape[0]:
2030 return op
2031
2032 # Remove reference of the current shape op from the parent tensor's consumer list
2033 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
2034
2035 # Clear any references to parent node
2036 op.inputs = []
2037
2038 # Convert this SHAPE op to const
2039 op.type = Op.Const
2040
2041 # Add size calculation to shape output tensors
2042 ofm.values = np.array(ifm.shape)
2043
2044 return op
2045
2046
Tim Hallea4ba662022-11-11 18:19:53 +00002047def fixup_dilation_gt2(op, arch, nng):
2048 assert op.run_on_npu
2049 if op.type == Op.Conv2DBias or op.type == Op.DepthwiseConv2DBias:
2050 dilation_w, dilation_h = op.get_kernel_dilation()
2051
2052 # if dilation in either axis is greater than that supported by the hardware then we must manually dilate the
2053 # kernel
2054 if dilation_w > 2 or dilation_h > 2:
2055 kernel_w, kernel_h = op.get_kernel_size()
2056 kernel_ic = op.weights.shape[-2]
2057 kernel_oc = op.weights.shape[-1]
2058
2059 # if the dilation is a multiple of 2 then the hardware dialtion can be enabled to provide that multiple
2060 # of 2. this allows the kernel size to be reduced (via the scaled dilation) by half in that dimension.
2061 # odd = 1, even = 2
2062 hw_dilation_h = 1 if (dilation_h & 1) else 2
2063 hw_dilation_w = 1 if (dilation_w & 1) else 2
2064
2065 scale_dilation_h = dilation_h // hw_dilation_h
2066 scale_dilation_w = dilation_w // hw_dilation_w
2067
2068 # create new empty kernel (HWIO format)
2069 new_kernel_h = (kernel_h - 1) * scale_dilation_h + 1
2070 new_kernel_w = (kernel_w - 1) * scale_dilation_w + 1
2071
2072 new_kernel_shape = [new_kernel_h, new_kernel_w, kernel_ic, kernel_oc]
2073 new_kernel_values = np.zeros(new_kernel_shape, dtype=op.weights.values.dtype)
2074
2075 # copy the original kernel values into the new sparse kernel
2076 for h in range(0, kernel_h):
2077 for w in range(0, kernel_w):
2078 new_h = h * scale_dilation_h
2079 new_w = w * scale_dilation_w
2080 new_kernel_values[new_h, new_w, :, :] = op.weights.values[h, w, :, :]
2081
2082 # update the weight tensor with the new dilated kernel
2083 op.weights.shape = new_kernel_shape
2084 op.weights.values = new_kernel_values
2085
2086 # enable(=2) / disable(=1) hardware dilation
2087 op.attrs["dilation"] = (1, hw_dilation_h, hw_dilation_w, 1) # nhwc format
2088 op.attrs["dilation_h_factor"] = hw_dilation_h
2089 op.attrs["dilation_w_factor"] = hw_dilation_w
2090
2091 return op
2092
2093
Tim Hall2180a172023-03-10 18:11:34 +00002094def fixup_reshape(op, arch, nng):
2095 def _get_explicit_shape(implicit_shape, total_size):
2096 # the explicit shape is a copy of the implicit shape but with the special -1 (remaining size) value converted to
2097 # the appropriate value
2098 if implicit_shape is None:
2099 return None
2100
2101 explicit_shape = list(implicit_shape)
2102 if -1 in explicit_shape:
2103 explicit_shape[explicit_shape.index(-1)] = int(total_size / abs(np.prod(implicit_shape)))
2104
2105 return explicit_shape
2106
2107 if op.type == Op.Reshape:
2108 ifm_tensor, _, ofm_tensor = op.get_ifm_ifm2_ofm()
2109 ifm_size = ifm_tensor.elements()
2110 ofm_shape = ofm_tensor.shape
2111
2112 new_shape_tensor_shape = op.inputs[1].values.flatten() if len(op.inputs) > 1 else None
2113 new_shape_tensor_shape = _get_explicit_shape(new_shape_tensor_shape, ifm_size)
2114
2115 new_shape_attribute = op.attrs.get("new_shape", None)
2116 new_shape_attribute = _get_explicit_shape(new_shape_attribute, ifm_size)
2117
2118 # if present the new shape tensor overrides the new_shape attribute
2119 if new_shape_tensor_shape is not None:
2120 # check tensor
2121 if not np.array_equal(new_shape_tensor_shape, ofm_shape):
2122 print(
2123 f"Warning: {optype_to_builtintype(op.type)} '{op.name}' has new shape tensor"
2124 f" ({new_shape_tensor_shape}) that does not match output tensor shape {ofm_shape}. Will use output"
2125 f" tensor shape."
2126 )
2127 elif new_shape_attribute is not None:
2128 # check attribute
2129 if not np.array_equal(new_shape_attribute, ofm_shape):
2130 print(
2131 f"Warning: {optype_to_builtintype(op.type)} '{op.name}' has new_shape attribute"
2132 f" ({new_shape_attribute}) that does not match output tensor shape {ofm_shape}. Will use output"
2133 f" tensor shape."
2134 )
2135 else:
2136 print(
2137 f"Warning: {optype_to_builtintype(op.type)} '{op.name}' does not have a new shape tensor or a new_shape"
2138 f" attribute. Will use output tensor shape {ofm_shape}."
2139 )
2140
2141 # force new shape tensor to output shape
2142 new_shape_tensor = create_const_tensor(
2143 op.name + "_new_shape", [len(ofm_shape)], DataType.int32, np.array(ofm_shape, np.int32)
2144 )
2145 if len(op.inputs) > 1:
2146 op.set_input_tensor(new_shape_tensor, 1)
2147 else:
2148 op.add_input_tensor(new_shape_tensor)
2149
2150 # force new_shape attribute to output shape
2151 op.attrs["new_shape"] = ofm_shape
2152
2153 return op
2154
2155
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002156def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02002157 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002158 return op
2159
2160
wilisa0146c94772023-02-08 09:56:14 +00002161def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
Fredrik Svedberg11563172022-07-06 14:54:12 +02002162 # Compile time static optimisations
wilisa0146c94772023-02-08 09:56:14 +00002163 optimisation_list = [
2164 optimise_quantize,
2165 convert_shape_op_to_constant_tensor,
2166 fixup_or_check_asymmetric_weights(force_symmetric_int_weights),
2167 ]
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002168
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002169 for idx, sg in enumerate(nng.subgraphs):
2170 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002171 nng,
2172 sg,
2173 arch,
2174 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01002175 optimisation_list,
2176 rewrite_unsupported=False,
2177 )
2178
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002179 # Pre-processing step
Tim Hall2180a172023-03-10 18:11:34 +00002180 pre_process_list = [supported_operator_check, set_ifm_ofm_op_shapes, fixup_reshape]
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002181
Ayaan Masood4965fae2022-06-29 11:30:57 +01002182 for idx, sg in enumerate(nng.subgraphs):
2183 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
2184 nng,
2185 sg,
2186 arch,
2187 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02002188 pre_process_list,
2189 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002190 )
2191
2192 # Handle Concat Ops
2193 for idx, sg in enumerate(nng.subgraphs):
2194 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
2195 sg.refresh_after_modification()
2196
2197 # Handle Split Ops
2198 for idx, sg in enumerate(nng.subgraphs):
2199 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
2200 nng,
2201 sg,
2202 arch,
2203 [],
2204 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
2205 rewrite_unsupported=False,
2206 )
2207
2208 for idx, sg in enumerate(nng.subgraphs):
2209 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002210 nng,
2211 sg,
2212 arch,
2213 [rewrite_split_ops],
2214 [],
2215 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002216 )
2217
Johan Alfvena5e1b622023-02-02 14:59:03 +01002218 # Bypass or rewrite memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002219 for idx, sg in enumerate(nng.subgraphs):
2220 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002221 nng,
2222 sg,
2223 arch,
2224 [],
Johan Alfvena5e1b622023-02-02 14:59:03 +01002225 [bypass_memory_only_ops],
Jonas Ohlssond8575072022-03-30 10:30:25 +02002226 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002227 )
2228
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002229 # Rewrite of operators
2230 op_rewrite_list = [
2231 set_tensor_equivalence,
Johan Alfvence502732023-04-24 13:35:40 +02002232 convert_ops_to_lut,
Rickard Bolina68b82a2023-04-20 15:12:28 +00002233 convert_mean_to_depthwise_conv,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002234 convert_depthwise_to_conv,
2235 convert_conv_to_fc,
Fredrik Svedberg0ac08042023-04-11 22:35:04 +02002236 convert_lstm,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002237 convert_softmax,
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02002238 convert_prelu,
Fredrik Svedberg36424312022-09-16 09:39:26 +02002239 convert_mul_max_to_abs_or_lrelu,
2240 convert_lrelu,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002241 convert_hardswish_to_lut,
2242 rewrite_fully_connected_input,
2243 convert_batched_fc_shape,
2244 fixup_conv2d_backprop,
2245 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002246 reorder_depthwise_weights,
Rickard Bolin6986a072022-12-19 12:33:40 +00002247 convert_argmax_to_depthwise_conv_and_max_pool,
Tim Hall885033b2022-07-21 11:46:03 +01002248 fixup_resize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002249 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01002250 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002251 convert_tanh_sigmoid_to_lut,
2252 replace_pad_by_hw_pad,
Tim Hallea4ba662022-11-11 18:19:53 +00002253 fixup_dilation_gt2,
Raul Farkas72c6a242023-03-16 16:38:05 +00002254 fixup_strided_conv,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002255 ]
2256
2257 for idx, sg in enumerate(nng.subgraphs):
2258 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002259 nng,
2260 sg,
2261 arch,
2262 [],
2263 op_rewrite_list,
2264 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002265 )
2266
2267 for idx, sg in enumerate(nng.subgraphs):
2268 # remove passthrough tensors and attempt further optimizations
2269 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
2270 nng,
2271 sg,
2272 arch,
2273 [remove_passthrough_tensor],
2274 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
2275 )
2276
2277 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
2278 # since ifm/ofm_shapes are of importance to this function
2279 for sg in nng.subgraphs:
2280 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
2281 sg.refresh_after_modification()
2282
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01002283 # Make sure that const optimisations on subgraph outputs are handled correctly
2284 for sg in nng.subgraphs:
2285 for ofm in sg.output_tensors:
2286 if ofm.is_const and ofm.ops[0].type_changed:
2287 # Subgraph output cannot be const - insert a memory copy
2288 op = ofm.ops[0]
2289 ofm_clone = ofm.clone()
2290 ofm_clone.values = ofm.values
2291 ofm.values = None
Tim Hall3b1578e2023-01-13 17:57:25 +00002292 zero = create_const_tensor("zero", [1], ofm.dtype, [0], quantization=ofm.quantization)
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01002293 memcpy = create_add_nop(f"{ofm.name}_copy")
2294 memcpy.add_input_tensor(ofm_clone)
2295 memcpy.add_input_tensor(zero)
2296 memcpy.set_output_tensor(ofm)
2297 memcpy.set_ifm_ofm_shapes()
2298 op.set_output_tensor(ofm_clone)
2299 DebugDatabase.add_optimised(op, memcpy)
2300
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002301 return nng