blob: ccbb1f28eae130b947c4346b571bec5aefa6e735 [file] [log] [blame]
Johan Alfven9341bf42024-03-05 11:31:49 +01001# SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020017# Description:
18# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
19# to do the traversal of the graph.
Raul Farkas10d6b3b2023-01-30 12:58:46 +000020from __future__ import annotations
21
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020022import math
23import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020024
25import numpy as np
26
27from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020028from . import rewrite_graph
29from . import scaling
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020030from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020031from .data_type import DataType
32from .debug_database import DebugDatabase
33from .errors import UnsupportedFeatureError
34from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020035from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020036from .graph_optimiser_util import calc_explicit_padding
Johan Alfven7647b0f2024-04-02 20:56:09 +020037from .graph_optimiser_util import check_splitsliceread_to_consumer_shape
Patrik Gustavssondf995102021-08-23 15:33:59 +020038from .graph_optimiser_util import convert_depthwise_to_conv
Fredrik Svedberg0ac08042023-04-11 22:35:04 +020039from .graph_optimiser_util import create_avg_pool_for_concat
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020040from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020041from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020042from .graph_optimiser_util import needed_total_padding
43from .graph_optimiser_util import set_ifm_ofm_op_shapes
44from .graph_optimiser_util import set_tensor_equivalence
Fredrik Svedberg0ac08042023-04-11 22:35:04 +020045from .lstm import Lstm
Johan Alfvence502732023-04-24 13:35:40 +020046from .lut import convert_to_lut
47from .lut import create_lut_8bit_op
48from .lut import create_lut_int16_op
Johan Alfven8e525ca2023-05-07 13:12:37 +020049from .lut import create_lut_rsqrt_int8_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020050from .numeric_util import clamp_sigmoid
Johan Alfven56811e62023-03-27 11:33:50 +020051from .numeric_util import full_shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020052from .numeric_util import round_away_zero
Johan Alfven7b3008a2023-04-13 18:54:47 +020053from .numeric_util import round_down_log2
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020054from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020055from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020056from .operation import NpuBlockType
57from .operation import Op
58from .operation import Operation
59from .operation import Padding
Tim Hall5ff4cd12023-05-16 22:39:14 +010060from .operation import RoundingMode
Alexander Hansson90c34b52023-05-31 15:03:03 +000061from .operation_util import create_add
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +010062from .operation_util import create_add_nop
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020063from .operation_util import create_avgpool_nop
Johan Alfvenc1ad80b2023-03-31 10:19:23 +020064from .operation_util import create_cast_op
Rickard Bolin6986a072022-12-19 12:33:40 +000065from .operation_util import create_depthwise_maxpool
Johan Alfvenc1ad80b2023-03-31 10:19:23 +020066from .operation_util import create_memcpy
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020067from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010068from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020069from .shape4d import Shape4D
70from .softmax import SoftMax
71from .tensor import check_quantized_tens_scaling_equal
72from .tensor import create_const_tensor
73from .tensor import create_equivalence_id
74from .tensor import QuantizationParameters
75from .tensor import Tensor
76from .tensor import TensorPurpose
77from .tflite_mapping import optype_to_builtintype
Raul Farkas3b64f062023-05-16 17:18:31 +010078from .utils import calc_resize_factor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020079
80passthrough_nodes = (Op.Identity,)
81
82
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020083def remove_passthrough_tensor(tens, arch, nng):
84 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
85 assert len(tens.ops[0].inputs) == 1
86 tens = tens.ops[0].inputs[0]
87 return tens
88
89
90def rewrite_concat_ops(op, arch):
91 if not op.run_on_npu or not op.type.is_concat_op():
92 return
93
94 axis_4D = 0
95 ofm = op.ofm
96 ofm.ops = []
97 offset = 0
98
99 unfuse_activation_function(op)
100
101 if op.type == Op.Pack:
102 # Pack is also referred to as Stack
103 axis = int(op.attrs["axis"])
104 if axis < 0: # Convert to positive axis
105 axis = len(op.inputs[0].shape) + 1 + axis
106
107 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
108
109 axis_4D = axis + (4 - len(desired_shape))
110
111 for idx, inp in enumerate(op.inputs):
112 op.ifm_shapes[idx] = Shape4D(desired_shape)
113 op.type = Op.PackReshaped
114
115 inputs, axis = op.get_concat_inputs_axis()
116 for idx, inp in enumerate(inputs):
117 if op.type != Op.PackReshaped:
118 op.ifm_shapes[idx] = Shape4D(inp.shape)
119 if axis >= 0:
120 axis_4D = axis + (4 - len(inp.shape))
121 else:
122 axis_4D = axis
123 write_offset = [0, 0, 0, 0]
124 write_offset[axis_4D] = offset
125 concat_end = offset + op.ifm_shapes[idx][axis_4D]
126 create_avg_pool_for_concat(
127 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
128 )
129 offset = concat_end
130 assert ofm.shape[axis] == offset
131
132 return op
133
134
135def rewrite_split_ops(tens, arch, nng):
136
137 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
138 split_op = tens.ops[0]
139
140 # Not supported so leave it and run on CPU
141 if not split_op.run_on_npu:
142 return tens
143
Rickard Bolinbe78a052024-01-31 12:05:11 +0000144 inp, outputs, axis, offset_start, offset_end, strides_tens = split_op.get_split_inputs_axis()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200145
146 tens.ops = []
147 new_op = Operation(Op.SplitSliceRead, split_op.name)
148 new_op.inputs = [inp]
149 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000150 if None in (offset_end, offset_start):
151 read_shape = None
152 else:
Rickard Bolinbe78a052024-01-31 12:05:11 +0000153 # The read shape is relative to each start offset
154 # Limit read shape to the size of the IFM - offset is not necessarily limited
155 ifm_dims = split_op.ifm_shapes[0].as_list()
156 read_shape = Shape4D([min(oe, ifm_dim) - os for oe, os, ifm_dim in zip(offset_end, offset_start, ifm_dims)])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200157
158 # For Split the offset cannot be extracted from the tensor so it has to
159 # be calculated from the index of the output tensor
160 if axis is not None:
161 # Get the start and end of the split
162 offset_start = [0] * 4
163 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
164 for idx, out in enumerate(outputs):
165 if axis_4D_list is not None:
166 axis_4D = axis_4D_list[idx]
167 else:
168 split_op.ofm_shapes[idx] = Shape4D(out.shape)
169 if axis >= 0:
170 axis_4D = axis + (4 - len(out.shape))
171 else:
172 axis_4D = axis
173
174 if out == tens:
175 ofm_shape_idx = idx
176 read_shape = split_op.ofm_shapes[idx]
177 break
178
179 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
180
181 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
182 new_op.read_shapes[0] = read_shape
183 new_op.run_on_npu = True
184 new_op.set_output_tensor(tens)
185 new_op.ifm_shapes.append(Shape4D(inp.shape))
186 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
Rickard Bolinbe78a052024-01-31 12:05:11 +0000187 # Set stride multiplier in H/W if a stride tensor is provided
188 s_h, s_w = (strides_tens.values[-3], strides_tens.values[-2]) if strides_tens else (1, 1)
189 new_op.ifm_stride_multiplier[0] = [1, s_h, s_w] # C/H/W
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200190 DebugDatabase.add_optimised(split_op, new_op)
191
192 return tens
193
194
195def remove_SplitSliceRead(op, arch):
196
197 if op.type == Op.SplitSliceRead:
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200198 # Check if it is possible to put the SplitSliceRead on the tensor consumer(s),
199 # or if an avgpool need to be inserted
Johan Alfven9341bf42024-03-05 11:31:49 +0100200 # Not possible to move:
Rickard Bolinbe78a052024-01-31 12:05:11 +0000201 # - if ifm stride multiplier is larger than one in any dimension
Johan Alfven9341bf42024-03-05 11:31:49 +0100202 # - if consumer is a Transpose op since ifm shape has been reshaped and can not be changed
203 # - if consumer is elementwise and ifm needs to be broadcasted
Rickard Bolinbe78a052024-01-31 12:05:11 +0000204 if (
205 op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
206 and all(s_mul == 1 for s_mul in op.ifm_stride_multiplier[0])
207 and all(
208 consumer is not None
209 and consumer.run_on_npu
210 and consumer.type not in memory_only_ops
211 and consumer.original_type != Op.Transpose
212 and check_splitsliceread_to_consumer_shape(op, consumer)
213 and not (
214 consumer.type.is_binary_elementwise_op()
215 and Shape4D.from_list(consumer.ofm.shape) != op.ofm_shapes[0]
216 )
217 for consumer in op.ofm.consumer_list
Johan Alfven9341bf42024-03-05 11:31:49 +0100218 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200219 ):
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200220 # SplitSliceRead can be performed by tensor consumer(s)
221 for cons_op in list(op.ofm.consumer_list):
222 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200223 else:
224 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
225 avgpool_op.add_input_tensor(op.ifm)
226 avgpool_op.outputs = [op.ofm]
227 op.ofm.ops.remove(op)
228 op.ofm.ops.append(avgpool_op)
229 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
230 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
231 avgpool_op.read_offsets[0] = op.read_offsets[0]
232 avgpool_op.read_shapes[0] = op.read_shapes[0]
Rickard Bolinbe78a052024-01-31 12:05:11 +0000233 if any(s_mul != 1 for s_mul in op.ifm_stride_multiplier[0]):
234 avgpool_op.ifm_stride_multiplier[0] = op.ifm_stride_multiplier[0].copy()
235 avgpool_op.ifm.force_linear_format = True
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200236
237 op.ifm.consumer_list.remove(op)
238 DebugDatabase.add_optimised(op, avgpool_op)
239
240
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200241def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
242 k_w, k_h = kernel.dilated_wh()
243 s_x, s_y = kernel.stride
244 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
245 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
246 if padding_type == Padding.SAME:
247 left_pad = (xpad + 0) // 2
248 right_pad = (xpad + 1) // 2
249 top_pad = (ypad + 0) // 2
250 bottom_pad = (ypad + 1) // 2
251 elif padding_type == Padding.VALID:
252 left_pad = 0
253 right_pad = 0
254 top_pad = 0
255 bottom_pad = 0
256 elif padding_type == Padding.EXPLICIT:
257 # Padding is specified in a PAD operator which has been bypassed.
258 top, left, bottom, right = explicit_padding
259 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
260 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Rickard Bolin9ae34552022-06-09 13:07:17 +0000261 elif padding_type == Padding.TILE:
262 # The values in the explicit padding only represent the "direction" in which to pad
263 top_pad, left_pad, bottom_pad, right_pad = explicit_padding
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200264 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000265 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200266 padding = (top_pad, left_pad, bottom_pad, right_pad)
267 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
268 return padding, skirt
269
270
Johan Alfvenc0bb8682023-09-04 17:18:33 +0200271def calc_upscaled_padding_and_skirt(
272 padding_type, kernel_size, stride, input_shape, upscaling_factor_y, upscaling_factor_x
273):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200274 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
275 if padding_type == Padding.SAME:
Johan Alfvenc0bb8682023-09-04 17:18:33 +0200276 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor_y, int(stride[1]), int(kernel_height))
277 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor_x, int(stride[2]), int(kernel_width))
278 right_pad = max(((xpad + 1) // upscaling_factor_x) - 1, 0)
279 bottom_pad = max(((ypad + 1) // upscaling_factor_y) - 1, 0)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200280 left_pad = max(kernel_width - 1 - right_pad, 0)
281 top_pad = max(kernel_height - 1 - bottom_pad, 0)
282 elif padding_type == Padding.VALID:
283 right_pad = max(kernel_width - 2, 0)
284 bottom_pad = max(kernel_height - 2, 0)
285 left_pad = kernel_width - 1
286 top_pad = kernel_height - 1
287 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000288 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200289 padding = (top_pad, left_pad, bottom_pad, right_pad)
290 skirt = padding
291 return padding, skirt
292
293
Raul Farkas66207142023-05-25 11:15:20 +0100294def fixup_conv2d_backprop(op: Operation, arch, nng) -> Operation:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200295 if op.type == Op.Conv2DBackpropInput:
296 # flip the inputs
297 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
298 op.type = Op.Conv2DBackpropInputSwitchedBias
Johan Alfvenc0bb8682023-09-04 17:18:33 +0200299 stride_w = op.kernel.stride.x
300 stride_h = op.kernel.stride.y
301 if stride_w > 1 or stride_h > 1:
302 # Transpose conv2d with upscaling
303 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200304
305 # Update strides
306 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
wilisa0179a89042022-11-02 17:18:43 +0000307 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200308
309 return op
310
311
312# Convert the op to an elementwise add
Tim Hall885033b2022-07-21 11:46:03 +0100313def convert_resize_1x1_to_add(op):
314 op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200315 op.name = op.name + "_add"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200316 # Create an input tensor filled with zeros
wilisa018289d512023-01-12 08:17:23 +0000317 name = op.inputs[1].name + "_add"
318 dtype = op.inputs[0].dtype
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200319 shape = op.ofm_shapes[0].as_list()
wilisa018289d512023-01-12 08:17:23 +0000320 values = np.zeros(shape, dtype.as_numpy_type())
321 quantization = QuantizationParameters(0.0, 255.0)
322 quantization.scale_f32 = 1.0
323 quantization.zero_point = 0
wilisa0116b5e5e2023-02-14 12:03:59 +0000324 op.inputs[1] = op.inputs[0]
325 op.set_input_tensor(create_const_tensor(name, shape, dtype, values, quantization=quantization), 0)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200326 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000327 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200328
329 return op
330
331
Tim Hall5ff4cd12023-05-16 22:39:14 +0100332# Convert ResizeNearestNeighbor with align corners to a depthwise convolution. The IFM will already have been upscaled
Tim Hall885033b2022-07-21 11:46:03 +0100333# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
334# to select the appropriate nearest neighbor value
335def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
336 ifm = op.ifm
337 ofm = op.ofm
338 output_depth = ofm.shape[-1]
339 dw_op_attrs = {
340 "padding": Padding.VALID,
341 "stride_h": 1,
342 "stride_w": 1,
343 "strides": (1, 1, 1, 1),
344 "depth_multiplier": 1,
345 "channel_multiplier": 1,
346 "dilation_h_factor": 1,
347 "dilation_w_factor": 1,
348 "dilation": (1, 1, 1, 1),
349 }
350
Tim Hall5ff4cd12023-05-16 22:39:14 +0100351 # change ResizeNearestNeighbor to Depthwise
Tim Hall885033b2022-07-21 11:46:03 +0100352 op.type = Op.DepthwiseConv2DBias
353 op.attrs.update(dw_op_attrs)
354 op.set_input_tensor(ifm, 0) # ifm tensor index
355 op.activation = None
356
357 # add input resample to resize by x2
358 op.ifm_resampling_mode = resampling_mode.NEAREST
359
360 # don't care about the rounding mode as it is nearest neighbor
361
362 # setup weight tensor
363 weight_quant = QuantizationParameters()
364 weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
365 weight_quant.zero_point = 0
366 weight_quant.quant_dim = 0
367 ofm_dtype = ofm.dtype
Tim Hall3b1578e2023-01-13 17:57:25 +0000368 if ofm_dtype.type == BaseType.UnsignedInt:
Tim Hall885033b2022-07-21 11:46:03 +0100369 weight_quant.quant_min = 0
370 weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
371 else:
Tim Hall885033b2022-07-21 11:46:03 +0100372 weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
373 weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
374
375 weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
376
377 # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
378 # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
379 # below-and-right (i.e. next) to it (D).
380 # 0---1---2
381 # | A | B |
382 # 1---*---+
383 # | C | D |
384 # 2---+---+
385 weight_values = [0] * (upscale_factor * upscale_factor)
386 centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
387 weight_values[centre_coeff] = 1
388
389 # add weight tensor, this will discard the size tensor of the resize op
390 op.set_input_tensor(
391 create_const_tensor(
392 "weights",
393 weight_shape,
Tim Hall3b1578e2023-01-13 17:57:25 +0000394 ofm_dtype,
Tim Hall885033b2022-07-21 11:46:03 +0100395 np.array(weight_values).reshape(weight_shape),
Tim Hall885033b2022-07-21 11:46:03 +0100396 quantization=weight_quant,
397 ),
398 1, # inputs tensor weight index
399 )
400
401 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
402 # need to append the bias tensor as resize ops only have 2 inputs
403 assert len(op.inputs) == 2
404 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +0200405 fixup_bias_tensors(op, None, None, DataType.int32)
Tim Hall885033b2022-07-21 11:46:03 +0100406
407 # finally update the shape incase we've change the tensor shapes or connections
408 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000409 DebugDatabase.add_optimised(op, op)
Tim Hall885033b2022-07-21 11:46:03 +0100410
411 return op
412
413
414# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
415# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
416# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
417def convert_resize_to_upscale_and_average_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200418 pre_op = op
419 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000420 dtype = op.ifm.dtype
Tim Hall885033b2022-07-21 11:46:03 +0100421
Rickard Boline546def2022-01-25 15:45:00 +0000422 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Tim Hall47c76362022-07-18 21:26:47 +0100423 op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
Tim Hall3c5cfe92022-03-16 16:31:57 +0000424 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200425
426 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
Tim Hall47c76362022-07-18 21:26:47 +0100427
428 # Get upscale factor that was calculated in the supported operators check
429 upscale_factor = op.attrs["upscale_factor"]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200430
Rickard Boline546def2022-01-25 15:45:00 +0000431 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100432 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
433 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
Rickard Boline546def2022-01-25 15:45:00 +0000434 n = int(np.log2(upscale_factor))
435
Tim Hall885033b2022-07-21 11:46:03 +0100436 # Perform x2 upscaling n-1 times
Rickard Boline546def2022-01-25 15:45:00 +0000437 scaled_op = pre_op
438 for count in range(n - 1):
439 if count > 0:
440 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200441 scaled_op.inputs[0] = pre_op.outputs[0]
442
Tim Hall885033b2022-07-21 11:46:03 +0100443 # Nearest neighbor x2 upscaling
Tim Hall47c76362022-07-18 21:26:47 +0100444 upscaled_shape = upscaled_shape * 2
Rickard Boline546def2022-01-25 15:45:00 +0000445 shape = op.ofm_shapes[0].as_list()
446 shape[1:3] = upscaled_shape
447 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
448 out_tens.quantization = op.outputs[0].quantization.clone()
449 scaled_op.set_output_tensor(out_tens)
450 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200451
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200452 scaled_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000453 DebugDatabase.add_optimised(op, scaled_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200454
Tim Hall885033b2022-07-21 11:46:03 +0100455 # Last x2 upscaling
Rickard Boline546def2022-01-25 15:45:00 +0000456 if n > 1:
457 scaled_op = op.clone(f"_{n-1}")
458 scaled_op.inputs[0] = pre_op.outputs[0]
Tim Hall885033b2022-07-21 11:46:03 +0100459
460 if scaled_op.original_type == Op.ResizeBilinear:
461 if scaled_op.attrs["align_corners"]:
462 # no padding
463 scaled_op.attrs["padding"] = Padding.VALID
464 else:
465 # padding to the right and bottom (limits average pool to 8x8 kernel)
466 scaled_op.attrs["padding"] = Padding.EXPLICIT
467 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
468
469 # kernal size dependent on the upscaling factor
470 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
471 else: # Op.ResizeNearestNeighbor
472 if scaled_op.attrs["align_corners"]:
473 # use depthwise conv to select the correct value
474 scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
475 else:
Johan Alfvéna64616c2022-10-17 12:29:12 +0200476 # Keep 1x1 kernel and average pool, this applies both when
477 # half-pixel-centers is True and False. Calculations are the
478 # same in the reference.
Tim Hall885033b2022-07-21 11:46:03 +0100479 pass
480
Rickard Boline546def2022-01-25 15:45:00 +0000481 scaled_op.outputs = outputs
482 scaled_op.outputs[0].ops = [scaled_op]
483 scaled_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000484 DebugDatabase.add_optimised(op, scaled_op)
Rickard Boline546def2022-01-25 15:45:00 +0000485
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200486 return op
487
488
Raul Farkas66207142023-05-25 11:15:20 +0100489def convert_argmax_to_depthwise_conv_and_max_pool(op: Operation, arch, nng) -> Operation:
Rickard Bolin6986a072022-12-19 12:33:40 +0000490 """
491 Convert ArgMax to DWConv2D->MaxPool->DWConv2D, see details below.
492
493 Example:
494 arr = [4, [00000100,
495 6, = 00000110, # <-- This is the largest value, so we're expecting argmax(arr) = 1
496 5] 00000101]
497
498 Use 16-bit precision and shift all values 7 bits to the left:
499 Shifted_arr = [0000001000000000,
500 0000001100000000,
501 0000001010000000]
502
503 Add "c - index of channel" to each channel:
504 Shifted_arr_plus_reverse_idx = [0000001000000010, (+2)
505 0000001100000001, (+1)
506 0000001010000000] (+0)
507
508 The index is reversed since ArgMax selects the lowest index if maximum value is found at two index. The index will
509 act as a tie-breaker between channels with equal values and since we want the smallest channel index to be chosen
510 we reverse the index before the maxpool and then subtract the index from the number of channel after the maxpool to
511 get the correct index.
512
513 Find the maximum value in the array:
514 val = max(shifted_arr_plus_reverse_idx) = 0000001100000001
515
516 Subtract the value from the number of channels:
517 shifted_arr_plus_idx = (c-1) - val = 2 - 1 = 1
518
519 Extract the 7 lowest bits using a LUT to cut off the 9 most significant bits:
520 idx = LUT(val) = 0000000000000001 = 1
521 """
522
523 if op.type == Op.ArgMax:
524 ifm, ofm = op.inputs[0], op.outputs[0]
525 identity_quant = QuantizationParameters()
526 identity_quant.zero_point = 0
527 identity_quant.scale_f32 = 1.0
Rickard Bolin6986a072022-12-19 12:33:40 +0000528 # Add last dimension to ofm shape
529 ofm.shape += [1]
530 ofm.ops = []
531
532 # Create 1x1 Depthwise convolution with 2**7 weights for each channel to convert precision to 16 bit and shift
533 # all values 7 bits to the left
534 # Set necessary depthwise attributes
535 dw_op_attrs = {
536 "padding": Padding.VALID,
537 "stride_h": 1,
538 "stride_w": 1,
539 "strides": (1, 1, 1, 1),
540 "depth_multiplier": 1,
541 "channel_multiplier": 1,
542 "dilation_h_factor": 1,
543 "dilation_w_factor": 1,
544 "dilation": (1, 1, 1, 1),
545 "explicit_padding": None,
546 }
Johan Alfvenc1ad80b2023-03-31 10:19:23 +0200547 orig_name = op.name
548 op.name = f"{orig_name}_depthwise_conv_SHL_7"
Rickard Bolin6986a072022-12-19 12:33:40 +0000549 op.type = Op.DepthwiseConv2DBias
550 op.attrs.update(dw_op_attrs)
Johan Alfven56811e62023-03-27 11:33:50 +0200551 n, h, w, c = full_shape(4, ifm.shape, 1)
Rickard Bolin6986a072022-12-19 12:33:40 +0000552 shape = [1, 1, 1, c]
553 kernel = np.dstack([2**7] * c)
554 op.inputs = []
555 op.add_input_tensor(ifm)
556 op.add_input_tensor(
557 create_const_tensor(
558 "weights",
559 shape,
560 DataType.uint8,
561 np.array(kernel).reshape(shape),
562 quantization=identity_quant,
563 ),
564 )
565 # Let the bias for each channel be the "reverse" index of the channel it is in, ie c - channel_idx
566 reverse_idxs = list(reversed(range(c)))
567 bias_tensor = create_const_tensor(op.name + "_bias", [c], DataType.int64, reverse_idxs)
568 op.add_input_tensor(bias_tensor)
569
570 intermediate_tens = Tensor([n, h, w, c], DataType.int16, "int16_and_shifted_7_bits_left")
571 intermediate_tens.quantization = ifm.quantization
572 op.set_output_tensor(intermediate_tens)
573 op.set_ifm_ofm_shapes()
574 orig_ifm_shape = op.ifm_shapes[0]
575 DebugDatabase.add_optimised(op, op)
576
577 # To extract 7 least significant bits and swap reverse index back to real index using a LUT activation, we set
578 # the base value to c-1 and slope to -128. The 16-bit LUT uses a table of 32-bit values where the top 16 bits
579 # represent the slope and bottom 16 bits the base which are used to interpolate the activation value.
580 slope = (-128 & 0xFFFF) << 16 # Top 16 bits of 32 bit LUT table value
581 base = c - 1 # Bottom 16 bits of the LUT table value
582 lut_tensor = create_const_tensor(
583 "maxpool_LUT_extract_7_LSB",
584 [1, 1, 1, 512],
585 DataType.uint32,
586 [slope + base] * 512,
587 TensorPurpose.LUT,
588 )
589
590 # Split large feature maps into smaller chunks since the Depthwise Maxpool height dimension can overflow due to
591 # flattening the ifm to (H*W)xCx1
592 max_height = 2**16 // orig_ifm_shape.width
593 num_full_height_ops = orig_ifm_shape.height // max_height
594 last_op_height = orig_ifm_shape.height - max_height * num_full_height_ops
595 op_heights = [max_height] * num_full_height_ops
596 if last_op_height > 0:
597 op_heights.append(last_op_height)
598
599 # Create maxpool output tensor which is reshaped to 1x(H*W)x1x1. The product H*W might be larger than the
600 # maximum allowed height, but that's handled by reading and writing the data in chunks
601 maxpool_ofm = Tensor([1, orig_ifm_shape.height * orig_ifm_shape.width, 1, 1], DataType.int16, "argmax_maxpool")
602 maxpool_ofm.quantization = identity_quant
603
604 for op_idx, op_height in enumerate(op_heights):
605 maxpool_op = create_depthwise_maxpool(
606 f"dw_maxpool_{op_idx}", intermediate_tens, orig_ifm_shape, identity_quant
607 )
608 maxpool_op.outputs = [maxpool_ofm]
609 maxpool_ofm.ops.append(maxpool_op)
610 maxpool_op.ofm_shapes = [Shape4D(maxpool_ofm.shape)]
611 maxpool_op.set_activation_lut(lut_tensor)
612
613 # Set read and write shapes/offsets to read/write chunks of the IFM/OFM
614 maxpool_op.read_shapes[0] = Shape4D([1, op_height * orig_ifm_shape.width, orig_ifm_shape.depth, 1])
615 maxpool_op.read_offsets[0] = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0])
616 maxpool_op.write_shape = Shape4D([1, op_height * orig_ifm_shape.width, 1, 1])
617 maxpool_op.write_offset = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0])
618 DebugDatabase.add_optimised(op, maxpool_op)
619
Johan Alfvenc1ad80b2023-03-31 10:19:23 +0200620 # Set final shape
621 maxpool_ofm.set_all_shapes([1, h, w, 1])
622
623 # Convert 16bit to 32bit or 64bit
624 if ofm.dtype == DataType.int64:
625 # If OFM dtype is int64 the result is converted by two cast ops (16bit to 32bit)
626 #
627 # A -> B -> C -> D (OFM)
628 # |0001| |00010000| |0001|0000| |00010000|00000000|
629 # i16 i32 i16 i16 i32 i32
630 # <-------i64------->
631 #
632 # Memcpy is used to copy the content from B to C and from D to OFM
633 # Memcpy will be turned into a nop or an DMA transer if memory regions differs.
634 intermediate_32bit = Tensor([1, h, w, 1], DataType.int32, f"{orig_name}_32bit")
635 else:
636 intermediate_32bit = ofm
637
638 op_cast = create_cast_op(f"{orig_name}_cast_to_32bit_1", maxpool_ofm, intermediate_32bit)
639 DebugDatabase.add_optimised(op, op_cast)
640
641 if ofm.dtype == DataType.int64:
642 # Create int16 tensor with double shape to cover the intermediate_32bit result from the first cast
643 intermediate_16bit_2x_size = Tensor([1, h, w, 2], DataType.int16, f"{orig_name}_16bit_2x_size")
644 memcpy_op = create_memcpy(f"{orig_name}_memcpy_1", intermediate_32bit, intermediate_16bit_2x_size)
645 DebugDatabase.add_optimised(op, memcpy_op)
646
647 # Create int32 tensor with double ofm shape to be able to store a "int64" result
648 intermediate_32bit_2x_size = Tensor([1, h, w, 2], DataType.int32, f"{orig_name}_32bit_2x_size")
649
650 op_cast = create_cast_op(
651 f"{orig_name}_cast_to_32bit_2", intermediate_16bit_2x_size, intermediate_32bit_2x_size
652 )
653 DebugDatabase.add_optimised(op, op_cast)
654
655 memcpy_op = create_memcpy("f{orig_name}_memcpy_2", intermediate_32bit_2x_size, ofm)
656 DebugDatabase.add_optimised(op, memcpy_op)
Rickard Bolin6986a072022-12-19 12:33:40 +0000657
658 return op
659
660
Rickard Bolinfea15162022-07-04 16:19:16 +0000661def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
662 def _compute_interpolation_values(index, input_size, output_size):
663 scale = input_size / output_size
664 scaled_value = (index + 0.5 * half_pixel_centers) * scale - 0.5 * half_pixel_centers
665 lower_bound = max(np.floor(scaled_value), 0)
666
667 return scaled_value, lower_bound
668
669 def _compute_kernels(input_height, input_width, output_height, output_width):
670 kernels = []
671 for y in (1, 2):
672 for x in (1, 2):
673 sv_h, lb_h = _compute_interpolation_values(y, input_height, output_height)
674 sv_w, lb_w = _compute_interpolation_values(x, input_width, output_width)
675
676 # Interpolation values calculated for (x, y) = ([1, 2], [1, 2]) will always generalize to the whole
677 # input for upscale = 2 and input sizes >= 2x2 and be in the correct order for going left-to-right,
678 # top-to-bottom - same as the depthwise convolution strides across each tile
679 kernel = np.zeros((2, 2))
680 kernel[1, 1] = (1 - (sv_h - lb_h)) * (1 - (sv_w - lb_w))
681 kernel[0, 1] = (sv_h - lb_h) * (1 - (sv_w - lb_w))
682 kernel[1, 0] = (1 - (sv_h - lb_h)) * (sv_w - lb_w)
683 kernel[0, 0] = (sv_h - lb_h) * (sv_w - lb_w)
684 kernel *= 16
685 kernels.append(kernel)
686
687 return kernels
688
689 def _build_convolutions(op, kernels):
690 dw_op_attrs = {
691 "padding": Padding.TILE,
692 "stride_h": 1,
693 "stride_w": 1,
694 "strides": (1, 1, 1, 1),
695 "depth_multiplier": 1,
696 "channel_multiplier": 1,
697 "dilation_h_factor": 1,
698 "dilation_w_factor": 1,
699 "dilation": (1, 1, 1, 1),
700 }
701 ifm = op.ifm
702 ofm = op.ofm
703 ofm.ops = []
704 elem_size = 2 if ofm.dtype == DataType.int16 else 1
705
706 n, h, w, c = ifm.shape
707 _, _, ow, _ = ofm.shape
708
709 intermediate_tens = Tensor(ifm.shape, ifm.dtype, "intermediate_tens")
710 intermediate_tens.quantization = op.outputs[0].quantization.clone()
711 avgpool_op = op
712 avgpool_op.name = "rb_init_avgpool"
713 avgpool_op.type = Op.AvgPool
714 avgpool_op.attrs["padding"] = Padding.VALID
715 avgpool_op.attrs["stride_w"] = 1
716 avgpool_op.attrs["stride_h"] = 1
717 avgpool_op.attrs["filter_width"] = 1
718 avgpool_op.attrs["filter_height"] = 1
719 avgpool_op.attrs["strides"] = [1, 1, 1, 1]
720 avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
721
722 avgpool_op.add_input_tensor(ifm)
723 avgpool_op.set_output_tensor(intermediate_tens)
724 avgpool_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000725 DebugDatabase.add_optimised(op, op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000726
727 dw_conv = Operation(Op.DepthwiseConv2DBias, "depthwise_conv")
728 dw_conv._original_type = Op.ResizeBilinear
729 dw_conv.write_shape = Shape4D(n, h, w, c)
730 dw_conv.write_offset = Shape4D(0, 0, 0, 0)
731
Tim Hall5ff4cd12023-05-16 22:39:14 +0100732 # Resize bilinear requires rounding away from zero
733 dw_conv.rounding_mode = RoundingMode.AwayZero
Rickard Bolinfea15162022-07-04 16:19:16 +0000734
735 # Double height and width stride to write the output of each of the four depthwise convolutions below
736 # interleaved with each other when combined with OFM tile base offsets.
737 dw_conv.ofm_stride_multiplier = [1, 2, 2] # C/H/W
738
739 # Choose tile padding direction - pad by 1 with edge values in two direction.
740 # For example, TL (top left) will pad top and left in H/W-plane in all channels.
741 directions = [[1, 1, 0, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 1]] # TL, TR, BL, BR
742 for i in (0, 1):
743 for j in (0, 1):
744 index = i * 2 + j
745 dw_conv.name = f"depthwise_conv_{index}"
746 dw_op_attrs["explicit_padding"] = directions[index]
747 dw_conv.attrs.update(dw_op_attrs)
748
749 # This will offset the start of the write by modifying the Tile 0 base address
750 dw_conv.tile_base_offsets_ofm[0] = (i * ow + j) * c * elem_size
751
752 ofm.ops.append(dw_conv)
753 dw_conv.outputs = [ofm]
754
755 kernel = kernels[index]
756 shape = [2, 2, 1, c]
757 kernel = np.dstack([kernel] * c)
758
759 quant = QuantizationParameters()
760 quant.zero_point = 0
761 quant.scale_f32 = 1.0 / 16
762
763 dw_conv.inputs = []
764 dw_conv.add_input_tensor(intermediate_tens)
765 dw_conv.add_input_tensor(
766 create_const_tensor(
767 "weights",
768 shape,
769 intermediate_tens.dtype,
770 np.array(kernel).reshape(shape),
Rickard Bolinfea15162022-07-04 16:19:16 +0000771 quantization=quant,
772 ),
773 )
774
775 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
776 # need to append the bias tensor as resize ops only have 2 inputs
777 assert len(dw_conv.inputs) == 2
778 dw_conv.inputs.append(None)
Rickard Bolin017b4cc2022-09-23 10:16:48 +0000779 fixup_bias_tensors(dw_conv, None, None, dtype=DataType.int32)
Rickard Bolinfea15162022-07-04 16:19:16 +0000780
781 dw_conv.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000782 DebugDatabase.add_optimised(op, dw_conv)
783
Rickard Bolinfea15162022-07-04 16:19:16 +0000784 dw_conv = dw_conv.clone(f"_{index}")
785 return op
786
787 _, input_height, input_width, _ = op.ifm.shape
788 _, output_height, output_width, _ = op.ofm.shape
789
790 kernels = _compute_kernels(input_height, input_width, output_height, output_width)
791 op = _build_convolutions(op, kernels)
792
793 return op
794
795
Raul Farkas66207142023-05-25 11:15:20 +0100796def fixup_resize(op: Operation, arch, nng) -> Operation:
797 """Fixup resize ops to increase support for ResizeNearestNeighbor cases."""
Tim Hall885033b2022-07-21 11:46:03 +0100798 if op.type.is_resize_op() and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200799 if op.ifm_shapes[0] == op.ofm_shapes[0]:
Tim Hall885033b2022-07-21 11:46:03 +0100800 # Bypass the resize op which is essentially a NOP
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200801 op.inputs = op.inputs[:1]
802 op.type = Op.Identity
803 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
Tim Hall885033b2022-07-21 11:46:03 +0100804 convert_resize_1x1_to_add(op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000805 elif op.type == Op.ResizeBilinear and op.attrs.get("half_pixel_centers", False):
806 convert_resizebilinear_to_depthwise_convolutions(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200807 else:
Tim Hall885033b2022-07-21 11:46:03 +0100808 convert_resize_to_upscale_and_average_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200809
810 return op
811
812
813def convert_nop_split_to_identity(op, arch, nng):
814 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
815 # the list comprehension should return a list with a single tensor
816 # if it shouldn't, remove_passthrough_tensor will fail appropriately
817 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
818 op.type = Op.Identity
819 return op
820
821
Raul Farkas66207142023-05-25 11:15:20 +0100822def rewrite_fully_connected_input(op: Operation, arch, nng) -> Operation:
823 """Rewrite FullyConnected shape as 2D to allow it to run on NPU."""
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200824 # If the operation already have a read shape do not modify
825 # the ifm shape, since that will already be correct
826 if op.type == Op.FullyConnected and not op.read_shapes[0]:
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100827 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
828 assert new_shape is not None, "Tensor can not be reshaped to 2D"
829 op.ifm_shapes[0] = new_shape
Johan Alfvén65835e02022-10-13 10:49:30 +0200830
831 if op.ifm_shapes[0].batch > 1 and op.ofm_shapes[0].batch == 1:
832 # If IFM is batching then also make sure OFM is batching
833 h, w = op.ofm_shapes[0].height, op.ofm_shapes[0].width
834 op.ofm_shapes[0] = Shape4D([h * w, 1, 1, op.ofm_shapes[0].depth])
835
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200836 return op
837
838
Raul Farkas66207142023-05-25 11:15:20 +0100839def convert_batched_fc_shape(op: Operation, arch, nng) -> Operation:
840 """Convert batched FullyConnected op shape to allow for support on NPU."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200841 if op.type == Op.FullyConnected:
842 # Check if the first dimension indicates batching
843 if op.ifm_shapes[0].batch > 1:
Johan Alfvenf4937002024-04-20 08:04:27 +0200844 batching_split = {4: (2, 2), 6: (2, 3), 8: (2, 4), 9: (3, 3), 12: (3, 4), 16: (4, 4)}
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200845 n = op.ifm_shapes[0].batch
846 h, w = batching_split.get(n, (1, n))
847 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
848
849 # Reshape Weights to be 4D. IO becomes HWIO
850 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100851 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
852 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200853
854 n = op.ofm_shapes[0].batch
855 h, w = batching_split.get(n, (1, n))
856 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
Johan Alfvenf4937002024-04-20 08:04:27 +0200857 if h == 1 and w > 4:
858 # If batch can not be found in the split set the weights are going to be
859 # read from memory several times. Convert op to conv2d since this
860 # enables weight buffering.
861 op.type = Op.Conv2DBias
862 op.attrs["padding"] = Padding.SAME
863 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200864 return op
865
866
867def unfuse_activation_function(op):
868 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
869 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
870 op.activation = None
871 out_tens = op.outputs[0]
872 intermediate_tens = out_tens.clone("_act_intermediate")
873 act_op.set_output_tensor(out_tens)
874 act_op.add_input_tensor(intermediate_tens)
875 op.set_output_tensor(intermediate_tens)
876 act_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000877 DebugDatabase.add_optimised(op, act_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200878
879
880def rewrite_stridedslice_output(op, arch, nng):
881 if not op.run_on_npu or op.type != Op.StridedSlice:
882 return op
883
884 new_axis_mask = op.attrs["new_axis_mask"]
885 shrink_axis_mask = op.attrs["shrink_axis_mask"]
886
887 if shrink_axis_mask == 0 and new_axis_mask == 0:
888 return op
889
890 axis_4D = [0] * len(op.outputs)
891 for idx, out_tens in enumerate(op.outputs):
892 output_shape = list(out_tens.shape)
893
894 if shrink_axis_mask != 0:
895 n = 0
896 axis = 0
897 while shrink_axis_mask:
898 prev_mask = shrink_axis_mask
899 n += 1
900 shrink_axis_mask &= shrink_axis_mask - 1
901 axis = int(math.log2(prev_mask - shrink_axis_mask))
902 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
903
904 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
905 op.attrs["shrink_axis_mask"] = 0
906 if axis >= 0:
907 axis_4D[idx] = axis + (4 - len(output_shape))
908 else:
909 axis_4D[idx] = axis
910 op.ofm_shapes[idx] = Shape4D(output_shape)
911
912 elif new_axis_mask != 0:
913 n = 0
914 axis = 0
915 while new_axis_mask:
916 prev_mask = new_axis_mask
917 n += 1
918 new_axis_mask &= new_axis_mask - 1
919 axis = int(math.log2(prev_mask - new_axis_mask))
920 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
921 new_axis_mask >>= 1
922
923 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
924 op.attrs["new_axis_mask"] = 0
925 if axis >= 0:
926 axis_4D[idx] = axis + (4 - len(output_shape))
927 else:
928 axis_4D[idx] = axis
929 op.ofm_shapes[idx] = Shape4D(output_shape)
930
931 op.attrs["split_axis_4D"] = axis_4D
932 return op
933
934
935def rewrite_unpack_output(op, arch, nng):
936 tens = op.outputs[0]
937 if op.run_on_npu and op.type == Op.Unpack:
938 # Unpack is also referred to as Unstack
939 axis = int(op.attrs["axis"])
940 if axis < 0: # Convert to positive axis
941 axis = len(op.inputs[0].shape) + 1 + axis
942 op.type = Op.UnpackReshaped
943 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
944
945 axis_4D = axis + (4 - len(desired_output_shape))
946 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
947
948 for idx, out_tens in enumerate(op.outputs):
949 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
950 return op
951
952
953def add_padding_fields(op, arch, nng):
954 if op.run_on_npu:
955 if "padding" in op.attrs:
956 input_shape = op.ifm_shapes[0]
957 output_shape = op.ofm_shapes[0]
958 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
959 kernel_size = op.inputs[1].shape[:2]
960 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
961 kernel_size = op.attrs["ksize"][1:3]
962 else:
963 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
964
Johan Alfvenc0bb8682023-09-04 17:18:33 +0200965 if op.type == Op.Conv2DBackpropInputSwitchedBias and op.ifm_resampling_mode == resampling_mode.TRANSPOSE:
966 # Transpose with upscale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200967 padding, skirt = calc_upscaled_padding_and_skirt(
Johan Alfvenc0bb8682023-09-04 17:18:33 +0200968 op.attrs["padding"],
969 kernel_size,
970 op.attrs["strides"],
971 input_shape,
972 output_shape.height // input_shape.height,
973 output_shape.width // input_shape.width,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200974 )
975 else:
976 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200977 op.attrs["padding"],
978 op.kernel,
979 input_shape,
980 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200981 )
982
983 op.attrs["explicit_padding"] = padding
984 op.attrs["skirt"] = skirt
985
986 return op
987
988
Raul Farkas66207142023-05-25 11:15:20 +0100989def reorder_depthwise_weights(op: Operation, arch, nng) -> Operation:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200990 if op.type.is_depthwise_conv2d_op():
991 weight_tensor = op.inputs[1]
Alexander Hansson90c34b52023-05-31 15:03:03 +0000992 if not weight_tensor.weight_transpose_depthwise:
993 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
994 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
995 weight_tensor.weight_transpose_depthwise = True
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200996
997 return op
998
999
Raul Farkas3e7157b2023-05-09 09:09:17 +01001000def convert_avg_pool_to_conv2d(op: Operation, arch, nng) -> Operation:
1001 """Convert strided Average Pools with stride >= 4 to Conv2D."""
1002 if op.type != Op.AvgPool:
1003 return op
1004
1005 stride_x, stride_y = op.get_kernel_stride()
1006 # For strides <= 3 no optimization is needed
1007 if stride_x <= 3:
1008 return op
1009 h, w = op.attrs["filter_height"], op.attrs["filter_width"]
1010 inputs = op.inputs[0]
1011 shape = inputs.shape
1012
1013 # Set necessary conv2d attributes
1014 op.attrs.update(
1015 {
1016 "stride_h": stride_y,
1017 "stride_w": stride_x,
1018 "dilation_h_factor": 1,
1019 "dilation_w_factor": 1,
1020 "strides": (1, stride_y, stride_x, 1),
1021 "dilation": (1, 1, 1, 1),
1022 }
1023 )
1024
1025 # Change op type
1026 op.type = Op.Conv2DBias
1027 op.name += "_conv2d"
1028
1029 op.rounding_mode = RoundingMode.AwayZero
1030 shape = [h, w, 1, op.ofm.shape[-1]]
1031 weights = np.full(shape, 1)
1032 quant = QuantizationParameters(scale_f32=1 / (h * w), zero_point=0)
1033 # Add unit weight tensor
1034 op.add_input_tensor(
1035 create_const_tensor(
1036 "weights",
1037 shape,
1038 inputs.dtype,
1039 weights,
1040 quantization=quant,
1041 ),
1042 )
1043 op.weights.values = np.reshape(op.inputs[1].values, shape)
1044
1045 # Set IFM/OFM shapes after changing op type
1046 op.set_ifm_ofm_shapes()
1047 return op
1048
1049
1050def fixup_strided_conv(op: Operation, arch, nng):
Raul Farkas72c6a242023-03-16 16:38:05 +00001051 """Optimize or fixup strided Conv2DBias
1052 Optimization:
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001053 Reduce, when possible, the Conv2DBias stride from N with 1 > N > 4 to 1
1054 by re-shaping both IFM and filter.
Raul Farkas72c6a242023-03-16 16:38:05 +00001055
1056 Fixup:
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001057 Introduce software support for Conv2DBias with stride_width > 4 by
1058 reducing it to 1, 2 or 3 (HW supported strides) when possible by
1059 re-shaping both IFM and filter.
Raul Farkas72c6a242023-03-16 16:38:05 +00001060 """
Raul Farkas090f18a2023-01-24 16:29:06 +00001061 if op.type != Op.Conv2DBias:
Louis Verhaard43d27582022-03-17 14:06:00 +01001062 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001063 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +01001064 weight_tensor = op.weights
1065 ifm_shape = op.ifm_shapes[0]
Raul Farkas69782af2023-05-09 10:39:52 +01001066
1067 # Do not optimize if op is not the first in the network and stride is
1068 # supported by the hardware
1069 if op.op_index != 0 and stride_x < 4:
1070 return op
1071
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001072 resize_factor, final_stride = calc_resize_factor(ifm_shape.width, stride_x)
1073
1074 def calc_filter_padding(
1075 ifm_padding_type: Padding | None,
1076 ifm_current_padding_x: int,
1077 post_op_stride: int,
1078 opt_resize_factor: int,
1079 filter_width: int,
Raul Farkas3b64f062023-05-16 17:18:31 +01001080 ifm_width: int,
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001081 ) -> tuple[int, int, int, int]:
1082 """Calculate zero padding to be added to the filter.
1083
1084 Parameters
1085 ----------
1086 ifm_padding_type : Padding or None
1087 The padding type that is applied to the IFM.
1088 ifm_current_padding_x : int
1089 Padding amount that is added to the IFM before optimization.
1090 post_op_stride : int
1091 The final stride once optimization is performed.
1092 opt_resize_factor : int
1093 The factor by which the stride will be reduced.
1094 E.g. opt_resize_factor = 2 on a stride of 4 will produce
1095 a stride of 2 after the optimization
1096 filter_width : int
1097 Width of the filter before optimization.
Raul Farkas3b64f062023-05-16 17:18:31 +01001098 ifm_width : int
1099 Width of the IFM before optimization
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001100
1101 Returns
1102 -------
1103 padding : tuple[int, int, int, int]
1104 A tuple with the ammount of padding on each side (top, left, bottom, right)
1105 """
1106 padding_size = 0
1107 padding = (0, 0, 0, 0)
1108 if ifm_padding_type and ifm_padding_type != Padding.VALID:
Raul Farkas3b64f062023-05-16 17:18:31 +01001109 # Compute padding size for the filter that guarantees that HW padding added to IFM matches
1110 # before and after the optimization is performed
1111 expected_filter_size = 0
1112 pre_opt_stride = post_op_stride * opt_resize_factor
1113 post_opt_ifm_width = ifm_width // opt_resize_factor
1114 # Compute the total expected filter size post optimization that ensures that the same HW padding
1115 # is added to IFM.
1116 # There are two ways of calculating required filter size depending on whether IFM width is divisible
1117 # by stride width or not. These approaches match the cases used to calculate HW padding in
1118 # needed_total_padding method.
1119 if ifm_width % pre_opt_stride == 0:
1120 expected_filter_size = ifm_current_padding_x + post_op_stride
1121 else:
1122 expected_filter_size = ifm_current_padding_x + (post_opt_ifm_width % post_op_stride)
1123 # Compute padding size from expected filter size
1124 padding_size = expected_filter_size * opt_resize_factor - filter_width
1125
1126 if ifm_current_padding_x == 0:
1127 # If no HW padding is added to IFM, divide filter padding between left and right following
1128 # the same strategy as the reference.
1129 padding_left = padding_size // 2
1130 else:
1131 # If HW padding is added to IFM, split padding for the filter so that left padding and right padding
1132 # are proportional to left and right HW padding.
1133 left_hw_padding = ifm_current_padding_x // 2
1134 # Compute filter padding
1135 padding_left = padding_size // ifm_current_padding_x * left_hw_padding
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001136 padding = (0, padding_left, 0, padding_size - padding_left)
1137
1138 # Check if filter width is divisible by the stride width (required for optimization)
Raul Farkas3b64f062023-05-16 17:18:31 +01001139 # If filter width is not divisible by stride width and no HW padding is added to IFM, compute
1140 # filter padding required for the filter width to be divisible by the stride width and apply it as right
1141 # padding.
1142 if filter_width % opt_resize_factor != 0 and (padding_size == 0 or ifm_current_padding_x == 0):
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001143 padding_size = opt_resize_factor - (filter_width % opt_resize_factor)
1144 # Add padding zeros to the right
1145 padding = (0, 0, 0, padding_size)
1146
1147 return padding
1148
1149 # Compute the depth of the IFM once the strided Conv2D is optimised
1150 post_opt_ifm_depth = ifm_shape.depth * resize_factor
1151
1152 if stride_x > 1 and (post_opt_ifm_depth <= 8 or stride_x > 3) and resize_factor != 1 and weight_tensor is not None:
1153 k_w, _ = op.get_kernel_size()
1154 weight_shape = weight_tensor.shape
1155
1156 padding_type = op.attrs.get("padding", None)
1157 if padding_type in (None, Padding.EXPLICIT, Padding.TILE):
Louis Verhaard43d27582022-03-17 14:06:00 +01001158 return op
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001159 # Compute current padding as if IFM padding is SAME
1160 curr_padding_x = needed_total_padding(ifm_shape.width, stride_x, k_w)
1161 # Compute the padding needed on the filter for the optimisation
1162 _, left_filter_padding, _, right_filter_padding = calc_filter_padding(
Raul Farkas3b64f062023-05-16 17:18:31 +01001163 padding_type, curr_padding_x, final_stride, resize_factor, k_w, ifm_shape.width
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001164 )
1165 total_horizontal_padding = left_filter_padding + right_filter_padding
1166 # If IFM padding is enabled, check if pre-opt and post-opt padding is
1167 # the same while taking into consideration the extra filter padding.
1168 if padding_type == Padding.SAME:
1169 optimised_padding_x = needed_total_padding(
1170 ifm_shape.width // resize_factor, final_stride, (k_w + 1 + total_horizontal_padding) // resize_factor
1171 )
1172 if curr_padding_x != optimised_padding_x:
1173 # Horizontal padding would become different after optimisation; this would not work
1174 return op
1175
1176 # Resize IFM
Raul Farkas090f18a2023-01-24 16:29:06 +00001177 op.ifm_shapes[0] = Shape4D(
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001178 [ifm_shape.batch, ifm_shape.height, ifm_shape.width // resize_factor, ifm_shape.depth * resize_factor]
Raul Farkas090f18a2023-01-24 16:29:06 +00001179 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001180
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001181 # Compute list of 0 padding for each dimensions of the filter
1182 filter_dimension_padding = [(0, 0) for _ in weight_tensor.shape]
1183 # Update padding for filter width with computed padding
1184 filter_dimension_padding[1] = (left_filter_padding, right_filter_padding)
1185 # Add padding to the filter
1186 zero_point = weight_tensor.quantization.zero_point
1187 padding_constant = zero_point if np.isscalar(zero_point) else 0
1188 padded_filter_tensor = np.pad(weight_tensor.values, filter_dimension_padding, constant_values=padding_constant)
1189 weight_shape[1] = padded_filter_tensor.shape[1]
1190 weight_tensor.values = padded_filter_tensor
Raul Farkas090f18a2023-01-24 16:29:06 +00001191 # Change weight shape based on stride_x
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001192 weight_shape[1] //= resize_factor
1193 weight_shape[2] *= resize_factor
Raul Farkas090f18a2023-01-24 16:29:06 +00001194
James Peet7519d502021-07-19 16:47:58 +01001195 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001196 weight_tensor.set_all_shapes(weight_shape)
1197 # If multiple copies of the weights are used, we could avoid
1198 # them having the same address by changing the value_id
1199 weight_tensor.value_id = uuid.uuid4()
1200
1201 # Strides
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001202 stride_x = final_stride
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001203 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
1204
Johan Alfvenafb56ae2023-10-27 13:08:21 +02001205 ofm_shape = op.ofm_shapes[0]
1206 if ofm_shape.height == 1 or ofm_shape.width == 1:
1207 # If height or width is 1 no stride is done in y or x direction and stride value can be set to 1
1208 # Before forcing kernel stride to 1 make sure to calculate the correct padding since it is
1209 # based on the original kernel stride
1210 padding, _ = calc_padding_and_skirt(
1211 op.attrs["padding"],
1212 op.kernel,
1213 ifm_shape,
1214 op.attrs.get("explicit_padding"),
1215 )
1216 # Use explicit padding so it is not recalculated later with the wrong kernel stride
1217 op.attrs["padding"] = Padding.EXPLICIT
1218 op.attrs["explicit_padding"] = padding
1219
1220 stride_y = 1 if ofm_shape.height == 1 else stride_y
1221 stride_x = 1 if ofm_shape.width == 1 else stride_x
1222
1223 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
1224
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001225 return op
1226
1227
Raul Farkas66207142023-05-25 11:15:20 +01001228def convert_conv_to_fc(op: Operation, arch, nng) -> Operation:
1229 """Convert 1x1 Conv2D that behave like FullyConnected to FullyConnected, since they don't need any weight
1230 buffering.
1231 """
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001232 # Conv 1x1 can be equivalent to Fully Connected.
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001233 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
1234 if op.type == Op.Conv2DBias:
1235 h = op.ifm_shapes[0].height
1236 w = op.ifm_shapes[0].width
1237 kh, kw, _, _ = op.inputs[1].shape
1238 if h == 1 and w == 1 and kh == 1 and kw == 1:
1239 # Overwrite this op as a Fully Connected Op
1240 op.name += "_fc"
1241 op.type = Op.FullyConnected
1242 op.attrs = {
1243 "weights_format": 0,
1244 }
1245 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
1246 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +01001247 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
1248 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001249
1250 DebugDatabase.add_optimised(op, op)
1251 return op
1252
1253
Raul Farkas66207142023-05-25 11:15:20 +01001254def fixup_relus_with_differing_ifm_ofm_scaling(op: Operation, arch, nng) -> Operation:
1255 """Fixup Relu with different IFM and OFM to allow fusing by adding its own primary op."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001256 if op.run_on_npu and op.type.is_relu_op():
1257 ifm = op.inputs[0]
1258 ofm = op.outputs[0]
1259 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
1260 # and requires its own to be inserted
1261 if not check_quantized_tens_scaling_equal(ifm, ofm):
1262 # Override this op with its own primary op (avgpool)
1263 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
1264 # And fuse the original activation function to it
1265 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +02001266 # Add explicit rescaling
1267 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
1268 multiplier, shift = scaling.quantise_scale(rescale)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001269 relu_fused_op.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001270 # Tidy up and assign the ifm and ofm to the new op
1271 ifm.consumer_list.remove(op)
1272
1273 relu_fused_op.add_input_tensor(ifm)
1274 relu_fused_op.set_output_tensor(ofm)
1275 relu_fused_op.set_ifm_ofm_shapes()
1276 op = relu_fused_op
1277 return op
1278
1279
Raul Farkas66207142023-05-25 11:15:20 +01001280def convert_lstm(op: Operation, arch, nng) -> Operation:
1281 """Convert LSTM op into its basic opearations to allow for support on NPU."""
Fredrik Svedberg0ac08042023-04-11 22:35:04 +02001282 if op.type == Op.UnidirectionalSequenceLstm:
1283 lstm = Lstm(op)
1284 op = lstm.get_graph()
1285 return op
1286
1287
Raul Farkas66207142023-05-25 11:15:20 +01001288def convert_softmax(op: Operation, arch, nng) -> Operation:
1289 """Convert Softmax op into its basic operations to allow for support on NPU."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001290 if op.type == Op.Softmax and op.run_on_npu:
1291 softmax = SoftMax(op)
1292 op = softmax.get_graph()
1293 return op
1294
1295
Raul Farkas66207142023-05-25 11:15:20 +01001296def convert_prelu(op: Operation, arch, nng) -> Operation:
1297 """Convert PReLU op to other ops based on alpha values to allow for support on NPU."""
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001298 if op.type == Op.Prelu:
1299 ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
1300 if None in (ifm, alpha, ofm):
1301 return op
1302
Fredrik Svedberg66591652022-08-29 10:51:27 +02001303 if alpha.values is not None:
1304 # If const alpha check for possible optimisations
1305 alpha_zp = alpha.quantization.zero_point
1306 alpha_scale = alpha.quantization.scale_f32
1307 # If all alpha values are the same the PReLU can be converted to LeakyRelu
Rickard Bolin5fdcf172022-12-19 12:56:17 +00001308 alpha_min = (alpha.values.min().astype(int) - alpha_zp) * alpha_scale
1309 alpha_max = (alpha.values.max().astype(int) - alpha_zp) * alpha_scale
Fredrik Svedberg66591652022-08-29 10:51:27 +02001310 if alpha_min == alpha_max:
1311 # or even a Relu
1312 if alpha_min == 0:
1313 new_op = Op.Relu
1314 else:
1315 new_op = Op.LeakyRelu
1316 op.attrs["alpha"] = alpha_min
1317 # setup alpha_scaling for bit exact result
1318 ifm_scale = ifm.quantization.scale_f32
1319 ofm_scale = ofm.quantization.scale_f32
1320 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha_scale, ofm_scale)
1321 op.attrs["alpha_scaling"] = (alpha.values.min() - alpha_zp, alpha_scale, alpha_shift)
1322 # Change op type
1323 op.type = new_op
1324 op.name = op.name.replace("Prelu", new_op.name)
1325 del op.inputs[1] # Remove alpha tensor
1326 return op
1327 elif alpha_max < 1:
1328 # If alpha_max is less than 1 convert PReLU to Max(alpha * IFM, identity * IFM)
1329 # Multiply with alpha tensor
1330 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
1331 mul_alpha.add_input_tensor(ifm)
1332 mul_alpha.add_input_tensor(alpha)
1333 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1334 mul_alpha.set_output_tensor(fm_alpha)
1335 mul_alpha.set_ifm_ofm_shapes()
1336 DebugDatabase.add_optimised(op, mul_alpha)
1337 if check_quantized_tens_scaling_equal(ifm, ofm):
1338 # No scaling is needed
1339 fm_id = ifm
1340 else:
1341 # Add multiplication with identity
1342 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1343 mul_identity.add_input_tensor(ifm)
1344 # Create const tensor containing identity as scalar
1345 quantization = ifm.quantization.clone()
1346 quantization.scale_f32 = np.float32(1)
1347 quantization.zero_point = 0
1348 one = create_const_tensor("one_const", [], ifm.dtype, [1], quantization=quantization)
1349 mul_identity.add_input_tensor(one)
1350 # Make sure that fm_id is allocated to a different address than fm_alpha
1351 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1352 mul_identity.set_output_tensor(fm_id)
1353 mul_identity.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +00001354 DebugDatabase.add_optimised(op, mul_identity)
Fredrik Svedberg66591652022-08-29 10:51:27 +02001355
1356 # Combine scaled and alpha multiplied values
1357 max_op = Operation(Op.Maximum, op.name + "_max")
1358 max_op.add_input_tensor(fm_alpha)
1359 max_op.add_input_tensor(fm_id)
1360 max_op.set_output_tensor(ofm)
1361 max_op.set_ifm_ofm_shapes()
1362
1363 DebugDatabase.add_optimised(op, max_op)
1364 ifm.consumer_list.remove(op)
1365 return max_op
1366
1367 # Catch all PReLU conversion for the cases that could not be optimised above
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001368 no_scale_quant = ifm.quantization.clone()
1369 no_scale_quant.scale_f32 = None
1370 no_scale_quant.zero_point = 0
Fredrik Svedberg66591652022-08-29 10:51:27 +02001371 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001372
1373 # Select values < 0
1374 min_op = Operation(Op.Minimum, op.name + "_min")
1375 min_op.add_input_tensor(ifm)
1376 min_op.add_input_tensor(zero)
1377 fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
1378 min_op.set_output_tensor(fm_negative)
1379 min_op.set_ifm_ofm_shapes()
1380 DebugDatabase.add_optimised(op, min_op)
1381
1382 # and multiply with alpha tensor
1383 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
1384 mul_alpha.add_input_tensor(fm_negative)
1385 mul_alpha.add_input_tensor(alpha)
1386 fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
1387 mul_alpha.set_output_tensor(fm_alpha)
1388 mul_alpha.set_ifm_ofm_shapes()
1389 DebugDatabase.add_optimised(op, mul_alpha)
1390
1391 # Select (and scale) values > 0
1392 relu_op = Operation(Op.Relu, op.name + "_relu")
1393 relu_op.add_input_tensor(ifm)
1394 fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1395 relu_op.set_output_tensor(fm_scaled)
1396 relu_op.set_ifm_ofm_shapes()
1397 DebugDatabase.add_optimised(op, relu_op)
1398
1399 # Add scaled and alpha multiplied values (without scaling)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001400 add_op = Operation(Op.Add, op.name + "_add")
1401 add_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001402 add_op.add_input_tensor(fm_alpha)
1403 add_op.add_input_tensor(fm_scaled)
1404 add_op.set_output_tensor(ofm)
1405 add_op.set_ifm_ofm_shapes()
1406
1407 DebugDatabase.add_optimised(op, add_op)
1408 ifm.consumer_list.remove(op)
1409 op = add_op
1410
1411 return op
1412
1413
Raul Farkas66207142023-05-25 11:15:20 +01001414def convert_mul_max_to_abs_or_lrelu(op: Operation, arch, nng) -> Operation:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001415 r"""Whenever there is a subgraph with this topology:
1416
Jonas Ohlssond8575072022-03-30 10:30:25 +02001417 Input X For X = -1 or X > 0
1418 | \ / This subgraph can be replaced with either
1419 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
1420 | /
1421 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001422 """
1423
1424 if op.type == Op.Maximum:
1425 # finds the Mul input(s) to the Max
1426 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
1427 if len(muls) == 1:
1428 mul = muls[0].ops[0]
1429 elif len(muls) == 2:
1430 # In the case both inputs are Muls, find the one with the same input as the Max
Fredrik Svedberg66591652022-08-29 10:51:27 +02001431 mul_ifms = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1]
1432 if len(mul_ifms):
1433 mul = mul_ifms[0].ops[0]
1434 else:
1435 # Not using same input
1436 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001437 else:
1438 # No Mul inputs
1439 return op
1440
1441 # make sure the Mul doesn't have any other consumers
1442 mul_ofm = mul.outputs[0]
1443 if len(mul_ofm.consumers()) != 1:
1444 return op
1445 # make sure the Mul doesn't have a fused activation function
1446 if mul.activation:
1447 return op
1448 ifm, ofm = op.get_ifm_ofm()
1449 if ifm is None or ofm is None:
1450 return op
1451
1452 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1453 return op
1454 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
1455 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
1456 return op
1457
1458 # finds the branched input that goes to both the Max and the Mul
1459 shared = set(op.inputs) & set(mul.inputs)
1460 if len(shared) == 1:
1461 shared_in = shared.pop()
1462 # find the constant scalar input to the Mul
1463 const_tens = (set(mul.inputs) - {shared_in}).pop()
1464 # check that it is a scalar
1465 if const_tens.shape != []:
1466 return op
1467 const = const_tens.ops[0]
1468 # check that it is a constant
1469 if const.type != Op.Const:
1470 return op
1471 # Remove the Mul from the shared input's consumers
1472 shared_in.consumer_list.remove(mul)
1473 else:
1474 return op
1475
1476 val = const.outputs[0].values
1477 if val >= 0:
1478 new_op = Op.LeakyRelu
1479 op.attrs["alpha"] = val
1480 # to produce bit exact results, the alpha is not enough;
1481 # save additional scaling info in attr "alpha_scale", to be used as input
1482 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +01001483 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001484 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
1485 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
1486 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
1487 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
1488 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
1489 elif val == -1:
1490 new_op = Op.Abs
1491 else:
1492 return op
1493
1494 op.type = new_op
1495 op.name = op.name.replace("Maximum", new_op.name)
1496 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
1497 op.inputs = [shared_in]
1498 op.set_ifm_ofm_shapes()
1499
1500 # Record optimisation in debug database
1501 DebugDatabase.add_optimised(op, op)
1502
1503 return op
1504
1505
Raul Farkas66207142023-05-25 11:15:20 +01001506def convert_hardswish_to_lut(op: Operation, arch, nng) -> Operation:
1507 """Convert HardSwish to LUT to allow for support on NPU."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001508 if op.type == Op.HardSwish:
1509 ifm, ofm = op.get_ifm_ofm()
1510 # Generate the LUT
1511 ifm_scale = np.double(ifm.quantization.scale_f32)
1512 ofm_scale = np.double(ofm.quantization.scale_f32)
1513 zp_in = ifm.quantization.zero_point
1514 zp_out = ofm.quantization.zero_point
1515 ifm_scale_hires = (1 / 128) * ifm_scale
1516 relu_multiplier = np.double(3 / 32768)
1517 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
1518 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
1519 # Use 16bit scale
1520 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
1521 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
1522
1523 values = []
1524 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1525 quantized_min = min(ix)
1526 quantized_max = max(ix)
1527 for x in ix:
1528 input_value = x - zp_in
1529 input_value_hires = input_value * 128
1530 # Compute the input value on essentially the output scale, not shifted yet
1531 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
1532 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
1533 relu_value = np.int16(input_value_hires)
1534 if relu_shift < 31:
1535 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
1536
1537 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
1538
1539 if relu_shift < 31:
1540 relu_value = fp_math.shift_left16(relu_value, 1)
1541
1542 if relu_shift > 31:
1543 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
1544
1545 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
1546 # Now convert that to a 16bit fixedpoint value in [0, 1]
1547 relu_value = (relu_value + (1 << 15)) >> 1
1548 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
1549 shift = 31 - out_shift
1550 shift = -shift if shift < 0 else 0
1551 # Finally apply the output shift
1552 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
1553 lut_result = min(quantized_max, max(quantized_min, lut_result))
1554 values.append(lut_result)
1555 return convert_to_lut(op, values, "hardswish")
1556 return op
1557
1558
1559def convert_lrelu_to_mul_max(op, arch):
1560 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
1561 # (the opposite of convert_mul_max_to_abs_or_lrelu)
1562 ifm, ofm = op.get_ifm_ofm()
1563 if ifm is None or ofm is None:
1564 return op
1565
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001566 alpha = np.float32(op.attrs["alpha"])
1567 use_mul_max = 0 < alpha < 1
Fredrik Svedberg36424312022-09-16 09:39:26 +02001568 is_converted_prelu = "alpha_scaling" in op.attrs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001569 if use_mul_max:
1570 mul_ifm = ifm
1571 new_op = Op.Maximum
1572 else:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001573 # Need to use a different approach for alpha < 0 or alpha > 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001574 no_scale_quant = ifm.quantization.clone()
1575 no_scale_quant.scale_f32 = None
1576 no_scale_quant.zero_point = 0
1577 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
1578
1579 # Select values < 0
1580 min_op = Operation(Op.Minimum, op.name + "_min")
1581 min_op.add_input_tensor(ifm)
1582 min_op.add_input_tensor(zero)
1583 mul_ifm = ifm.clone(op.name + "_negative", set_unique=True)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001584 if alpha < 0 and not is_converted_prelu:
1585 # For negative alpha that is not from a converted PReLU we need to use
1586 # int32 Mul below to perform the (negative) alpha scaling
1587 mul_ifm.dtype = DataType.int32
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001588 min_op.set_output_tensor(mul_ifm)
1589 min_op.set_ifm_ofm_shapes()
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001590 new_op = Op.Add
1591 op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001592 DebugDatabase.add_optimised(op, min_op)
1593
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001594 # Add multiplication with alpha
1595 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001596 mul_alpha.add_input_tensor(mul_ifm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001597 # Create const tensor containing alpha as scalar
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001598 quantization = ifm.quantization.clone()
1599 quantization.min = 0
1600 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
1601 quantization.zero_point = 0
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001602 alpha_dtype = mul_ifm.dtype
Fredrik Svedberg36424312022-09-16 09:39:26 +02001603 if is_converted_prelu:
1604 # The LeakyRelu was the result from convert_prelu and the scaling is provided
Fredrik Svedberg66591652022-08-29 10:51:27 +02001605 scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001606 mul_alpha.explicit_scaling = ExplicitScaling(False, [alpha_shift], [alpha_scale])
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001607 elif alpha == 0 or np.isinf(1 / alpha):
1608 # Handling of alpha near or at zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001609 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001610 scalar = 0
1611 else:
1612 quantization.scale_f32 = alpha
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001613 if alpha_dtype == DataType.int32:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001614 # When the datatype is int32 (alpha negative) we need to do the scaling with the multiplication
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001615 scalar, _ = scaling.elementwise_mul_scale(ifm.quantization.scale_f32, alpha, ofm.quantization.scale_f32)
1616 else:
1617 scalar = 1
Tim Hall3b1578e2023-01-13 17:57:25 +00001618 alpha_tens = create_const_tensor(op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], quantization=quantization)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001619 mul_alpha.add_input_tensor(alpha_tens)
1620 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1621 mul_alpha.set_output_tensor(fm_alpha)
1622 mul_alpha.set_ifm_ofm_shapes()
1623 DebugDatabase.add_optimised(op, mul_alpha)
1624
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001625 if not use_mul_max:
1626 relu_op = Operation(Op.Relu, op.name + "_relu")
1627 relu_op.add_input_tensor(ifm)
1628 fm_id = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1629 relu_op.set_output_tensor(fm_id)
1630 relu_op.set_ifm_ofm_shapes()
1631 DebugDatabase.add_optimised(op, relu_op)
1632 elif check_quantized_tens_scaling_equal(ifm, ofm):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001633 # No identity multiplication is needed
1634 fm_id = ifm
1635 else:
1636 # Add multiplication with identity
1637 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1638 mul_identity.add_input_tensor(ifm)
1639 # Create const tensor containing identity as scalar
1640 quantization = ifm.quantization.clone()
1641 quantization.min = 0
1642 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001643 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001644 quantization.zero_point = 0
Tim Hall3b1578e2023-01-13 17:57:25 +00001645 identity_tens = create_const_tensor(op.name + "_id_scalar", [], ifm.dtype, [1], quantization=quantization)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001646 mul_identity.add_input_tensor(identity_tens)
1647 # Make sure that fm_id is allocated to a different address than fm_alpha
1648 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1649 mul_identity.set_output_tensor(fm_id)
1650 mul_identity.set_ifm_ofm_shapes()
1651 DebugDatabase.add_optimised(op, mul_identity)
1652
1653 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001654 op.type = new_op
1655 op.name = op.name.replace("LeakyRelu", new_op.name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001656 op.inputs = []
1657 ifm.consumer_list.remove(op)
1658 op.add_input_tensor(fm_alpha)
1659 op.add_input_tensor(fm_id)
1660 op.set_ifm_ofm_shapes()
1661
1662 DebugDatabase.add_optimised(op, op)
1663 return op
1664
1665
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001666def convert_to_lut8(op, fn, fn_name):
1667 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
1668 # fn is a function(real) -> real
1669 ifm, ofm = op.get_ifm_ofm()
1670 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1671 return op
1672 # Generate the LUT
1673 ifm_scale = np.double(ifm.quantization.scale_f32)
1674 ofm_scale = np.double(ofm.quantization.scale_f32)
1675 zp_in = ifm.quantization.zero_point
1676 zp_out = ofm.quantization.zero_point
1677 values = []
1678 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1679 quantized_min = min(ix)
1680 quantized_max = max(ix)
1681 for x in ix:
1682 x_real = ifm_scale * (x - zp_in)
1683 y_real = fn(x_real)
1684 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1685 lut_result = min(quantized_max, max(quantized_min, lut_result))
1686 values.append(lut_result)
1687 return convert_to_lut(op, values, fn_name)
1688
1689
1690def convert_lrelu_to_lut(op, arch):
1691 ifm, ofm = op.get_ifm_ofm()
1692 # Generate the LUT
1693 alpha = op.attrs["alpha"]
1694 ifm_scale = np.double(ifm.quantization.scale_f32)
1695 ofm_scale = np.double(ofm.quantization.scale_f32)
1696 zp_in = ifm.quantization.zero_point
1697 zp_out = ofm.quantization.zero_point
1698 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1699 alpha_scalar = 1
1700 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1701 if "alpha_scaling" in op.attrs:
1702 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1703 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1704 values = []
1705 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1706 quantized_min = min(ix)
1707 quantized_max = max(ix)
1708 for x in ix:
1709 if x < zp_in:
1710 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1711 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1712 )
1713 else:
1714 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1715 lut_result = min(quantized_max, max(quantized_min, lut_result))
1716 values.append(lut_result)
1717 return convert_to_lut(op, values, "lrelu")
1718
1719
Raul Farkas66207142023-05-25 11:15:20 +01001720def convert_lrelu(op: Operation, arch, nng) -> Operation:
1721 """Convert LeakyRelu to a LUT based solution if possible, otherwise a mul + max."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001722 if op.type != Op.LeakyRelu:
1723 return op
1724 ifm, ofm = op.get_ifm_ofm()
1725 if ifm is None or ofm is None:
1726 return op
Fredrik Svedberg36424312022-09-16 09:39:26 +02001727 alpha = op.attrs["alpha"]
1728 if alpha == 0:
1729 # When alpha is 0 the opertion can be converted to a ReLU
1730 op.type = Op.Relu
1731 op.name = op.name.replace("LeakyRelu", op.type.name)
1732 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001733 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1734 # use LUT for int8/uint8
1735 return convert_lrelu_to_lut(op, arch)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001736 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16 and alpha > 0:
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001737 # use LeakyRelu unmodified for int16 with equal input/output scaling and positive alpha
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001738 return op
1739 return convert_lrelu_to_mul_max(op, arch)
1740
1741
Raul Farkas66207142023-05-25 11:15:20 +01001742def convert_tanh_sigmoid_to_lut(op: Operation, arch, nng) -> Operation:
1743 """Convert int8/uint8 Sigmoid and Tanh to a LUT based solution."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001744 if op.type == Op.Sigmoid:
1745 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1746 elif op.type == Op.Tanh:
1747 return convert_to_lut8(op, math.tanh, "tanh")
1748 return op
1749
1750
Johan Gunnarsson98556372023-08-10 13:10:44 +02001751def convert_quantize(op: Operation, arch, nng) -> Operation:
1752 """Convert Quantize to Avgpool. This conversion only works for int-to-int re-quantization and
1753 not to/from floats. Therefor, this rewrite should only run after the supported ops check to
1754 avoid rewriting ops that will run on CPU."""
1755 if op.type == Op.Quantize:
1756 # Create a new AvgPool op and steal its attrs, then reuse the original op with different type
1757 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
1758 op.type = Op.AvgPool
1759 op.attrs = avgpool_op.attrs.copy()
1760
1761 DebugDatabase.add_optimised(op, op)
1762
1763 return op
1764
1765
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001766def fuse_activation_function_with_prev(op, arch, nng):
1767 # if op is a no-op: attempts to move the activation function to the preceding op
1768 if not op.attrs.get("is_nop", False) or op.activation is None:
1769 return op
1770 ifm, ofm = op.get_ifm_ofm()
1771 if ifm is None or ofm is None:
1772 return op
1773 # finds the input(s) to the operation
1774 prev_op = ifm.ops[0]
1775 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1776 fuse = (
1777 prev_op.run_on_npu
Johan Alfven67daf2a2023-10-30 20:39:01 +01001778 and prev_op.type != Op.Memcpy
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001779 and prev_op.type.npu_block_type != NpuBlockType.Default
1780 and len(ifm.ops) == 1
1781 and len(prev_op.outputs[0].consumers()) == 1
1782 and prev_op.activation is None
1783 )
1784 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1785 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1786 # LUT currently only works correctly for elementwise ops
1787 fuse = False
1788 if not fuse:
1789 return op
1790 # Move the fused activation function + corresponding info to prev_op
1791 prev_op.activation = op.activation
1792 prev_op.forced_output_quantization = op.forced_output_quantization
1793 if op.activation_lut is not None:
1794 prev_op.set_activation_lut(op.activation_lut)
1795 # Bypass op
1796 prev_op.set_output_tensor(ofm)
wilisa0179a89042022-11-02 17:18:43 +00001797 DebugDatabase.add_optimised(prev_op, prev_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001798 return op
1799
1800
1801def _leading_pad_ok(leading_pad, stride, kernel_size):
1802 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1803 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1804 max_size = kernel_size // 2
1805 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1806
1807
Raul Farkas66207142023-05-25 11:15:20 +01001808def replace_pad_by_hw_pad(op: Operation, arch, nng) -> Operation:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001809 """
1810 Tries to completely remove a PAD operator by using hardware padding.
1811 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1812 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1813 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1814 if both operations can be run on the NPU.
1815 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1816 """
1817 if (
1818 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +00001819 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001820 and op.run_on_npu
1821 and op.attrs["padding"] == Padding.VALID
1822 ):
1823 pad_op = op.ifm.ops[0]
1824 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1825 return op
1826 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1827 return op
1828 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1829 k = op.kernel
1830 k_w, k_h = k.dilated_wh()
1831
1832 # Check if the PAD operator can be replaced by hardware padding
1833 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1834 # Too much padding, it would require hardware padding to actually insert zeros
1835 return op
1836 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1837 return op
1838
1839 if op.type.is_avgpool_op():
1840 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1841 for pad, k_size in (
1842 (left, k_w),
1843 (right, k_w),
1844 (top, k_h),
1845 (bottom, k_h),
1846 ):
1847 if pad not in (0, k_size // 2):
1848 return op
1849 # Average pool is converted to depthwise, because NPU average pool + same padding
1850 # has a special implementation that is different from PAD followed by average pool with
1851 # valid padding.
1852 k_w, k_h = op.kernel.width, op.kernel.height
1853 ifm = op.ifm
1854 # Remember other inputs
1855 other_inputs = op.inputs[1:]
1856 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1857 quantization = QuantizationParameters(0.0, 255.0)
1858 quantization.scale_f32 = 1.0 / (k_w * k_h)
1859 quantization.zero_point = 0
1860 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1861 weights = np.full(shape, 1)
1862
1863 weight_tens = create_const_tensor(
1864 op.name + "_weights",
1865 shape,
1866 op.ifm.dtype,
1867 weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001868 purpose=TensorPurpose.Weights,
1869 quantization=quantization,
1870 )
James Peet7519d502021-07-19 16:47:58 +01001871 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001872 op.type = Op.DepthwiseConv2DBias
1873 op.inputs = []
1874 op.add_input_tensor(ifm)
1875 op.add_input_tensor(weight_tens)
Tim Hall5ff4cd12023-05-16 22:39:14 +01001876
1877 if op.ifm.dtype == DataType.uint8:
1878 op.rounding_mode = RoundingMode.HalfUp
1879
1880 # Add bias tensor, all biases set to 0
1881 op.inputs.append(None)
1882 fixup_bias_tensors(op, arch, nng, DataType.int32)
1883
1884 else:
1885 op.rounding_mode = RoundingMode.AwayZero
1886
1887 # The DepthwiseConv needs to be performed with the IFM zero point set appropriately so that the correct
1888 # pad values are used. However, in order to use the rounding away from zero mode the zero point needs to
1889 # have been removed so that the zero point is at zero. This is done by adding a kernel sized amount of
1890 # the zero point as a bias. The datatype of the bias needs to be set to int32, even for an int16 IFM,
1891 # because this will cause full precision scaling to be used (see weight compression). Finally, the OFM
1892 # zero point will need forcing to zero (as it has already been removed)
1893 nr_biases = op.inputs[1].shape[-1]
1894 bias_values = [op.ifm.quantization.zero_point * k_h * k_w] * nr_biases
1895 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
1896 op.add_input_tensor(bias_tensor)
1897
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001898 # Add other inputs
1899 op.inputs.extend(other_inputs)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001900
1901 # Bypass the PAD operator
1902 op.set_input_tensor(pad_op.ifm, 0)
1903 # Adjust the padding attributes of the convolution operator
1904 op.attrs["padding"] = Padding.EXPLICIT
1905 op.attrs["explicit_padding"] = (top, left, bottom, right)
1906 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +00001907 DebugDatabase.add_optimised(op, op)
1908
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001909 return op
1910
1911
Rickard Bolinfdbb0722023-09-05 11:38:19 +00001912def convert_mirror_pad(op: Operation, arch, nng):
1913 if op.type != Op.MirrorPad or not op.run_on_npu:
1914 return op
1915
1916 _, (top, bot), (left, right), _ = op.ifm2.values
1917 mode = op.attrs["mode"] # 0 = reflect, 1 = symmetric
1918
1919 ifm = op.ifm
1920 ofm = op.ofm
1921 ofm.ops = []
1922 elem_size = 2 if ofm.dtype == DataType.int16 else 1
1923 n, h, w, c = ifm.shape
1924 _, oh, ow, _ = ofm.shape
1925 # Force linear format on OFM to allow negative stride multipliers
1926 ofm.force_linear_format = True
1927
1928 # Intermediate ofm needed to store ifm padded with top and bot values as input to the left and right padding
1929 intermediate_ofm_tens = Tensor([n, h + top + bot, w, c], ofm.dtype, "intermediate_ofm_tens")
1930 intermediate_ofm_tens.quantization = op.outputs[0].quantization.clone()
1931 intermediate_ofm_tens.force_linear_format = True
1932
1933 # If there is no left or right padding, we can write directly to the ofm without an intermediate tensor
1934 if not (left or right):
1935 intermediate_ofm_tens = ofm
1936
1937 # Initial op to copy the ifm into the middle of the intermediate ofm
1938 avg_pool_init = create_avgpool_nop("init_pool")
1939 avg_pool_init.write_shape = Shape4D(n, h, w, c)
1940 avg_pool_init.write_offset = Shape4D(0, top, 0, 0)
1941 avg_pool_init.read_shapes[0] = Shape4D(n, h, w, c)
1942 avg_pool_init.read_offsets[0] = Shape4D(0, 0, 0, 0)
1943 avg_pool_init.add_input_tensor(ifm)
1944 avg_pool_init.set_output_tensor(intermediate_ofm_tens)
1945 avg_pool_init.set_ifm_ofm_shapes()
1946 DebugDatabase.add_optimised(op, avg_pool_init)
1947
1948 # Create pools with negative stride to mirror edges and offset to write at padding positions
1949 avg_pool_pad = create_avgpool_nop("pad_pool")
1950 for i, pad_amount in enumerate([top, bot, left, right]):
1951 # Clear input from previous cloned op
1952 avg_pool_pad.inputs = []
1953 if not pad_amount:
1954 continue
1955
1956 if i == 0: # top
1957 # Set read and write shape width to full ifm width and height to "top" pad size
1958 avg_pool_pad.write_shape = Shape4D(n, top, w, c)
1959 avg_pool_pad.read_shapes[0] = Shape4D(n, top, w, c)
1960 # Leave read offset as default to read the top chunk of the ifm
1961 # For reflect mode, shift height offset down one step to "skip" the edge
1962 avg_pool_pad.read_offsets[0] = Shape4D(0, 0, 0, 0) if mode == 1 else Shape4D(0, 1, 0, 0)
1963 # Offset the base address of tile 0 to start writing just above the ifm that was copied into the middle of
1964 # the ofm and use negative height striding to mirror the above ifm chunk
1965 avg_pool_pad.tile_base_offsets_ofm[0] = ((top - 1) * w) * c * elem_size
1966 if i == 1: # bot
1967 # Set read and write shape width to full ifm width and height to "bot" pad size
1968 avg_pool_pad.write_shape = Shape4D(n, bot, w, c)
1969 avg_pool_pad.read_shapes[0] = Shape4D(n, bot, w, c)
1970 # Set read offset to read the bottom chunk of the ifm
1971 # For reflect mode, shift height offset up one step to "skip" the edge
1972 avg_pool_pad.read_offsets[0] = Shape4D(0, h - bot, 0, 0) if mode == 1 else Shape4D(0, h - bot - 1, 0, 0)
1973 # Offset the base address of tile 0 to start writing at the very bottom of the ofm and use negative height
1974 # striding to mirror the above ifm chunk
1975 avg_pool_pad.tile_base_offsets_ofm[0] = (oh - 1) * w * c * elem_size
1976 if i == 2: # left
1977 # Set read and write shape height to full intermediate ofm height and width to "left" pad size
1978 avg_pool_pad.write_shape = Shape4D(n, h + top + bot, left, c)
1979 avg_pool_pad.read_shapes[0] = Shape4D(n, h + top + bot, left, c)
1980 # Leave read offset as default to read the leftmost chunk of the intermediate ofm
1981 # For reflect mode, shift width offset one step to the right to "skip" the edge
1982 avg_pool_pad.read_offsets[0] = Shape4D(0, 0, 0, 0) if mode == 1 else Shape4D(0, 0, 1, 0)
1983 # Offset the base address of tile 0 to start writing just left of the intermediate ofm and use negative
1984 # width striding to mirror the above ifm chunk
1985 avg_pool_pad.tile_base_offsets_ofm[0] = (left - 1) * c * elem_size
1986 if i == 3: # right
1987 # Set read and write shape height to full intermediate ofm height and width to "right" pad size
1988 avg_pool_pad.write_shape = Shape4D(n, h + top + bot, right, c)
1989 avg_pool_pad.read_shapes[0] = Shape4D(n, h + top + bot, right, c)
1990 # Set read offset to read the rightmost chunk of the intermediate ofm
1991 # For reflect mode, shift width offset one step to the left to "skip" the edge
1992 avg_pool_pad.read_offsets[0] = Shape4D(0, 0, w - right, 0) if mode == 1 else Shape4D(0, 0, w - right - 1, 0)
1993 # Offset the base address of tile 0 to start writing at the rightmost part of the ofm and use negative
1994 # width striding to mirror the above ifm chunk
1995 avg_pool_pad.tile_base_offsets_ofm[0] = (ow - 1) * c * elem_size
1996
1997 # Write offset (0,0,0,0) for all convs
1998 avg_pool_pad.write_offset = Shape4D(0, 0, 0, 0)
1999
2000 if i in [0, 1]: # negative height stride for top and bot, negative width stride for left and right
2001 avg_pool_pad.ofm_stride_multiplier = [1, -1, 1] # C/H/W
2002 # top and bot reads from ifm and writes to intermediate ofm
2003 avg_pool_pad.add_input_tensor(ifm)
2004 intermediate_ofm_tens.ops.append(avg_pool_pad)
2005 avg_pool_pad.outputs = [intermediate_ofm_tens]
2006 else:
2007 avg_pool_pad.ofm_stride_multiplier = [1, 1, -1] # C/H/W
2008 # left and right reads from intermediate ofm and writes to ofm
2009 avg_pool_pad.add_input_tensor(intermediate_ofm_tens)
2010 ofm.ops.append(avg_pool_pad)
2011 avg_pool_pad.outputs = [ofm]
2012
2013 avg_pool_pad.set_ifm_ofm_shapes()
2014 DebugDatabase.add_optimised(op, avg_pool_pad)
2015
2016 # Clone operation for next padding direction
2017 avg_pool_pad = avg_pool_pad.clone(f"_{i}")
2018
2019 if left or right:
2020 # Copy intermediate ofm into final ofm
2021 avg_pool_final_copy = create_avgpool_nop("avg_pool_final_copy")
2022 avg_pool_final_copy.write_shape = Shape4D(n, h + top + bot, w, c)
2023 avg_pool_final_copy.write_offset = Shape4D(0, 0, left, 0)
2024 avg_pool_final_copy.read_shapes[0] = Shape4D(n, h + top + bot, w, c)
2025 avg_pool_final_copy.read_offsets[0] = Shape4D(0, 0, 0, 0)
2026
2027 avg_pool_final_copy.add_input_tensor(intermediate_ofm_tens)
2028 ofm.ops.append(avg_pool_final_copy)
2029 avg_pool_final_copy.outputs = [ofm]
2030 avg_pool_final_copy.set_ifm_ofm_shapes()
2031 DebugDatabase.add_optimised(op, avg_pool_final_copy)
2032
2033 return op
2034
2035
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002036def convert_pad(op: Operation, arch, nng):
2037 """
2038 Rewrites PAD operator to an average pool that copies the IFM to the OFM
2039 + up to 4 average pool operators that fill the OFM with zeros at the borders.
2040 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
2041 """
2042 if op.type != Op.Pad or not op.run_on_npu:
2043 return op
2044 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
2045
2046 ifm = op.ifm
2047 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01002048 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002049 ofm = op.ofm
2050 assert ofm is not None
2051 ofm.ops = []
2052 ofm_shape = op.ofm_shapes[0]
2053
2054 # Average pool op that copies IFM to the right place inside the OFM
2055 shp0 = Shape4D(0, 0, 0, 0)
2056 shp_top = shp0.with_height(top)
2057 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
2058 avgpool_op.activation = op.activation
2059 quant = ofm.quantization
2060 pad_value = quant.zero_point
2061 # Add operations that fill the borders of the OFM
2062 if top > 0:
2063 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
2064 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00002065 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002066 )
2067 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
2068 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
2069 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
2070 if bottom > 0:
2071 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
2072 zero_tens = create_const_tensor(
2073 op.name + "_bottom",
2074 shape.as_list(),
2075 ofm.dtype,
2076 shape.elements() * [pad_value],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002077 quantization=quant,
2078 )
2079 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
2080 create_avg_pool_for_concat(
2081 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
2082 )
2083 if left > 0:
2084 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
2085 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00002086 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002087 )
2088 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
2089 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
2090 if right > 0:
2091 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
2092 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00002093 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002094 )
2095 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
2096 create_avg_pool_for_concat(
2097 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
2098 )
2099
2100 op.type = Op.ConcatTFLite
2101 return avgpool_op
2102
2103
Raul Farkas66207142023-05-25 11:15:20 +01002104def fixup_bias_tensors(op: Operation, arch, nng, dtype=None) -> Operation:
2105 """Fixup ops that require a bias and don't have one by adding a bias tensor filled with zeros."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002106 if op.type.needs_bias() and op.bias is None:
2107 # Op has no bias, add bias tensor filled with zeros
2108 nr_biases = op.inputs[1].shape[-1]
2109 bias_values = [0] * nr_biases
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02002110 # The DataType of the bias tensor can be explicitly provided or deduced from the ifm
2111 # DataType. Default is int32 bias for 8-bit ifms and int64 for int16 ifms.
2112 # For int16 the selected bias DataType will have an impact on the scaling
2113 # used when encoding the scales and biases later. The default mode will match the
2114 # refence with reduced scaling for int64 bias.
2115 # This means that in cases (in the graph optimiser) where DepthwiseConv2DBias
2116 # is used to emulate average pool int32 bias should be selected for full precision
2117 # int16 scaling.
2118 if dtype is None:
2119 dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32
2120 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values)
Raul Farkas3e7157b2023-05-09 09:09:17 +01002121 bias_index = op.type.info.indices.biases[0]
2122 if bias_index < len(op.inputs):
2123 op.set_input_tensor(bias_tensor, bias_index)
2124 else:
2125 op.add_input_tensor(bias_tensor)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002126
2127 return op
2128
2129
wilisa0146c94772023-02-08 09:56:14 +00002130def detect_asymmetric_weights(op):
2131 # Check all ops (cpu and npu)
2132 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
2133 if op.ifm.dtype in (DataType.int8, DataType.int16):
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01002134 if not np.all(op.weights.quantization.zero_point == 0):
wilisa0146c94772023-02-08 09:56:14 +00002135 print(f"Warning: Op {op.type} '{op.name}' has asymmetric weights.", end=" ")
2136 return True
2137 return False
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01002138
wilisa0146c94772023-02-08 09:56:14 +00002139
Raul Farkas66207142023-05-25 11:15:20 +01002140def fixup_asymmetric_weights(op: Operation, arch, nng) -> Operation:
wilisa0146c94772023-02-08 09:56:14 +00002141 if detect_asymmetric_weights(op):
2142 if op.run_on_npu:
2143 print("Zero points have been adjusted.")
2144 op.weights.quantization.zero_point *= 0
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01002145 return op
2146
2147
wilisa0146c94772023-02-08 09:56:14 +00002148def check_asymmetric_weights(op, arch, nng):
2149 # This function can modify the run_on_npu flag which causes an operator to be placed on the CPU. It is usually only
2150 # set by the supported operator checks. Therefore, it should be run immediately after those checks to avoid the
2151 # possibility of other graph optimiser functions modify the operator (that is later run on the CPU)
2152 if detect_asymmetric_weights(op):
2153 if op.run_on_npu:
2154 print("To run the operator on Ethos-U use the option --force-symmetric-int-weights")
2155 op.run_on_npu = False
2156 return op
2157
2158
2159def fixup_or_check_asymmetric_weights(force_symmetric_int_weights):
2160 if force_symmetric_int_weights:
2161 return fixup_asymmetric_weights
2162 else:
2163 return check_asymmetric_weights
2164
2165
Johan Alfven906c9e82023-05-25 11:18:50 +02002166def convert_squared_difference(op, arch, nng):
2167 if op.type == Op.SquaredDifference and op.run_on_npu:
2168 ifm, ifm2, ofm = op.get_ifm_ifm2_ofm()
2169
2170 identity_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
2171
2172 # All the calculations/parameters same as reference kernel
2173 twice_max_input_scale = np.double(2.0 * max(ifm.quantization.scale_f32, ifm2.quantization.scale_f32))
2174 real_input1_multiplier = np.double(ifm.quantization.scale_f32) / twice_max_input_scale
2175 real_input2_multiplier = np.double(ifm2.quantization.scale_f32) / twice_max_input_scale
2176
2177 left_shift = 0 if op.ifm.dtype == DataType.int16 else 7
2178
2179 real_output_multiplier = (twice_max_input_scale * twice_max_input_scale) / (
2180 np.double((1 << (left_shift * 2)) * ofm.quantization.scale_f32)
2181 )
2182
2183 input1_multiplier, input1_shift = quantise_scale(real_input1_multiplier)
2184 input2_multiplier, input2_shift = quantise_scale(real_input2_multiplier)
2185 output_multiplier, output_shift = quantise_scale(real_output_multiplier)
2186
2187 input1_multiplier_const = create_const_tensor(
2188 op.name + "_input1_multiplier", [1], DataType.int32, [input1_multiplier], quantization=identity_quant
2189 )
2190 input2_multiplier_const = create_const_tensor(
2191 op.name + "_input2_multiplier", [1], DataType.int32, [input2_multiplier], quantization=identity_quant
2192 )
2193 output_multiplier_const = create_const_tensor(
2194 op.name + "_output_multiplier", [1], DataType.int32, [output_multiplier], quantization=identity_quant
2195 )
2196
2197 # Convert ifm to 32 bit
2198 ifm_32bit_shifted = ifm.clone(suffix="_ifm_32bit_shifted", set_unique=True)
2199 ifm_32bit_shifted.dtype = DataType.int32
2200 ifm_32bit_shifted.quantization = identity_quant
2201 cast_op = create_cast_op(op.name + "_ifm_32bit_shifted", ifm, ifm_32bit_shifted)
2202 # Use explicit scaling (multiplier) for the left shift
2203 cast_op.explicit_scaling = ExplicitScaling(False, [0], [1 << left_shift])
2204 DebugDatabase.add_optimised(op, cast_op)
2205
2206 # 32 bit Mul op do not scale the value so the input has to be multiplied with the "multiplier" calculated above
2207 ifm_scaled = ifm.clone(suffix="_scaled", set_unique=True)
2208 ifm_scaled.dtype = DataType.int32
2209 ifm_scaled.quantization = identity_quant
2210 mul_op = Operation(Op.Mul, op.name + "_scaled_input1")
2211 mul_op.add_input_tensor(ifm_32bit_shifted)
2212 mul_op.add_input_tensor(input1_multiplier_const)
2213 mul_op.set_output_tensor(ifm_scaled)
2214 # Use explicit scaling for the shift (multiplier not actually used for int32, but value can not be empty)
2215 mul_op.explicit_scaling = ExplicitScaling(False, [input1_shift], [input1_multiplier])
2216 mul_op.set_ifm_ofm_shapes()
2217 DebugDatabase.add_optimised(op, mul_op)
2218
2219 # Convert ifm2 to 32 bit
2220 ifm2_32bit_shifted = ifm2.clone(suffix="_ifm2_32bit_shifted", set_unique=True)
2221 ifm2_32bit_shifted.dtype = DataType.int32
2222 ifm2_32bit_shifted.quantization = identity_quant
2223 cast_op = create_cast_op(op.name + "_ifm2_32bit_shifted", ifm2, ifm2_32bit_shifted)
2224 # Use explicit scaling (multiplier) for the left shift
2225 cast_op.explicit_scaling = ExplicitScaling(False, [0], [1 << left_shift])
2226 DebugDatabase.add_optimised(op, cast_op)
2227
2228 # 32 bit Mul op do not scale the value so input has to be multiplied with the "multiplier" calculated above
2229 ifm2_scaled = ifm2.clone(suffix="_scaled", set_unique=True)
2230 ifm2_scaled.dtype = DataType.int32
2231 ifm2_scaled.quantization = identity_quant
2232 mul_op = Operation(Op.Mul, op.name + "_scaled_input2")
2233 mul_op.add_input_tensor(ifm2_32bit_shifted)
2234 mul_op.add_input_tensor(input2_multiplier_const)
2235 mul_op.set_output_tensor(ifm2_scaled)
2236 # Use explicit scaling for the shift (multiplier not actually used for int32, but value can not be empty)
2237 mul_op.explicit_scaling = ExplicitScaling(False, [input2_shift], [input2_multiplier])
2238 mul_op.set_ifm_ofm_shapes()
2239 DebugDatabase.add_optimised(op, mul_op)
2240
2241 # Calculate the raw diff
2242 raw_diff = ifm.clone(suffix="_raw_diff", set_unique=True)
2243 raw_diff.dtype = DataType.int32
2244 raw_diff.quantization = None
2245 sub_op = Operation(Op.Sub, op.name + "_raw_diff")
2246 sub_op.add_input_tensor(ifm_scaled)
2247 sub_op.add_input_tensor(ifm2_scaled)
2248 sub_op.set_output_tensor(raw_diff)
2249 sub_op.set_ifm_ofm_shapes()
2250 DebugDatabase.add_optimised(op, sub_op)
2251
2252 # Calculate the squared diff
2253 squared_raw = ifm.clone(suffix="_squared_raw", set_unique=True)
2254 squared_raw.dtype = DataType.int32
2255 squared_raw.quantization = None
2256 mul_op = Operation(Op.Mul, op.name + "_squared_raw")
2257 mul_op.add_input_tensor(raw_diff)
2258 mul_op.add_input_tensor(raw_diff)
2259 mul_op.set_output_tensor(squared_raw)
2260 mul_op.set_ifm_ofm_shapes()
2261 DebugDatabase.add_optimised(op, mul_op)
2262
2263 # 32 bit Mul op do not scale the value so output has to be multiplied with "multiplier" calculated above
2264 op.set_input_tensor(squared_raw, 0)
2265 op.set_input_tensor(output_multiplier_const, 1)
2266 op.type = Op.Mul
2267 # Use explicit scaling for the shift (multiplier not actually used for int32, but value can not be empty)
2268 op.explicit_scaling = ExplicitScaling(False, [output_shift], [output_multiplier])
2269 op.set_ifm_ofm_shapes()
2270 DebugDatabase.add_optimised(op, op)
2271
2272 return op
2273
2274
Rickard Bolina68b82a2023-04-20 15:12:28 +00002275def convert_mean_to_depthwise_conv(op, arch, nng):
Alexander Hansson90c34b52023-05-31 15:03:03 +00002276 """
2277 When h x w <= 4096 When h x w > 4096 there is a need to split into several ops.
2278 Do this by splitting up h and change the read_offset/shape.
2279 Below is an example where ifm is 1x190x64x1
2280 MEAN MEAN
2281 | |-----------------------|----------------------|
2282 DepthwiseConv2DBias 1_DepthwiseConv2DBias 2_DepthwiseConv2DBias 3_DepthwiseConv2DBias
2283 | | | |
2284 MUL |---------ADD-----------| |
2285 | |
2286 |----------------ADD---------------|
2287 |
2288 MUL
2289 1_DepthwiseConv2DBias: read_offset [0, 0, 0, 0]> read_shape [1, 64, 64, 1]>
2290 2_DepthwiseConv2DBias: read_offset [0, 64, 0, 0]> read_shape [1, 64, 64, 1]>
2291 3_DepthwiseConv2DBias: read_offset [0, 128, 0, 0]> read_shape [1, 62, 64, 1]>
2292 """
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002293 if op.type == Op.Mean and op.run_on_npu:
Alexander Hansson90c34b52023-05-31 15:03:03 +00002294 max_kernel_size = 4096
2295 max_height = 64
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002296 inp, axis = op.inputs
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002297 dims = len(inp.shape)
2298 dims_ofm = len(op.ofm.shape)
Alexander Hansson90c34b52023-05-31 15:03:03 +00002299 ofmq = op.ofm.quantization
2300 ifmq = op.ifm.quantization
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002301
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002302 # reduce_axis[i] is true if axis i should be reduced
2303 if axis.shape == []:
2304 reduce_axis = [True if i == axis.values else False for i in range(dims)]
2305 else:
2306 reduce_axis = [True if i in axis.values else False for i in range(dims)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002307
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002308 ifm_shape = inp.shape.copy()
2309 intermediate_shape = op.ofm.shape.copy()
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01002310
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002311 # Fix intermediate_shape when keep_dims is false
2312 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the intermediate_shape should be 1xHx1xC
2313 if dims_ofm < dims:
2314 for i in range(dims):
2315 if reduce_axis[i]:
2316 intermediate_shape.insert(i, 1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002317
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002318 # Reshape to 4D
Alexander Hanssonda8741a2023-06-30 15:41:13 +00002319 reduce_axis = full_shape(4, reduce_axis, False)
2320 ifm_shape = full_shape(4, ifm_shape, 1)
2321 intermediate_shape = full_shape(4, intermediate_shape, 1)
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002322
2323 # If all dimensions to reduce have shape 1, the operation is essentially a memcpy.
2324 # We can then remove the whole op by propagating ofm to previous ops
2325 if not any([reduce_axis[i] and ifm_shape[i] > 1 for i in range(4)]):
2326 op.type = Op.Memcpy
2327 op = bypass_memory_only_ops(op, arch, nng)
2328 return op
2329
Alexander Hanssonda8741a2023-06-30 15:41:13 +00002330 # Support mean over depth-axis by left-shifting the C channel
2331 # From semantics checks we can assume that one of H,W,C has shape 1
2332 if reduce_axis[3] and ifm_shape[3] > 1:
2333 assert 1 in ifm_shape[1:], "Mean reduction over depth channel, but none of H,W,C has shape 1"
2334 # If W=1 reshape NxHx1xC -> NxHxCx1, else reshape Nx1xWxC -> NxWxCx1
2335 idx_to_del = 2 if ifm_shape[2] == 1 else 1
2336
2337 # Delete axis with size 1
2338 del reduce_axis[idx_to_del]
2339 del ifm_shape[idx_to_del]
2340 del intermediate_shape[idx_to_del]
2341
2342 # Add another element to set channel-axis to one
2343 reduce_axis.append(False)
2344 ifm_shape.append(1)
2345 intermediate_shape.append(1)
2346
2347 # Compute kernel sizes for our convolutions
2348 # Batch axis is implicit as it is only supported if batch size is 1.
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002349 h = ifm_shape[1] if reduce_axis[1] else 1
2350 w = ifm_shape[2] if reduce_axis[2] else 1
2351
Alexander Hansson90c34b52023-05-31 15:03:03 +00002352 num_elements_in_axis = h * w
2353
2354 # If one convolution is enough, but height is greater than max kernel height
2355 # reshape from HxW to 1x(HxW)
2356 # This can only be done if the mean is computed over both H and W
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002357 if h > max_height and num_elements_in_axis <= max_kernel_size and reduce_axis[1] and reduce_axis[2]:
2358 ifm_shape = [ifm_shape[0], 1, h * w, ifm_shape[3]]
Alexander Hansson90c34b52023-05-31 15:03:03 +00002359 w = h * w
2360 h = 1
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002361
Alexander Hansson90c34b52023-05-31 15:03:03 +00002362 intermediate_op = None
2363 height_per_conv = min(max_kernel_size // w, h)
2364 height_per_conv = min(height_per_conv, max_height)
2365 num_convs = math.ceil(h / height_per_conv)
2366 convs = list()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002367
Alexander Hansson90c34b52023-05-31 15:03:03 +00002368 for i in range(num_convs):
2369 is_last_op = i == (num_convs - 1)
2370
2371 intermediate_op = op.clone(f"{op.name}_conv_{i}")
2372
2373 intermediate_op.type = Op.DepthwiseConv2DBias
2374
2375 # Set necessary depthwise attributes
2376 intermediate_op.attrs.update(
2377 {
2378 "padding": Padding.VALID,
2379 "stride_h": 1,
2380 "stride_w": 1,
2381 "strides": (1, 1, 1, 1),
2382 "depth_multiplier": 1,
2383 "channel_multiplier": 1,
2384 "dilation_h_factor": 1,
2385 "dilation_w_factor": 1,
2386 "dilation": (1, 1, 1, 1),
2387 }
2388 )
2389
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002390 b, _, _, c = ifm_shape
Alexander Hansson90c34b52023-05-31 15:03:03 +00002391
2392 intermediate_tensor = op.ofm.clone(suffix=f"_conv_sum_{i}", set_unique=True)
2393 intermediate_tensor.dtype = DataType.int32
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002394 intermediate_tensor.shape = intermediate_shape
Alexander Hansson90c34b52023-05-31 15:03:03 +00002395 intermediate_op.set_output_tensor(intermediate_tensor)
2396
2397 # as we have several convs, scaling/rounding must be done after the sum has been calculated
2398 intermediate_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1])
2399
2400 # compute height for the kernel
2401 if is_last_op and h % height_per_conv != 0:
2402 weight_h = h % height_per_conv
2403 else:
2404 weight_h = height_per_conv
2405
2406 # compute ifm read offset and shape for the convolution
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002407 read_shape_h = weight_h if reduce_axis[1] else ifm_shape[1]
2408 read_shape_w = w if reduce_axis[2] else ifm_shape[2]
Alexander Hansson90c34b52023-05-31 15:03:03 +00002409
2410 intermediate_op.read_offsets[0] = Shape4D([0, i * height_per_conv, 0, 0])
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002411 intermediate_op.read_shapes[0] = Shape4D(ifm_shape).with_hw(read_shape_h, read_shape_w)
Alexander Hansson90c34b52023-05-31 15:03:03 +00002412
2413 weight_quant = QuantizationParameters(0, 255, scale_f32=1.0, zero_point=0)
2414 weight_shape = [weight_h, w, c, b]
2415 weight_tensor = create_const_tensor(
2416 f"{intermediate_op.name}_weights",
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002417 weight_shape,
Alexander Hansson90c34b52023-05-31 15:03:03 +00002418 DataType.uint8,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002419 np.ones(weight_shape),
Alexander Hansson90c34b52023-05-31 15:03:03 +00002420 TensorPurpose.Weights,
2421 quantization=weight_quant,
2422 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002423
Alexander Hansson90c34b52023-05-31 15:03:03 +00002424 weights_1D = np.ones(np.prod(weight_shape))
2425 weight_tensor.equivalence_id = create_equivalence_id(tuple(weights_1D))
2426 weight_tensor.value_id = weight_tensor.equivalence_id
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002427
Alexander Hansson90c34b52023-05-31 15:03:03 +00002428 intermediate_op.set_input_tensor(weight_tensor, 1)
Rickard Bolina68b82a2023-04-20 15:12:28 +00002429
Alexander Hansson90c34b52023-05-31 15:03:03 +00002430 dtype = DataType.int64 if intermediate_op.ifm.dtype == DataType.int16 else DataType.int32
2431 bias_values = [0] * c
2432 bias = create_const_tensor(f"{intermediate_op.name}_bias", [c], dtype, bias_values)
2433 bias.equivalence_id = create_equivalence_id(tuple(bias_values))
2434 bias.value_id = bias.equivalence_id
2435 intermediate_op.inputs.append(bias)
2436 intermediate_op.set_ifm_ofm_shapes()
Johan Alfven7b3008a2023-04-13 18:54:47 +02002437
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002438 # We want to avoid reshaping the ifm tensor directly, to not affect other ops
Alexander Hansson90c34b52023-05-31 15:03:03 +00002439 # so we update the shape explicitly for this operation
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002440 intermediate_op.ifm_shapes[0] = Shape4D(ifm_shape)
Rickard Bolina68b82a2023-04-20 15:12:28 +00002441
Alexander Hansson90c34b52023-05-31 15:03:03 +00002442 convs.append(intermediate_op)
2443 DebugDatabase.add_optimised(op, intermediate_op)
2444
2445 # If we have more than one convolution
2446 # We use add operations to accumulate the intermediate tensors
2447 if len(convs) > 1:
2448 prev_add_op = None
2449 idx = 0
2450
2451 while len(convs):
2452 intermediate_tensor = op.ofm.clone(suffix=f"_add_sum_{idx}", set_unique=True)
2453 intermediate_tensor.dtype = DataType.int32
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002454 intermediate_tensor.shape = intermediate_shape
Alexander Hansson90c34b52023-05-31 15:03:03 +00002455
2456 one_scale_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
2457
2458 ifm = convs.pop().ofm
2459 if not prev_add_op:
2460 ifm2 = convs.pop().ofm
2461 else:
2462 ifm2 = prev_add_op.ofm
Alexander Hansson90c34b52023-05-31 15:03:03 +00002463 intermediate_op = create_add(f"{op.name}_add_{idx}", ifm, ifm2, one_scale_quant)
2464 intermediate_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1])
2465 intermediate_op.set_output_tensor(intermediate_tensor)
2466 intermediate_op.set_ifm_ofm_shapes()
2467
2468 prev_add_op = intermediate_op
2469 idx += 1
2470
2471 DebugDatabase.add_optimised(op, intermediate_op)
2472
2473 # Convert the original mean op to our final Mul operation
2474 # Which scales and divides by num_elements_in_axis
2475 op.type = Op.Mul
2476 op.name = f"{op.name}_mul"
2477 op.attrs = {}
2478 op.set_input_tensor(intermediate_op.ofm, 0)
Rickard Bolina68b82a2023-04-20 15:12:28 +00002479
Johan Alfven7b3008a2023-04-13 18:54:47 +02002480 # The multiplier is calculated in the same way as in the reference,
2481 # clamping the shift value at the price of some precision loss.
Johan Alfven7b3008a2023-04-13 18:54:47 +02002482 output_multiplier, output_shift_vela = quantise_scale(np.double(ifmq.scale_f32) / np.double(ofmq.scale_f32))
2483
2484 # Convert to reference representation shift value
2485 output_shift = 31 - output_shift_vela
2486
2487 # Reference calculation
2488 # round_down_log2 same as 63 - CountLeadingZeros(num_elements_in_axis)
2489 shift = round_down_log2(num_elements_in_axis)
2490 shift = min(shift, 32)
2491 shift = min(shift, 31 + output_shift)
2492 output_multiplier = (output_multiplier << shift) // num_elements_in_axis
2493 output_shift = output_shift - shift
2494
2495 # Convert to vela representation shift
2496 output_shift_vela = 31 - output_shift
2497
2498 # For int32 scaling is not supported so instead multiply with the scale
2499 # intermediate * scale -> round and shift.
Alexander Hansson90c34b52023-05-31 15:03:03 +00002500 identity_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
Johan Alfven7b3008a2023-04-13 18:54:47 +02002501 scalar = create_const_tensor(
2502 op.name + "_scalar", [1, 1, 1, 1], DataType.int32, [output_multiplier], quantization=identity_quant
2503 )
Alexander Hansson90c34b52023-05-31 15:03:03 +00002504 op.set_input_tensor(scalar, 1)
2505 op.set_ifm_ofm_shapes()
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002506 op.ofm_shapes[0] = Shape4D(intermediate_shape)
Johan Alfven7b3008a2023-04-13 18:54:47 +02002507
2508 # Reference using TFL rounding for the multiply
Alexander Hansson90c34b52023-05-31 15:03:03 +00002509 op.rounding_mode = RoundingMode.TFLite
Johan Alfven7b3008a2023-04-13 18:54:47 +02002510
2511 # Need to use explicit scaling to get the wanted shift
Alexander Hansson90c34b52023-05-31 15:03:03 +00002512 op.explicit_scaling = ExplicitScaling(False, [output_shift_vela], [1])
2513 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002514 return op
2515
2516
Raul Farkas66207142023-05-25 11:15:20 +01002517def convert_ops_to_lut(op: Operation, arch, nng) -> Operation:
2518 """Convert Exp to 8bit or 16bit LUT to allow for support on NPU."""
Johan Alfvence502732023-04-24 13:35:40 +02002519 if op.type == Op.Exp:
2520 if op.ifm.dtype == DataType.int8:
2521 return create_lut_8bit_op(op, math.exp, "exp")
2522 elif op.ifm.dtype == DataType.int16:
2523 return create_lut_int16_op(op, math.exp, "exp")
2524 else:
2525 # Should already be catched in tflite supported ops
2526 assert False, f"Unsupported data type {op.ifm.dtype} for {op.type}"
2527
Johan Alfven8e525ca2023-05-07 13:12:37 +02002528 if op.type == Op.Rsqrt:
2529 return create_lut_rsqrt_int8_op(op)
2530
Johan Alfvence502732023-04-24 13:35:40 +02002531 return op
2532
2533
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002534def optimise_quantize(op: Operation, arch, nng):
2535
2536 if op.type == Op.Quantize and op.run_on_npu:
2537
2538 ifm, ofm = op.get_ifm_ofm()
2539 input_values = ifm.values
2540
2541 # Guard clause - input not const or no values to quantize
2542 if ifm.ops[0].type != Op.Const or input_values is None:
2543 return op
2544
2545 # Singular val in numpy array, convert to indexable array
2546 if input_values.ndim == 0:
2547 input_values = np.array([input_values])
2548
Fredrik Svedberg11563172022-07-06 14:54:12 +02002549 # requantized int8 to int8 or int16 to int16
2550 if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002551
2552 # scale needs to use double precision to match TFLite reference kernel
2553 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
2554 effective_multiplier, effective_shift = quantise_scale(effective_scale)
2555
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002556 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002557 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002558 input_val = val - ifm.quantization.zero_point
2559
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002560 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
2561 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002562
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002563 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
2564 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002565
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002566 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
2567 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002568
2569 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002570 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002571
2572 quantized_vals = []
2573 for val in input_values:
2574
2575 # Derive quantized value
2576 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002577 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
2578 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002579
2580 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002581 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
2582
2583 # Unsupported data type
2584 else:
2585 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002586
2587 # Make quantize op const and disconnect from parent node
2588
2589 # Remove reference of the current quant op from the parent tensor's consumer list
2590 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
2591
2592 # Clear any references to parent node
2593 op.inputs = []
2594
2595 # Convert this quantize op to const
2596 op.type = Op.Const
2597
2598 return op
2599
2600
Ayaan Masood4965fae2022-06-29 11:30:57 +01002601def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
2602 """Static optimisation for SHAPE operator output value known at compile time"""
2603
2604 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
2605
2606 if op.type == Op.Shape and op.run_on_npu:
2607
2608 ifm, ofm = op.get_ifm_ofm()
2609
2610 if len(ifm.shape) != ofm.shape[0]:
2611 return op
2612
2613 # Remove reference of the current shape op from the parent tensor's consumer list
2614 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
2615
2616 # Clear any references to parent node
2617 op.inputs = []
2618
2619 # Convert this SHAPE op to const
2620 op.type = Op.Const
2621
2622 # Add size calculation to shape output tensors
2623 ofm.values = np.array(ifm.shape)
2624
2625 return op
2626
2627
Johan Gunnarsson24570f02023-08-29 15:33:10 +02002628def fixup_pool_strides(op: Operation, arch, nng):
Johan Gunnarssonb4e804b2023-09-07 12:43:49 +02002629 """Fixup Pool strides when the kernel size, IFM shape and stride are equal. Then stride can be changed
2630 to (1, 1) and padding can be changed to VALID, so the strides are within the limits for the NPU."""
Johan Gunnarsson7ccc5832023-09-07 12:28:28 +02002631 if op.type in (Op.AvgPool, Op.MaxPool, Op.QuantizedAvgPool, Op.QuantizedMaxPool):
Johan Gunnarsson24570f02023-08-29 15:33:10 +02002632 ifm, _ = op.get_ifm_ofm()
2633 kernel_w, kernel_h = op.get_kernel_size()
Johan Gunnarssonb4e804b2023-09-07 12:43:49 +02002634 stride_w, stride_h = op.get_kernel_stride()
2635 if kernel_w == stride_w == ifm.shape[2] and kernel_h == stride_h == ifm.shape[1]:
2636 if "strides" in op.attrs:
2637 stride_n, _, _, stride_c = op.attrs["strides"]
2638 op.attrs["strides"] = (stride_n, 1, 1, stride_c)
Johan Gunnarsson24570f02023-08-29 15:33:10 +02002639 op.attrs["stride_w"] = 1
2640 op.attrs["stride_h"] = 1
Johan Gunnarssonb4e804b2023-09-07 12:43:49 +02002641 op.attrs["padding"] = Padding.VALID
Johan Gunnarsson24570f02023-08-29 15:33:10 +02002642
2643 return op
2644
2645
Raul Farkas66207142023-05-25 11:15:20 +01002646def fixup_dilation_gt2(op: Operation, arch, nng) -> Operation:
2647 """Fixup Conv2DBias and DepthwiseConv2DBias to allow dilation greater than 2."""
Tim Hallea4ba662022-11-11 18:19:53 +00002648 assert op.run_on_npu
2649 if op.type == Op.Conv2DBias or op.type == Op.DepthwiseConv2DBias:
2650 dilation_w, dilation_h = op.get_kernel_dilation()
2651
2652 # if dilation in either axis is greater than that supported by the hardware then we must manually dilate the
2653 # kernel
2654 if dilation_w > 2 or dilation_h > 2:
2655 kernel_w, kernel_h = op.get_kernel_size()
2656 kernel_ic = op.weights.shape[-2]
2657 kernel_oc = op.weights.shape[-1]
2658
2659 # if the dilation is a multiple of 2 then the hardware dialtion can be enabled to provide that multiple
2660 # of 2. this allows the kernel size to be reduced (via the scaled dilation) by half in that dimension.
2661 # odd = 1, even = 2
2662 hw_dilation_h = 1 if (dilation_h & 1) else 2
2663 hw_dilation_w = 1 if (dilation_w & 1) else 2
2664
2665 scale_dilation_h = dilation_h // hw_dilation_h
2666 scale_dilation_w = dilation_w // hw_dilation_w
2667
2668 # create new empty kernel (HWIO format)
2669 new_kernel_h = (kernel_h - 1) * scale_dilation_h + 1
2670 new_kernel_w = (kernel_w - 1) * scale_dilation_w + 1
2671
2672 new_kernel_shape = [new_kernel_h, new_kernel_w, kernel_ic, kernel_oc]
2673 new_kernel_values = np.zeros(new_kernel_shape, dtype=op.weights.values.dtype)
2674
2675 # copy the original kernel values into the new sparse kernel
2676 for h in range(0, kernel_h):
2677 for w in range(0, kernel_w):
2678 new_h = h * scale_dilation_h
2679 new_w = w * scale_dilation_w
2680 new_kernel_values[new_h, new_w, :, :] = op.weights.values[h, w, :, :]
2681
2682 # update the weight tensor with the new dilated kernel
2683 op.weights.shape = new_kernel_shape
2684 op.weights.values = new_kernel_values
2685
2686 # enable(=2) / disable(=1) hardware dilation
2687 op.attrs["dilation"] = (1, hw_dilation_h, hw_dilation_w, 1) # nhwc format
2688 op.attrs["dilation_h_factor"] = hw_dilation_h
2689 op.attrs["dilation_w_factor"] = hw_dilation_w
2690
2691 return op
2692
2693
Johan Alfvena8fda882023-10-28 16:04:46 +02002694def fixup_transpose(op, arch, nng):
2695 """
2696 Convert Transpose to AvgPool where the strides for height and width is swapped on the OFM
2697 in order to achieve the transpose. It is only possible to swap height and width on the op.
2698
2699 Shape (2,3) transposed to Shape (3,2)
2700 |0|1|2| ifm_stride_w = 1 |0|3| ofm_stride_w = 1
2701 |4|5|6| ifm_stride_h = 3 |1|4| ofm_stride_h = 2
2702 |2|5|
2703
2704 To achieve the above with the AvgPool, the ofm_shape must be set equal to the ifm_shape.
2705 The reason is that AvgPool uses the ofm shape when looping over the memory. So if the
2706 ofm shape is not equal to the ifm shape the full ifm will not be read.
2707 When looping over the values the following formula is used:
2708
2709 IFM [h_pos, w_pos] = h_pos * ifm_stride_h + w_pos * ifm_stride_w
2710 OFM [h_pos, w_pos] = h_pos * ofm_stride_w + w_pos * ofm_stride_h (stride has been swapped)
2711
2712 Below code changes op to an AvgPool and sets the correct shapes. The actual stride swap
2713 is done when creating the ofm featuremap. As seen there are several corner cases
2714 when it is possible to transpose the depth channel.
2715 """
2716 if op.type == Op.Transpose:
2717 op.name = f"{op.name}_avgpool"
2718 op.type = Op.AvgPool
2719 op.attrs["padding"] = Padding.VALID
2720 op.attrs["stride_w"] = 1
2721 op.attrs["stride_h"] = 1
2722 op.attrs["filter_width"] = 1
2723 op.attrs["filter_height"] = 1
2724 op.attrs["strides"] = [1, 1, 1, 1]
2725 op.attrs["ksize"] = [1, 1, 1, 1]
2726 # Swapping strides only works in linear format (ofm)
2727 op.ofm.force_linear_format = True
2728
2729 # Convert IFM to correct 4D shape
2730 perm = op.inputs[1]
2731 ifm_shape = op.ifm.shape
2732
2733 # IFM rank 2 case
2734 if len(ifm_shape) == 2:
2735 # IFM shape: WxC -> 1xWxCx1
2736 op.ifm_shapes[0] = Shape4D([1, ifm_shape[0], ifm_shape[1], 1])
2737
2738 # IFM rank 3 cases
2739 elif len(ifm_shape) == 3:
2740 # Check if HxWxC -> WxHxC
2741 if perm.values[0] == 1 and perm.values[1] == 0:
2742 # IFM shape: HxWxC -> 1xHxWxC
2743 op.ifm_shapes[0] = Shape4D([1, ifm_shape[0], ifm_shape[1], ifm_shape[2]])
2744
2745 # Check if 1xWxC -> 1xCxW
2746 elif ifm_shape[0] == 1 and perm.values[1] == 2 and perm.values[2] == 1:
2747 # IFM shape: 1xWxC -> 1xWxCx1
2748 op.ifm_shapes[0] = Shape4D([1, ifm_shape[1], ifm_shape[2], 1])
2749
2750 # Check if Hx1xC -> Cx1xH
2751 elif ifm_shape[1] == 1 and perm.values[0] == 2 and perm.values[2] == 0:
2752 # IFM shape: Hx1xC -> 1xHxCx1
2753 op.ifm_shapes[0] = Shape4D([1, ifm_shape[0], ifm_shape[2], 1])
2754
2755 # IFM rank 4 cases
2756 elif len(ifm_shape) == 4:
2757 # Check if 1xHxWxC -> 1xWxHxC
2758 if perm.values[1] == 2 and perm.values[2] == 1:
2759 # IFM shape is correct
2760 pass
2761
2762 # Check if 1x1xWxC -> 1x1xCxW
2763 elif ifm_shape[1] == 1 and perm.values[2] == 3 and perm.values[3] == 2:
2764 # IFM shape: 1x1xWxC -> 1xWxCx1
2765 op.ifm_shapes[0] = Shape4D([1, ifm_shape[2], ifm_shape[3], 1])
2766
2767 # Check if 1xHx1xC -> 1xCx1xH
2768 elif ifm_shape[2] == 1 and perm.values[1] == 3 and perm.values[3] == 1:
2769 # IFM shape: 1xHx1xC -> 1xHxCx1
2770 op.ifm_shapes[0] = Shape4D([1, ifm_shape[1], ifm_shape[3], 1])
2771
2772 # OFM shape must use IFM shape
2773 op.ofm_shapes[0] = op.ifm_shapes[0]
2774
2775 DebugDatabase.add_optimised(op, op)
2776
2777 return op
2778
2779
Tim Hall2180a172023-03-10 18:11:34 +00002780def fixup_reshape(op, arch, nng):
2781 def _get_explicit_shape(implicit_shape, total_size):
2782 # the explicit shape is a copy of the implicit shape but with the special -1 (remaining size) value converted to
2783 # the appropriate value
2784 if implicit_shape is None:
2785 return None
2786
2787 explicit_shape = list(implicit_shape)
2788 if -1 in explicit_shape:
2789 explicit_shape[explicit_shape.index(-1)] = int(total_size / abs(np.prod(implicit_shape)))
2790
2791 return explicit_shape
2792
2793 if op.type == Op.Reshape:
2794 ifm_tensor, _, ofm_tensor = op.get_ifm_ifm2_ofm()
2795 ifm_size = ifm_tensor.elements()
2796 ofm_shape = ofm_tensor.shape
2797
2798 new_shape_tensor_shape = op.inputs[1].values.flatten() if len(op.inputs) > 1 else None
2799 new_shape_tensor_shape = _get_explicit_shape(new_shape_tensor_shape, ifm_size)
2800
2801 new_shape_attribute = op.attrs.get("new_shape", None)
2802 new_shape_attribute = _get_explicit_shape(new_shape_attribute, ifm_size)
2803
2804 # if present the new shape tensor overrides the new_shape attribute
2805 if new_shape_tensor_shape is not None:
2806 # check tensor
2807 if not np.array_equal(new_shape_tensor_shape, ofm_shape):
2808 print(
2809 f"Warning: {optype_to_builtintype(op.type)} '{op.name}' has new shape tensor"
2810 f" ({new_shape_tensor_shape}) that does not match output tensor shape {ofm_shape}. Will use output"
2811 f" tensor shape."
2812 )
2813 elif new_shape_attribute is not None:
2814 # check attribute
2815 if not np.array_equal(new_shape_attribute, ofm_shape):
2816 print(
2817 f"Warning: {optype_to_builtintype(op.type)} '{op.name}' has new_shape attribute"
2818 f" ({new_shape_attribute}) that does not match output tensor shape {ofm_shape}. Will use output"
2819 f" tensor shape."
2820 )
2821 else:
2822 print(
2823 f"Warning: {optype_to_builtintype(op.type)} '{op.name}' does not have a new shape tensor or a new_shape"
2824 f" attribute. Will use output tensor shape {ofm_shape}."
2825 )
2826
2827 # force new shape tensor to output shape
2828 new_shape_tensor = create_const_tensor(
2829 op.name + "_new_shape", [len(ofm_shape)], DataType.int32, np.array(ofm_shape, np.int32)
2830 )
2831 if len(op.inputs) > 1:
2832 op.set_input_tensor(new_shape_tensor, 1)
2833 else:
2834 op.add_input_tensor(new_shape_tensor)
2835
2836 # force new_shape attribute to output shape
2837 op.attrs["new_shape"] = ofm_shape
2838
2839 return op
2840
2841
Tim Hall9cf63a32023-06-27 12:07:49 +01002842def convert_conv_groups(op: Operation, arch, nng):
2843 """
2844 Convert convolution groups to a split followed by separate convolutions and then a concat.
2845 This needs to run before the concat and split handling functions"""
2846 if not op.type.is_conv2d_op():
2847 return op
2848
2849 num_conv_groups = op.attrs.get("num_conv_groups", 0)
2850 if num_conv_groups > 1:
2851 # convolution groups params
2852 ifm_depth_cg = op.ifm.shape[-1] // num_conv_groups
2853 num_filters_cg = op.weights.shape[-1] // num_conv_groups
2854
2855 # create split
2856 split_op = Operation(Op.Split, f"{op.name}_split")
2857 split_op.attrs.update(
2858 {
2859 "num_splits": num_conv_groups,
2860 }
2861 )
2862 # first input is the split axis
2863 split_op.add_input_tensor(
2864 # split along the depth axis
2865 create_const_tensor(f"{split_op.name}_axis", [0], DataType.int32, [-1])
2866 )
2867 # second input is the ifm
2868 split_op.add_input_tensor(op.ifm)
2869 # calculate shape of each ofm part
2870 split_op_ofm_shape = op.ifm.shape[:-1] + [ifm_depth_cg]
2871
2872 # create concat. do this prior to each conv group so that the for-loop can reference the concat as it iterates
2873 concat_op = Operation(Op.ConcatTFLite, f"{op.name}_concat")
2874 concat_op.attrs.update(
2875 {
2876 "axis": -1,
2877 "fused_activation_function": None,
2878 }
2879 )
2880 # calculate shape of each ifm part
2881 concat_op_ifm_shape = op.ofm.shape[:-1] + [num_filters_cg]
2882 # output is the concatenated tensor
2883 concat_op.set_output_tensor(op.ofm) # will disconnect ofm from op
2884
2885 # for each conv group
2886 for i in range(num_conv_groups):
2887 # cg params
2888 cg_oc_start = i * num_filters_cg
2889 cg_oc_end = (i + 1) * num_filters_cg
2890
2891 # split has multiple outputs
2892 split_op_ofm_part = Tensor(split_op_ofm_shape, op.ifm.dtype, f"{split_op.name}_out{i}")
2893 split_op_ofm_part.quantization = op.ifm.quantization.clone()
2894 split_op.add_output_tensor(split_op_ofm_part)
2895
2896 # concat has multiple inputs
2897 concat_op_ifm_part = Tensor(concat_op_ifm_shape, op.ifm.dtype, f"{concat_op.name}_in{i}")
2898 concat_op_ifm_part.quantization = op.ofm.quantization.clone()
2899 concat_op.add_input_tensor(concat_op_ifm_part)
2900
2901 # create convolution group operator
2902 conv_group_op = Operation(op.type, f"{op.name}_cg{i}")
2903 conv_group_op.attrs = op.attrs.copy()
2904 conv_group_op.attrs["num_conv_groups"] = 1
2905 # first input is the ifm
2906 conv_group_op.add_input_tensor(split_op_ofm_part)
2907 # second input is weights. the number of filters (i.e. the output channels) need to be split equally
2908 # across all of the convolution groups
2909 conv_group_op_weights_shape = op.weights.shape[:-1] + [num_filters_cg]
2910 conv_group_op_weights_quant = op.weights.quantization.clone()
2911 conv_group_op_weights_quant.scale_f32 = op.weights.quantization.scale_f32[..., cg_oc_start:cg_oc_end]
2912 conv_group_op_weights_quant.zero_point = op.weights.quantization.zero_point[..., cg_oc_start:cg_oc_end]
2913 conv_group_op.add_input_tensor(
2914 create_const_tensor(
2915 f"{op.weights.name}_cg{i}",
2916 conv_group_op_weights_shape,
2917 op.weights.dtype,
2918 op.weights.values[..., cg_oc_start:cg_oc_end],
2919 op.weights.purpose,
2920 conv_group_op_weights_quant,
2921 )
2922 )
2923 # third input is bias. like the weights, the bias needs to be split equally across all of the convolution
2924 # groups
2925 if op.bias is None:
2926 conv_group_op.add_input_tensor(None)
2927 else:
2928 conv_group_op_bias_shape = op.bias.shape[:-1] + [num_filters_cg]
2929 conv_group_op_bias_quant = op.bias.quantization.clone()
2930 conv_group_op_bias_quant.scale_f32 = op.bias.quantization.scale_f32[..., cg_oc_start:cg_oc_end]
2931 conv_group_op_bias_quant.zero_point = op.bias.quantization.zero_point[..., cg_oc_start:cg_oc_end]
2932 conv_group_op.add_input_tensor(
2933 create_const_tensor(
2934 f"{op.bias.name}_cg{i}",
2935 conv_group_op_bias_shape,
2936 op.bias.dtype,
2937 op.bias.values[..., cg_oc_start:cg_oc_end],
2938 op.bias.purpose,
2939 op.bias.quantization,
2940 )
2941 )
2942 # output goes to the concat
2943 conv_group_op.set_output_tensor(concat_op_ifm_part)
2944 # update the cg op shapes and debug db
2945 conv_group_op.set_ifm_ofm_shapes()
2946 DebugDatabase.add_optimised(op, conv_group_op)
2947
2948 # update the split/concat op shapes/debug db
2949 split_op.set_ifm_ofm_shapes()
2950 DebugDatabase.add_optimised(op, split_op)
2951 concat_op.set_ifm_ofm_shapes()
2952 DebugDatabase.add_optimised(op, concat_op)
2953
2954 # disconnect the original convolution operator.
2955 # the ofm has already been disconnected by concat_op.set_output_tensor()
2956 op.ifm.consumer_list.remove(op)
2957 op.inputs = []
2958 op.outputs = []
2959
2960 # return last op so that other graph optimiser functions can process the new operators
2961 op = concat_op
2962
2963 return op
2964
2965
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002966def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02002967 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002968 return op
2969
2970
wilisa0146c94772023-02-08 09:56:14 +00002971def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
Fredrik Svedberg11563172022-07-06 14:54:12 +02002972 # Compile time static optimisations
wilisa0146c94772023-02-08 09:56:14 +00002973 optimisation_list = [
2974 optimise_quantize,
2975 convert_shape_op_to_constant_tensor,
2976 fixup_or_check_asymmetric_weights(force_symmetric_int_weights),
Johan Gunnarsson24570f02023-08-29 15:33:10 +02002977 fixup_pool_strides,
wilisa0146c94772023-02-08 09:56:14 +00002978 ]
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002979
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002980 for idx, sg in enumerate(nng.subgraphs):
2981 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002982 nng,
2983 sg,
2984 arch,
2985 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01002986 optimisation_list,
2987 rewrite_unsupported=False,
2988 )
2989
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002990 # Pre-processing step
Tim Hall9cf63a32023-06-27 12:07:49 +01002991 pre_process_list = [supported_operator_check, set_ifm_ofm_op_shapes, fixup_reshape, convert_conv_groups]
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002992
Ayaan Masood4965fae2022-06-29 11:30:57 +01002993 for idx, sg in enumerate(nng.subgraphs):
2994 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
2995 nng,
2996 sg,
2997 arch,
2998 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02002999 pre_process_list,
3000 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02003001 )
3002
3003 # Handle Concat Ops
3004 for idx, sg in enumerate(nng.subgraphs):
3005 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
3006 sg.refresh_after_modification()
3007
3008 # Handle Split Ops
3009 for idx, sg in enumerate(nng.subgraphs):
3010 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
3011 nng,
3012 sg,
3013 arch,
3014 [],
3015 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
3016 rewrite_unsupported=False,
3017 )
3018
3019 for idx, sg in enumerate(nng.subgraphs):
3020 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02003021 nng,
3022 sg,
3023 arch,
3024 [rewrite_split_ops],
3025 [],
3026 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02003027 )
3028
Johan Alfvena5e1b622023-02-02 14:59:03 +01003029 # Bypass or rewrite memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02003030 for idx, sg in enumerate(nng.subgraphs):
3031 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02003032 nng,
3033 sg,
3034 arch,
3035 [],
Johan Alfvena5e1b622023-02-02 14:59:03 +01003036 [bypass_memory_only_ops],
Jonas Ohlssond8575072022-03-30 10:30:25 +02003037 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02003038 )
3039
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02003040 # Rewrite of operators
3041 op_rewrite_list = [
3042 set_tensor_equivalence,
Johan Alfvence502732023-04-24 13:35:40 +02003043 convert_ops_to_lut,
Johan Alfven906c9e82023-05-25 11:18:50 +02003044 convert_squared_difference,
Rickard Bolina68b82a2023-04-20 15:12:28 +00003045 convert_mean_to_depthwise_conv,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02003046 convert_depthwise_to_conv,
3047 convert_conv_to_fc,
Fredrik Svedberg0ac08042023-04-11 22:35:04 +02003048 convert_lstm,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02003049 convert_softmax,
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02003050 convert_prelu,
Fredrik Svedberg36424312022-09-16 09:39:26 +02003051 convert_mul_max_to_abs_or_lrelu,
3052 convert_lrelu,
Raul Farkas3e7157b2023-05-09 09:09:17 +01003053 convert_avg_pool_to_conv2d,
Rickard Bolinfdbb0722023-09-05 11:38:19 +00003054 convert_mirror_pad,
Raul Farkas69782af2023-05-09 10:39:52 +01003055 fixup_strided_conv,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02003056 convert_hardswish_to_lut,
3057 rewrite_fully_connected_input,
3058 convert_batched_fc_shape,
3059 fixup_conv2d_backprop,
3060 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02003061 reorder_depthwise_weights,
Rickard Bolin6986a072022-12-19 12:33:40 +00003062 convert_argmax_to_depthwise_conv_and_max_pool,
Tim Hall885033b2022-07-21 11:46:03 +01003063 fixup_resize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02003064 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01003065 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02003066 convert_tanh_sigmoid_to_lut,
Johan Gunnarsson98556372023-08-10 13:10:44 +02003067 convert_quantize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02003068 replace_pad_by_hw_pad,
Tim Hallea4ba662022-11-11 18:19:53 +00003069 fixup_dilation_gt2,
Johan Alfvena8fda882023-10-28 16:04:46 +02003070 fixup_transpose,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02003071 ]
3072
3073 for idx, sg in enumerate(nng.subgraphs):
3074 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02003075 nng,
3076 sg,
3077 arch,
3078 [],
3079 op_rewrite_list,
3080 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02003081 )
3082
3083 for idx, sg in enumerate(nng.subgraphs):
3084 # remove passthrough tensors and attempt further optimizations
3085 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
3086 nng,
3087 sg,
3088 arch,
3089 [remove_passthrough_tensor],
3090 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
3091 )
3092
3093 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
3094 # since ifm/ofm_shapes are of importance to this function
3095 for sg in nng.subgraphs:
3096 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
3097 sg.refresh_after_modification()
3098
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01003099 # Make sure that const optimisations on subgraph outputs are handled correctly
3100 for sg in nng.subgraphs:
3101 for ofm in sg.output_tensors:
3102 if ofm.is_const and ofm.ops[0].type_changed:
3103 # Subgraph output cannot be const - insert a memory copy
3104 op = ofm.ops[0]
3105 ofm_clone = ofm.clone()
3106 ofm_clone.values = ofm.values
3107 ofm.values = None
Tim Hall3b1578e2023-01-13 17:57:25 +00003108 zero = create_const_tensor("zero", [1], ofm.dtype, [0], quantization=ofm.quantization)
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01003109 memcpy = create_add_nop(f"{ofm.name}_copy")
3110 memcpy.add_input_tensor(ofm_clone)
3111 memcpy.add_input_tensor(zero)
3112 memcpy.set_output_tensor(ofm)
3113 memcpy.set_ifm_ofm_shapes()
3114 op.set_output_tensor(ofm_clone)
3115 DebugDatabase.add_optimised(op, memcpy)
3116
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02003117 return nng