blob: 28dead109cdc26cb694a6aafff6dbe4ff3a16e99 [file] [log] [blame]
Tim Hall3b1578e2023-01-13 17:57:25 +00001# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020017# Description:
18# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
19# to do the traversal of the graph.
Raul Farkas10d6b3b2023-01-30 12:58:46 +000020from __future__ import annotations
21
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020022import math
23import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020024
25import numpy as np
26
27from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020028from . import rewrite_graph
29from . import scaling
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020030from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020031from .data_type import DataType
32from .debug_database import DebugDatabase
33from .errors import UnsupportedFeatureError
34from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020035from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020036from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020037from .graph_optimiser_util import convert_depthwise_to_conv
Fredrik Svedberg0ac08042023-04-11 22:35:04 +020038from .graph_optimiser_util import create_avg_pool_for_concat
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020039from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020040from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020041from .graph_optimiser_util import needed_total_padding
42from .graph_optimiser_util import set_ifm_ofm_op_shapes
43from .graph_optimiser_util import set_tensor_equivalence
Fredrik Svedberg0ac08042023-04-11 22:35:04 +020044from .lstm import Lstm
Johan Alfvence502732023-04-24 13:35:40 +020045from .lut import convert_to_lut
46from .lut import create_lut_8bit_op
47from .lut import create_lut_int16_op
Johan Alfven8e525ca2023-05-07 13:12:37 +020048from .lut import create_lut_rsqrt_int8_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020049from .numeric_util import clamp_sigmoid
Johan Alfven56811e62023-03-27 11:33:50 +020050from .numeric_util import full_shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020051from .numeric_util import round_away_zero
Johan Alfven7b3008a2023-04-13 18:54:47 +020052from .numeric_util import round_down_log2
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020053from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020054from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020055from .operation import NpuBlockType
56from .operation import Op
57from .operation import Operation
58from .operation import Padding
Tim Hall5ff4cd12023-05-16 22:39:14 +010059from .operation import RoundingMode
Alexander Hansson90c34b52023-05-31 15:03:03 +000060from .operation_util import create_add
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +010061from .operation_util import create_add_nop
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020062from .operation_util import create_avgpool_nop
Johan Alfvenc1ad80b2023-03-31 10:19:23 +020063from .operation_util import create_cast_op
Rickard Bolin6986a072022-12-19 12:33:40 +000064from .operation_util import create_depthwise_maxpool
Johan Alfvenc1ad80b2023-03-31 10:19:23 +020065from .operation_util import create_memcpy
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020066from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010067from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020068from .shape4d import Shape4D
69from .softmax import SoftMax
70from .tensor import check_quantized_tens_scaling_equal
71from .tensor import create_const_tensor
72from .tensor import create_equivalence_id
73from .tensor import QuantizationParameters
74from .tensor import Tensor
75from .tensor import TensorPurpose
76from .tflite_mapping import optype_to_builtintype
Raul Farkas3b64f062023-05-16 17:18:31 +010077from .utils import calc_resize_factor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020078
79passthrough_nodes = (Op.Identity,)
80
81
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020082def remove_passthrough_tensor(tens, arch, nng):
83 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
84 assert len(tens.ops[0].inputs) == 1
85 tens = tens.ops[0].inputs[0]
86 return tens
87
88
89def rewrite_concat_ops(op, arch):
90 if not op.run_on_npu or not op.type.is_concat_op():
91 return
92
93 axis_4D = 0
94 ofm = op.ofm
95 ofm.ops = []
96 offset = 0
97
98 unfuse_activation_function(op)
99
100 if op.type == Op.Pack:
101 # Pack is also referred to as Stack
102 axis = int(op.attrs["axis"])
103 if axis < 0: # Convert to positive axis
104 axis = len(op.inputs[0].shape) + 1 + axis
105
106 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
107
108 axis_4D = axis + (4 - len(desired_shape))
109
110 for idx, inp in enumerate(op.inputs):
111 op.ifm_shapes[idx] = Shape4D(desired_shape)
112 op.type = Op.PackReshaped
113
114 inputs, axis = op.get_concat_inputs_axis()
115 for idx, inp in enumerate(inputs):
116 if op.type != Op.PackReshaped:
117 op.ifm_shapes[idx] = Shape4D(inp.shape)
118 if axis >= 0:
119 axis_4D = axis + (4 - len(inp.shape))
120 else:
121 axis_4D = axis
122 write_offset = [0, 0, 0, 0]
123 write_offset[axis_4D] = offset
124 concat_end = offset + op.ifm_shapes[idx][axis_4D]
125 create_avg_pool_for_concat(
126 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
127 )
128 offset = concat_end
129 assert ofm.shape[axis] == offset
130
131 return op
132
133
134def rewrite_split_ops(tens, arch, nng):
135
136 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
137 split_op = tens.ops[0]
138
139 # Not supported so leave it and run on CPU
140 if not split_op.run_on_npu:
141 return tens
142
143 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
144
145 tens.ops = []
146 new_op = Operation(Op.SplitSliceRead, split_op.name)
147 new_op.inputs = [inp]
148 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000149 if None in (offset_end, offset_start):
150 read_shape = None
151 else:
152 # the read shape is relative to each start offset
153 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200154
155 # For Split the offset cannot be extracted from the tensor so it has to
156 # be calculated from the index of the output tensor
157 if axis is not None:
158 # Get the start and end of the split
159 offset_start = [0] * 4
160 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
161 for idx, out in enumerate(outputs):
162 if axis_4D_list is not None:
163 axis_4D = axis_4D_list[idx]
164 else:
165 split_op.ofm_shapes[idx] = Shape4D(out.shape)
166 if axis >= 0:
167 axis_4D = axis + (4 - len(out.shape))
168 else:
169 axis_4D = axis
170
171 if out == tens:
172 ofm_shape_idx = idx
173 read_shape = split_op.ofm_shapes[idx]
174 break
175
176 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
177
178 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
179 new_op.read_shapes[0] = read_shape
180 new_op.run_on_npu = True
181 new_op.set_output_tensor(tens)
182 new_op.ifm_shapes.append(Shape4D(inp.shape))
183 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
184 DebugDatabase.add_optimised(split_op, new_op)
185
186 return tens
187
188
189def remove_SplitSliceRead(op, arch):
190
191 if op.type == Op.SplitSliceRead:
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200192 # Check if it is possible to put the SplitSliceRead on the tensor consumer(s),
193 # or if an avgpool need to be inserted
194 if op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape) and all(
195 consumer is not None and consumer.run_on_npu and consumer.type not in memory_only_ops
196 for consumer in op.ofm.consumer_list
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200197 ):
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200198 # SplitSliceRead can be performed by tensor consumer(s)
199 for cons_op in list(op.ofm.consumer_list):
200 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200201 else:
202 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
203 avgpool_op.add_input_tensor(op.ifm)
204 avgpool_op.outputs = [op.ofm]
205 op.ofm.ops.remove(op)
206 op.ofm.ops.append(avgpool_op)
207 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
208 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
209 avgpool_op.read_offsets[0] = op.read_offsets[0]
210 avgpool_op.read_shapes[0] = op.read_shapes[0]
211
212 op.ifm.consumer_list.remove(op)
213 DebugDatabase.add_optimised(op, avgpool_op)
214
215
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200216def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
217 k_w, k_h = kernel.dilated_wh()
218 s_x, s_y = kernel.stride
219 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
220 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
221 if padding_type == Padding.SAME:
222 left_pad = (xpad + 0) // 2
223 right_pad = (xpad + 1) // 2
224 top_pad = (ypad + 0) // 2
225 bottom_pad = (ypad + 1) // 2
226 elif padding_type == Padding.VALID:
227 left_pad = 0
228 right_pad = 0
229 top_pad = 0
230 bottom_pad = 0
231 elif padding_type == Padding.EXPLICIT:
232 # Padding is specified in a PAD operator which has been bypassed.
233 top, left, bottom, right = explicit_padding
234 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
235 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Rickard Bolin9ae34552022-06-09 13:07:17 +0000236 elif padding_type == Padding.TILE:
237 # The values in the explicit padding only represent the "direction" in which to pad
238 top_pad, left_pad, bottom_pad, right_pad = explicit_padding
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200239 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000240 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200241 padding = (top_pad, left_pad, bottom_pad, right_pad)
242 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
243 return padding, skirt
244
245
246def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
247 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
248 if padding_type == Padding.SAME:
249 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
250 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
251 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
252 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
253 left_pad = max(kernel_width - 1 - right_pad, 0)
254 top_pad = max(kernel_height - 1 - bottom_pad, 0)
255 elif padding_type == Padding.VALID:
256 right_pad = max(kernel_width - 2, 0)
257 bottom_pad = max(kernel_height - 2, 0)
258 left_pad = kernel_width - 1
259 top_pad = kernel_height - 1
260 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000261 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200262 padding = (top_pad, left_pad, bottom_pad, right_pad)
263 skirt = padding
264 return padding, skirt
265
266
Raul Farkas66207142023-05-25 11:15:20 +0100267def fixup_conv2d_backprop(op: Operation, arch, nng) -> Operation:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200268 if op.type == Op.Conv2DBackpropInput:
269 # flip the inputs
270 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
271 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000272 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200273
274 # Update strides
275 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
wilisa0179a89042022-11-02 17:18:43 +0000276 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200277
278 return op
279
280
281# Convert the op to an elementwise add
Tim Hall885033b2022-07-21 11:46:03 +0100282def convert_resize_1x1_to_add(op):
283 op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200284 op.name = op.name + "_add"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200285 # Create an input tensor filled with zeros
wilisa018289d512023-01-12 08:17:23 +0000286 name = op.inputs[1].name + "_add"
287 dtype = op.inputs[0].dtype
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200288 shape = op.ofm_shapes[0].as_list()
wilisa018289d512023-01-12 08:17:23 +0000289 values = np.zeros(shape, dtype.as_numpy_type())
290 quantization = QuantizationParameters(0.0, 255.0)
291 quantization.scale_f32 = 1.0
292 quantization.zero_point = 0
wilisa0116b5e5e2023-02-14 12:03:59 +0000293 op.inputs[1] = op.inputs[0]
294 op.set_input_tensor(create_const_tensor(name, shape, dtype, values, quantization=quantization), 0)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200295 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000296 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200297
298 return op
299
300
Tim Hall5ff4cd12023-05-16 22:39:14 +0100301# Convert ResizeNearestNeighbor with align corners to a depthwise convolution. The IFM will already have been upscaled
Tim Hall885033b2022-07-21 11:46:03 +0100302# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
303# to select the appropriate nearest neighbor value
304def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
305 ifm = op.ifm
306 ofm = op.ofm
307 output_depth = ofm.shape[-1]
308 dw_op_attrs = {
309 "padding": Padding.VALID,
310 "stride_h": 1,
311 "stride_w": 1,
312 "strides": (1, 1, 1, 1),
313 "depth_multiplier": 1,
314 "channel_multiplier": 1,
315 "dilation_h_factor": 1,
316 "dilation_w_factor": 1,
317 "dilation": (1, 1, 1, 1),
318 }
319
Tim Hall5ff4cd12023-05-16 22:39:14 +0100320 # change ResizeNearestNeighbor to Depthwise
Tim Hall885033b2022-07-21 11:46:03 +0100321 op.type = Op.DepthwiseConv2DBias
322 op.attrs.update(dw_op_attrs)
323 op.set_input_tensor(ifm, 0) # ifm tensor index
324 op.activation = None
325
326 # add input resample to resize by x2
327 op.ifm_resampling_mode = resampling_mode.NEAREST
328
329 # don't care about the rounding mode as it is nearest neighbor
330
331 # setup weight tensor
332 weight_quant = QuantizationParameters()
333 weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
334 weight_quant.zero_point = 0
335 weight_quant.quant_dim = 0
336 ofm_dtype = ofm.dtype
Tim Hall3b1578e2023-01-13 17:57:25 +0000337 if ofm_dtype.type == BaseType.UnsignedInt:
Tim Hall885033b2022-07-21 11:46:03 +0100338 weight_quant.quant_min = 0
339 weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
340 else:
Tim Hall885033b2022-07-21 11:46:03 +0100341 weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
342 weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
343
344 weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
345
346 # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
347 # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
348 # below-and-right (i.e. next) to it (D).
349 # 0---1---2
350 # | A | B |
351 # 1---*---+
352 # | C | D |
353 # 2---+---+
354 weight_values = [0] * (upscale_factor * upscale_factor)
355 centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
356 weight_values[centre_coeff] = 1
357
358 # add weight tensor, this will discard the size tensor of the resize op
359 op.set_input_tensor(
360 create_const_tensor(
361 "weights",
362 weight_shape,
Tim Hall3b1578e2023-01-13 17:57:25 +0000363 ofm_dtype,
Tim Hall885033b2022-07-21 11:46:03 +0100364 np.array(weight_values).reshape(weight_shape),
Tim Hall885033b2022-07-21 11:46:03 +0100365 quantization=weight_quant,
366 ),
367 1, # inputs tensor weight index
368 )
369
370 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
371 # need to append the bias tensor as resize ops only have 2 inputs
372 assert len(op.inputs) == 2
373 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +0200374 fixup_bias_tensors(op, None, None, DataType.int32)
Tim Hall885033b2022-07-21 11:46:03 +0100375
376 # finally update the shape incase we've change the tensor shapes or connections
377 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000378 DebugDatabase.add_optimised(op, op)
Tim Hall885033b2022-07-21 11:46:03 +0100379
380 return op
381
382
383# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
384# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
385# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
386def convert_resize_to_upscale_and_average_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200387 pre_op = op
388 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000389 dtype = op.ifm.dtype
Tim Hall885033b2022-07-21 11:46:03 +0100390
Rickard Boline546def2022-01-25 15:45:00 +0000391 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Tim Hall47c76362022-07-18 21:26:47 +0100392 op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
Tim Hall3c5cfe92022-03-16 16:31:57 +0000393 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200394
395 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
Tim Hall47c76362022-07-18 21:26:47 +0100396
397 # Get upscale factor that was calculated in the supported operators check
398 upscale_factor = op.attrs["upscale_factor"]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200399
Rickard Boline546def2022-01-25 15:45:00 +0000400 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100401 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
402 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
Rickard Boline546def2022-01-25 15:45:00 +0000403 n = int(np.log2(upscale_factor))
404
Tim Hall885033b2022-07-21 11:46:03 +0100405 # Perform x2 upscaling n-1 times
Rickard Boline546def2022-01-25 15:45:00 +0000406 scaled_op = pre_op
407 for count in range(n - 1):
408 if count > 0:
409 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200410 scaled_op.inputs[0] = pre_op.outputs[0]
411
Tim Hall885033b2022-07-21 11:46:03 +0100412 # Nearest neighbor x2 upscaling
Tim Hall47c76362022-07-18 21:26:47 +0100413 upscaled_shape = upscaled_shape * 2
Rickard Boline546def2022-01-25 15:45:00 +0000414 shape = op.ofm_shapes[0].as_list()
415 shape[1:3] = upscaled_shape
416 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
417 out_tens.quantization = op.outputs[0].quantization.clone()
418 scaled_op.set_output_tensor(out_tens)
419 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200420
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200421 scaled_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000422 DebugDatabase.add_optimised(op, scaled_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200423
Tim Hall885033b2022-07-21 11:46:03 +0100424 # Last x2 upscaling
Rickard Boline546def2022-01-25 15:45:00 +0000425 if n > 1:
426 scaled_op = op.clone(f"_{n-1}")
427 scaled_op.inputs[0] = pre_op.outputs[0]
Tim Hall885033b2022-07-21 11:46:03 +0100428
429 if scaled_op.original_type == Op.ResizeBilinear:
430 if scaled_op.attrs["align_corners"]:
431 # no padding
432 scaled_op.attrs["padding"] = Padding.VALID
433 else:
434 # padding to the right and bottom (limits average pool to 8x8 kernel)
435 scaled_op.attrs["padding"] = Padding.EXPLICIT
436 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
437
438 # kernal size dependent on the upscaling factor
439 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
440 else: # Op.ResizeNearestNeighbor
441 if scaled_op.attrs["align_corners"]:
442 # use depthwise conv to select the correct value
443 scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
444 else:
Johan Alfvéna64616c2022-10-17 12:29:12 +0200445 # Keep 1x1 kernel and average pool, this applies both when
446 # half-pixel-centers is True and False. Calculations are the
447 # same in the reference.
Tim Hall885033b2022-07-21 11:46:03 +0100448 pass
449
Rickard Boline546def2022-01-25 15:45:00 +0000450 scaled_op.outputs = outputs
451 scaled_op.outputs[0].ops = [scaled_op]
452 scaled_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000453 DebugDatabase.add_optimised(op, scaled_op)
Rickard Boline546def2022-01-25 15:45:00 +0000454
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200455 return op
456
457
Raul Farkas66207142023-05-25 11:15:20 +0100458def convert_argmax_to_depthwise_conv_and_max_pool(op: Operation, arch, nng) -> Operation:
Rickard Bolin6986a072022-12-19 12:33:40 +0000459 """
460 Convert ArgMax to DWConv2D->MaxPool->DWConv2D, see details below.
461
462 Example:
463 arr = [4, [00000100,
464 6, = 00000110, # <-- This is the largest value, so we're expecting argmax(arr) = 1
465 5] 00000101]
466
467 Use 16-bit precision and shift all values 7 bits to the left:
468 Shifted_arr = [0000001000000000,
469 0000001100000000,
470 0000001010000000]
471
472 Add "c - index of channel" to each channel:
473 Shifted_arr_plus_reverse_idx = [0000001000000010, (+2)
474 0000001100000001, (+1)
475 0000001010000000] (+0)
476
477 The index is reversed since ArgMax selects the lowest index if maximum value is found at two index. The index will
478 act as a tie-breaker between channels with equal values and since we want the smallest channel index to be chosen
479 we reverse the index before the maxpool and then subtract the index from the number of channel after the maxpool to
480 get the correct index.
481
482 Find the maximum value in the array:
483 val = max(shifted_arr_plus_reverse_idx) = 0000001100000001
484
485 Subtract the value from the number of channels:
486 shifted_arr_plus_idx = (c-1) - val = 2 - 1 = 1
487
488 Extract the 7 lowest bits using a LUT to cut off the 9 most significant bits:
489 idx = LUT(val) = 0000000000000001 = 1
490 """
491
492 if op.type == Op.ArgMax:
493 ifm, ofm = op.inputs[0], op.outputs[0]
494 identity_quant = QuantizationParameters()
495 identity_quant.zero_point = 0
496 identity_quant.scale_f32 = 1.0
Rickard Bolin6986a072022-12-19 12:33:40 +0000497 # Add last dimension to ofm shape
498 ofm.shape += [1]
499 ofm.ops = []
500
501 # Create 1x1 Depthwise convolution with 2**7 weights for each channel to convert precision to 16 bit and shift
502 # all values 7 bits to the left
503 # Set necessary depthwise attributes
504 dw_op_attrs = {
505 "padding": Padding.VALID,
506 "stride_h": 1,
507 "stride_w": 1,
508 "strides": (1, 1, 1, 1),
509 "depth_multiplier": 1,
510 "channel_multiplier": 1,
511 "dilation_h_factor": 1,
512 "dilation_w_factor": 1,
513 "dilation": (1, 1, 1, 1),
514 "explicit_padding": None,
515 }
Johan Alfvenc1ad80b2023-03-31 10:19:23 +0200516 orig_name = op.name
517 op.name = f"{orig_name}_depthwise_conv_SHL_7"
Rickard Bolin6986a072022-12-19 12:33:40 +0000518 op.type = Op.DepthwiseConv2DBias
519 op.attrs.update(dw_op_attrs)
Johan Alfven56811e62023-03-27 11:33:50 +0200520 n, h, w, c = full_shape(4, ifm.shape, 1)
Rickard Bolin6986a072022-12-19 12:33:40 +0000521 shape = [1, 1, 1, c]
522 kernel = np.dstack([2**7] * c)
523 op.inputs = []
524 op.add_input_tensor(ifm)
525 op.add_input_tensor(
526 create_const_tensor(
527 "weights",
528 shape,
529 DataType.uint8,
530 np.array(kernel).reshape(shape),
531 quantization=identity_quant,
532 ),
533 )
534 # Let the bias for each channel be the "reverse" index of the channel it is in, ie c - channel_idx
535 reverse_idxs = list(reversed(range(c)))
536 bias_tensor = create_const_tensor(op.name + "_bias", [c], DataType.int64, reverse_idxs)
537 op.add_input_tensor(bias_tensor)
538
539 intermediate_tens = Tensor([n, h, w, c], DataType.int16, "int16_and_shifted_7_bits_left")
540 intermediate_tens.quantization = ifm.quantization
541 op.set_output_tensor(intermediate_tens)
542 op.set_ifm_ofm_shapes()
543 orig_ifm_shape = op.ifm_shapes[0]
544 DebugDatabase.add_optimised(op, op)
545
546 # To extract 7 least significant bits and swap reverse index back to real index using a LUT activation, we set
547 # the base value to c-1 and slope to -128. The 16-bit LUT uses a table of 32-bit values where the top 16 bits
548 # represent the slope and bottom 16 bits the base which are used to interpolate the activation value.
549 slope = (-128 & 0xFFFF) << 16 # Top 16 bits of 32 bit LUT table value
550 base = c - 1 # Bottom 16 bits of the LUT table value
551 lut_tensor = create_const_tensor(
552 "maxpool_LUT_extract_7_LSB",
553 [1, 1, 1, 512],
554 DataType.uint32,
555 [slope + base] * 512,
556 TensorPurpose.LUT,
557 )
558
559 # Split large feature maps into smaller chunks since the Depthwise Maxpool height dimension can overflow due to
560 # flattening the ifm to (H*W)xCx1
561 max_height = 2**16 // orig_ifm_shape.width
562 num_full_height_ops = orig_ifm_shape.height // max_height
563 last_op_height = orig_ifm_shape.height - max_height * num_full_height_ops
564 op_heights = [max_height] * num_full_height_ops
565 if last_op_height > 0:
566 op_heights.append(last_op_height)
567
568 # Create maxpool output tensor which is reshaped to 1x(H*W)x1x1. The product H*W might be larger than the
569 # maximum allowed height, but that's handled by reading and writing the data in chunks
570 maxpool_ofm = Tensor([1, orig_ifm_shape.height * orig_ifm_shape.width, 1, 1], DataType.int16, "argmax_maxpool")
571 maxpool_ofm.quantization = identity_quant
572
573 for op_idx, op_height in enumerate(op_heights):
574 maxpool_op = create_depthwise_maxpool(
575 f"dw_maxpool_{op_idx}", intermediate_tens, orig_ifm_shape, identity_quant
576 )
577 maxpool_op.outputs = [maxpool_ofm]
578 maxpool_ofm.ops.append(maxpool_op)
579 maxpool_op.ofm_shapes = [Shape4D(maxpool_ofm.shape)]
580 maxpool_op.set_activation_lut(lut_tensor)
581
582 # Set read and write shapes/offsets to read/write chunks of the IFM/OFM
583 maxpool_op.read_shapes[0] = Shape4D([1, op_height * orig_ifm_shape.width, orig_ifm_shape.depth, 1])
584 maxpool_op.read_offsets[0] = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0])
585 maxpool_op.write_shape = Shape4D([1, op_height * orig_ifm_shape.width, 1, 1])
586 maxpool_op.write_offset = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0])
587 DebugDatabase.add_optimised(op, maxpool_op)
588
Johan Alfvenc1ad80b2023-03-31 10:19:23 +0200589 # Set final shape
590 maxpool_ofm.set_all_shapes([1, h, w, 1])
591
592 # Convert 16bit to 32bit or 64bit
593 if ofm.dtype == DataType.int64:
594 # If OFM dtype is int64 the result is converted by two cast ops (16bit to 32bit)
595 #
596 # A -> B -> C -> D (OFM)
597 # |0001| |00010000| |0001|0000| |00010000|00000000|
598 # i16 i32 i16 i16 i32 i32
599 # <-------i64------->
600 #
601 # Memcpy is used to copy the content from B to C and from D to OFM
602 # Memcpy will be turned into a nop or an DMA transer if memory regions differs.
603 intermediate_32bit = Tensor([1, h, w, 1], DataType.int32, f"{orig_name}_32bit")
604 else:
605 intermediate_32bit = ofm
606
607 op_cast = create_cast_op(f"{orig_name}_cast_to_32bit_1", maxpool_ofm, intermediate_32bit)
608 DebugDatabase.add_optimised(op, op_cast)
609
610 if ofm.dtype == DataType.int64:
611 # Create int16 tensor with double shape to cover the intermediate_32bit result from the first cast
612 intermediate_16bit_2x_size = Tensor([1, h, w, 2], DataType.int16, f"{orig_name}_16bit_2x_size")
613 memcpy_op = create_memcpy(f"{orig_name}_memcpy_1", intermediate_32bit, intermediate_16bit_2x_size)
614 DebugDatabase.add_optimised(op, memcpy_op)
615
616 # Create int32 tensor with double ofm shape to be able to store a "int64" result
617 intermediate_32bit_2x_size = Tensor([1, h, w, 2], DataType.int32, f"{orig_name}_32bit_2x_size")
618
619 op_cast = create_cast_op(
620 f"{orig_name}_cast_to_32bit_2", intermediate_16bit_2x_size, intermediate_32bit_2x_size
621 )
622 DebugDatabase.add_optimised(op, op_cast)
623
624 memcpy_op = create_memcpy("f{orig_name}_memcpy_2", intermediate_32bit_2x_size, ofm)
625 DebugDatabase.add_optimised(op, memcpy_op)
Rickard Bolin6986a072022-12-19 12:33:40 +0000626
627 return op
628
629
Rickard Bolinfea15162022-07-04 16:19:16 +0000630def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
631 def _compute_interpolation_values(index, input_size, output_size):
632 scale = input_size / output_size
633 scaled_value = (index + 0.5 * half_pixel_centers) * scale - 0.5 * half_pixel_centers
634 lower_bound = max(np.floor(scaled_value), 0)
635
636 return scaled_value, lower_bound
637
638 def _compute_kernels(input_height, input_width, output_height, output_width):
639 kernels = []
640 for y in (1, 2):
641 for x in (1, 2):
642 sv_h, lb_h = _compute_interpolation_values(y, input_height, output_height)
643 sv_w, lb_w = _compute_interpolation_values(x, input_width, output_width)
644
645 # Interpolation values calculated for (x, y) = ([1, 2], [1, 2]) will always generalize to the whole
646 # input for upscale = 2 and input sizes >= 2x2 and be in the correct order for going left-to-right,
647 # top-to-bottom - same as the depthwise convolution strides across each tile
648 kernel = np.zeros((2, 2))
649 kernel[1, 1] = (1 - (sv_h - lb_h)) * (1 - (sv_w - lb_w))
650 kernel[0, 1] = (sv_h - lb_h) * (1 - (sv_w - lb_w))
651 kernel[1, 0] = (1 - (sv_h - lb_h)) * (sv_w - lb_w)
652 kernel[0, 0] = (sv_h - lb_h) * (sv_w - lb_w)
653 kernel *= 16
654 kernels.append(kernel)
655
656 return kernels
657
658 def _build_convolutions(op, kernels):
659 dw_op_attrs = {
660 "padding": Padding.TILE,
661 "stride_h": 1,
662 "stride_w": 1,
663 "strides": (1, 1, 1, 1),
664 "depth_multiplier": 1,
665 "channel_multiplier": 1,
666 "dilation_h_factor": 1,
667 "dilation_w_factor": 1,
668 "dilation": (1, 1, 1, 1),
669 }
670 ifm = op.ifm
671 ofm = op.ofm
672 ofm.ops = []
673 elem_size = 2 if ofm.dtype == DataType.int16 else 1
674
675 n, h, w, c = ifm.shape
676 _, _, ow, _ = ofm.shape
677
678 intermediate_tens = Tensor(ifm.shape, ifm.dtype, "intermediate_tens")
679 intermediate_tens.quantization = op.outputs[0].quantization.clone()
680 avgpool_op = op
681 avgpool_op.name = "rb_init_avgpool"
682 avgpool_op.type = Op.AvgPool
683 avgpool_op.attrs["padding"] = Padding.VALID
684 avgpool_op.attrs["stride_w"] = 1
685 avgpool_op.attrs["stride_h"] = 1
686 avgpool_op.attrs["filter_width"] = 1
687 avgpool_op.attrs["filter_height"] = 1
688 avgpool_op.attrs["strides"] = [1, 1, 1, 1]
689 avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
690
691 avgpool_op.add_input_tensor(ifm)
692 avgpool_op.set_output_tensor(intermediate_tens)
693 avgpool_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000694 DebugDatabase.add_optimised(op, op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000695
696 dw_conv = Operation(Op.DepthwiseConv2DBias, "depthwise_conv")
697 dw_conv._original_type = Op.ResizeBilinear
698 dw_conv.write_shape = Shape4D(n, h, w, c)
699 dw_conv.write_offset = Shape4D(0, 0, 0, 0)
700
Tim Hall5ff4cd12023-05-16 22:39:14 +0100701 # Resize bilinear requires rounding away from zero
702 dw_conv.rounding_mode = RoundingMode.AwayZero
Rickard Bolinfea15162022-07-04 16:19:16 +0000703
704 # Double height and width stride to write the output of each of the four depthwise convolutions below
705 # interleaved with each other when combined with OFM tile base offsets.
706 dw_conv.ofm_stride_multiplier = [1, 2, 2] # C/H/W
707
708 # Choose tile padding direction - pad by 1 with edge values in two direction.
709 # For example, TL (top left) will pad top and left in H/W-plane in all channels.
710 directions = [[1, 1, 0, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 1]] # TL, TR, BL, BR
711 for i in (0, 1):
712 for j in (0, 1):
713 index = i * 2 + j
714 dw_conv.name = f"depthwise_conv_{index}"
715 dw_op_attrs["explicit_padding"] = directions[index]
716 dw_conv.attrs.update(dw_op_attrs)
717
718 # This will offset the start of the write by modifying the Tile 0 base address
719 dw_conv.tile_base_offsets_ofm[0] = (i * ow + j) * c * elem_size
720
721 ofm.ops.append(dw_conv)
722 dw_conv.outputs = [ofm]
723
724 kernel = kernels[index]
725 shape = [2, 2, 1, c]
726 kernel = np.dstack([kernel] * c)
727
728 quant = QuantizationParameters()
729 quant.zero_point = 0
730 quant.scale_f32 = 1.0 / 16
731
732 dw_conv.inputs = []
733 dw_conv.add_input_tensor(intermediate_tens)
734 dw_conv.add_input_tensor(
735 create_const_tensor(
736 "weights",
737 shape,
738 intermediate_tens.dtype,
739 np.array(kernel).reshape(shape),
Rickard Bolinfea15162022-07-04 16:19:16 +0000740 quantization=quant,
741 ),
742 )
743
744 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
745 # need to append the bias tensor as resize ops only have 2 inputs
746 assert len(dw_conv.inputs) == 2
747 dw_conv.inputs.append(None)
Rickard Bolin017b4cc2022-09-23 10:16:48 +0000748 fixup_bias_tensors(dw_conv, None, None, dtype=DataType.int32)
Rickard Bolinfea15162022-07-04 16:19:16 +0000749
750 dw_conv.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000751 DebugDatabase.add_optimised(op, dw_conv)
752
Rickard Bolinfea15162022-07-04 16:19:16 +0000753 dw_conv = dw_conv.clone(f"_{index}")
754 return op
755
756 _, input_height, input_width, _ = op.ifm.shape
757 _, output_height, output_width, _ = op.ofm.shape
758
759 kernels = _compute_kernels(input_height, input_width, output_height, output_width)
760 op = _build_convolutions(op, kernels)
761
762 return op
763
764
Raul Farkas66207142023-05-25 11:15:20 +0100765def fixup_resize(op: Operation, arch, nng) -> Operation:
766 """Fixup resize ops to increase support for ResizeNearestNeighbor cases."""
Tim Hall885033b2022-07-21 11:46:03 +0100767 if op.type.is_resize_op() and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200768 if op.ifm_shapes[0] == op.ofm_shapes[0]:
Tim Hall885033b2022-07-21 11:46:03 +0100769 # Bypass the resize op which is essentially a NOP
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200770 op.inputs = op.inputs[:1]
771 op.type = Op.Identity
772 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
Tim Hall885033b2022-07-21 11:46:03 +0100773 convert_resize_1x1_to_add(op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000774 elif op.type == Op.ResizeBilinear and op.attrs.get("half_pixel_centers", False):
775 convert_resizebilinear_to_depthwise_convolutions(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200776 else:
Tim Hall885033b2022-07-21 11:46:03 +0100777 convert_resize_to_upscale_and_average_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200778
779 return op
780
781
782def convert_nop_split_to_identity(op, arch, nng):
783 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
784 # the list comprehension should return a list with a single tensor
785 # if it shouldn't, remove_passthrough_tensor will fail appropriately
786 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
787 op.type = Op.Identity
788 return op
789
790
Raul Farkas66207142023-05-25 11:15:20 +0100791def rewrite_fully_connected_input(op: Operation, arch, nng) -> Operation:
792 """Rewrite FullyConnected shape as 2D to allow it to run on NPU."""
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200793 # If the operation already have a read shape do not modify
794 # the ifm shape, since that will already be correct
795 if op.type == Op.FullyConnected and not op.read_shapes[0]:
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100796 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
797 assert new_shape is not None, "Tensor can not be reshaped to 2D"
798 op.ifm_shapes[0] = new_shape
Johan Alfvén65835e02022-10-13 10:49:30 +0200799
800 if op.ifm_shapes[0].batch > 1 and op.ofm_shapes[0].batch == 1:
801 # If IFM is batching then also make sure OFM is batching
802 h, w = op.ofm_shapes[0].height, op.ofm_shapes[0].width
803 op.ofm_shapes[0] = Shape4D([h * w, 1, 1, op.ofm_shapes[0].depth])
804
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200805 return op
806
807
Raul Farkas66207142023-05-25 11:15:20 +0100808def convert_batched_fc_shape(op: Operation, arch, nng) -> Operation:
809 """Convert batched FullyConnected op shape to allow for support on NPU."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200810 if op.type == Op.FullyConnected:
811 # Check if the first dimension indicates batching
812 if op.ifm_shapes[0].batch > 1:
813 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
814 n = op.ifm_shapes[0].batch
815 h, w = batching_split.get(n, (1, n))
816 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
817
818 # Reshape Weights to be 4D. IO becomes HWIO
819 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100820 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
821 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200822
823 n = op.ofm_shapes[0].batch
824 h, w = batching_split.get(n, (1, n))
825 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
826 return op
827
828
829def unfuse_activation_function(op):
830 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
831 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
832 op.activation = None
833 out_tens = op.outputs[0]
834 intermediate_tens = out_tens.clone("_act_intermediate")
835 act_op.set_output_tensor(out_tens)
836 act_op.add_input_tensor(intermediate_tens)
837 op.set_output_tensor(intermediate_tens)
838 act_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000839 DebugDatabase.add_optimised(op, act_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200840
841
842def rewrite_stridedslice_output(op, arch, nng):
843 if not op.run_on_npu or op.type != Op.StridedSlice:
844 return op
845
846 new_axis_mask = op.attrs["new_axis_mask"]
847 shrink_axis_mask = op.attrs["shrink_axis_mask"]
848
849 if shrink_axis_mask == 0 and new_axis_mask == 0:
850 return op
851
852 axis_4D = [0] * len(op.outputs)
853 for idx, out_tens in enumerate(op.outputs):
854 output_shape = list(out_tens.shape)
855
856 if shrink_axis_mask != 0:
857 n = 0
858 axis = 0
859 while shrink_axis_mask:
860 prev_mask = shrink_axis_mask
861 n += 1
862 shrink_axis_mask &= shrink_axis_mask - 1
863 axis = int(math.log2(prev_mask - shrink_axis_mask))
864 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
865
866 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
867 op.attrs["shrink_axis_mask"] = 0
868 if axis >= 0:
869 axis_4D[idx] = axis + (4 - len(output_shape))
870 else:
871 axis_4D[idx] = axis
872 op.ofm_shapes[idx] = Shape4D(output_shape)
873
874 elif new_axis_mask != 0:
875 n = 0
876 axis = 0
877 while new_axis_mask:
878 prev_mask = new_axis_mask
879 n += 1
880 new_axis_mask &= new_axis_mask - 1
881 axis = int(math.log2(prev_mask - new_axis_mask))
882 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
883 new_axis_mask >>= 1
884
885 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
886 op.attrs["new_axis_mask"] = 0
887 if axis >= 0:
888 axis_4D[idx] = axis + (4 - len(output_shape))
889 else:
890 axis_4D[idx] = axis
891 op.ofm_shapes[idx] = Shape4D(output_shape)
892
893 op.attrs["split_axis_4D"] = axis_4D
894 return op
895
896
897def rewrite_unpack_output(op, arch, nng):
898 tens = op.outputs[0]
899 if op.run_on_npu and op.type == Op.Unpack:
900 # Unpack is also referred to as Unstack
901 axis = int(op.attrs["axis"])
902 if axis < 0: # Convert to positive axis
903 axis = len(op.inputs[0].shape) + 1 + axis
904 op.type = Op.UnpackReshaped
905 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
906
907 axis_4D = axis + (4 - len(desired_output_shape))
908 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
909
910 for idx, out_tens in enumerate(op.outputs):
911 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
912 return op
913
914
915def add_padding_fields(op, arch, nng):
916 if op.run_on_npu:
917 if "padding" in op.attrs:
918 input_shape = op.ifm_shapes[0]
919 output_shape = op.ofm_shapes[0]
920 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
921 kernel_size = op.inputs[1].shape[:2]
922 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
923 kernel_size = op.attrs["ksize"][1:3]
924 else:
925 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
926
927 if op.type == Op.Conv2DBackpropInputSwitchedBias:
928 upscaling_factor = output_shape.height // input_shape.height
929 padding, skirt = calc_upscaled_padding_and_skirt(
930 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
931 )
932 else:
933 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200934 op.attrs["padding"],
935 op.kernel,
936 input_shape,
937 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200938 )
939
940 op.attrs["explicit_padding"] = padding
941 op.attrs["skirt"] = skirt
942
943 return op
944
945
Raul Farkas66207142023-05-25 11:15:20 +0100946def reorder_depthwise_weights(op: Operation, arch, nng) -> Operation:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200947 if op.type.is_depthwise_conv2d_op():
948 weight_tensor = op.inputs[1]
Alexander Hansson90c34b52023-05-31 15:03:03 +0000949 if not weight_tensor.weight_transpose_depthwise:
950 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
951 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
952 weight_tensor.weight_transpose_depthwise = True
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200953
954 return op
955
956
Raul Farkas3e7157b2023-05-09 09:09:17 +0100957def convert_avg_pool_to_conv2d(op: Operation, arch, nng) -> Operation:
958 """Convert strided Average Pools with stride >= 4 to Conv2D."""
959 if op.type != Op.AvgPool:
960 return op
961
962 stride_x, stride_y = op.get_kernel_stride()
963 # For strides <= 3 no optimization is needed
964 if stride_x <= 3:
965 return op
966 h, w = op.attrs["filter_height"], op.attrs["filter_width"]
967 inputs = op.inputs[0]
968 shape = inputs.shape
969
970 # Set necessary conv2d attributes
971 op.attrs.update(
972 {
973 "stride_h": stride_y,
974 "stride_w": stride_x,
975 "dilation_h_factor": 1,
976 "dilation_w_factor": 1,
977 "strides": (1, stride_y, stride_x, 1),
978 "dilation": (1, 1, 1, 1),
979 }
980 )
981
982 # Change op type
983 op.type = Op.Conv2DBias
984 op.name += "_conv2d"
985
986 op.rounding_mode = RoundingMode.AwayZero
987 shape = [h, w, 1, op.ofm.shape[-1]]
988 weights = np.full(shape, 1)
989 quant = QuantizationParameters(scale_f32=1 / (h * w), zero_point=0)
990 # Add unit weight tensor
991 op.add_input_tensor(
992 create_const_tensor(
993 "weights",
994 shape,
995 inputs.dtype,
996 weights,
997 quantization=quant,
998 ),
999 )
1000 op.weights.values = np.reshape(op.inputs[1].values, shape)
1001
1002 # Set IFM/OFM shapes after changing op type
1003 op.set_ifm_ofm_shapes()
1004 return op
1005
1006
1007def fixup_strided_conv(op: Operation, arch, nng):
Raul Farkas72c6a242023-03-16 16:38:05 +00001008 """Optimize or fixup strided Conv2DBias
1009 Optimization:
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001010 Reduce, when possible, the Conv2DBias stride from N with 1 > N > 4 to 1
1011 by re-shaping both IFM and filter.
Raul Farkas72c6a242023-03-16 16:38:05 +00001012
1013 Fixup:
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001014 Introduce software support for Conv2DBias with stride_width > 4 by
1015 reducing it to 1, 2 or 3 (HW supported strides) when possible by
1016 re-shaping both IFM and filter.
Raul Farkas72c6a242023-03-16 16:38:05 +00001017 """
Raul Farkas090f18a2023-01-24 16:29:06 +00001018 if op.type != Op.Conv2DBias:
Louis Verhaard43d27582022-03-17 14:06:00 +01001019 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001020 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +01001021 weight_tensor = op.weights
1022 ifm_shape = op.ifm_shapes[0]
Raul Farkas69782af2023-05-09 10:39:52 +01001023
1024 # Do not optimize if op is not the first in the network and stride is
1025 # supported by the hardware
1026 if op.op_index != 0 and stride_x < 4:
1027 return op
1028
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001029 resize_factor, final_stride = calc_resize_factor(ifm_shape.width, stride_x)
1030
1031 def calc_filter_padding(
1032 ifm_padding_type: Padding | None,
1033 ifm_current_padding_x: int,
1034 post_op_stride: int,
1035 opt_resize_factor: int,
1036 filter_width: int,
Raul Farkas3b64f062023-05-16 17:18:31 +01001037 ifm_width: int,
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001038 ) -> tuple[int, int, int, int]:
1039 """Calculate zero padding to be added to the filter.
1040
1041 Parameters
1042 ----------
1043 ifm_padding_type : Padding or None
1044 The padding type that is applied to the IFM.
1045 ifm_current_padding_x : int
1046 Padding amount that is added to the IFM before optimization.
1047 post_op_stride : int
1048 The final stride once optimization is performed.
1049 opt_resize_factor : int
1050 The factor by which the stride will be reduced.
1051 E.g. opt_resize_factor = 2 on a stride of 4 will produce
1052 a stride of 2 after the optimization
1053 filter_width : int
1054 Width of the filter before optimization.
Raul Farkas3b64f062023-05-16 17:18:31 +01001055 ifm_width : int
1056 Width of the IFM before optimization
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001057
1058 Returns
1059 -------
1060 padding : tuple[int, int, int, int]
1061 A tuple with the ammount of padding on each side (top, left, bottom, right)
1062 """
1063 padding_size = 0
1064 padding = (0, 0, 0, 0)
1065 if ifm_padding_type and ifm_padding_type != Padding.VALID:
Raul Farkas3b64f062023-05-16 17:18:31 +01001066 # Compute padding size for the filter that guarantees that HW padding added to IFM matches
1067 # before and after the optimization is performed
1068 expected_filter_size = 0
1069 pre_opt_stride = post_op_stride * opt_resize_factor
1070 post_opt_ifm_width = ifm_width // opt_resize_factor
1071 # Compute the total expected filter size post optimization that ensures that the same HW padding
1072 # is added to IFM.
1073 # There are two ways of calculating required filter size depending on whether IFM width is divisible
1074 # by stride width or not. These approaches match the cases used to calculate HW padding in
1075 # needed_total_padding method.
1076 if ifm_width % pre_opt_stride == 0:
1077 expected_filter_size = ifm_current_padding_x + post_op_stride
1078 else:
1079 expected_filter_size = ifm_current_padding_x + (post_opt_ifm_width % post_op_stride)
1080 # Compute padding size from expected filter size
1081 padding_size = expected_filter_size * opt_resize_factor - filter_width
1082
1083 if ifm_current_padding_x == 0:
1084 # If no HW padding is added to IFM, divide filter padding between left and right following
1085 # the same strategy as the reference.
1086 padding_left = padding_size // 2
1087 else:
1088 # If HW padding is added to IFM, split padding for the filter so that left padding and right padding
1089 # are proportional to left and right HW padding.
1090 left_hw_padding = ifm_current_padding_x // 2
1091 # Compute filter padding
1092 padding_left = padding_size // ifm_current_padding_x * left_hw_padding
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001093 padding = (0, padding_left, 0, padding_size - padding_left)
1094
1095 # Check if filter width is divisible by the stride width (required for optimization)
Raul Farkas3b64f062023-05-16 17:18:31 +01001096 # If filter width is not divisible by stride width and no HW padding is added to IFM, compute
1097 # filter padding required for the filter width to be divisible by the stride width and apply it as right
1098 # padding.
1099 if filter_width % opt_resize_factor != 0 and (padding_size == 0 or ifm_current_padding_x == 0):
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001100 padding_size = opt_resize_factor - (filter_width % opt_resize_factor)
1101 # Add padding zeros to the right
1102 padding = (0, 0, 0, padding_size)
1103
1104 return padding
1105
1106 # Compute the depth of the IFM once the strided Conv2D is optimised
1107 post_opt_ifm_depth = ifm_shape.depth * resize_factor
1108
1109 if stride_x > 1 and (post_opt_ifm_depth <= 8 or stride_x > 3) and resize_factor != 1 and weight_tensor is not None:
1110 k_w, _ = op.get_kernel_size()
1111 weight_shape = weight_tensor.shape
1112
1113 padding_type = op.attrs.get("padding", None)
1114 if padding_type in (None, Padding.EXPLICIT, Padding.TILE):
Louis Verhaard43d27582022-03-17 14:06:00 +01001115 return op
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001116 # Compute current padding as if IFM padding is SAME
1117 curr_padding_x = needed_total_padding(ifm_shape.width, stride_x, k_w)
1118 # Compute the padding needed on the filter for the optimisation
1119 _, left_filter_padding, _, right_filter_padding = calc_filter_padding(
Raul Farkas3b64f062023-05-16 17:18:31 +01001120 padding_type, curr_padding_x, final_stride, resize_factor, k_w, ifm_shape.width
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001121 )
1122 total_horizontal_padding = left_filter_padding + right_filter_padding
1123 # If IFM padding is enabled, check if pre-opt and post-opt padding is
1124 # the same while taking into consideration the extra filter padding.
1125 if padding_type == Padding.SAME:
1126 optimised_padding_x = needed_total_padding(
1127 ifm_shape.width // resize_factor, final_stride, (k_w + 1 + total_horizontal_padding) // resize_factor
1128 )
1129 if curr_padding_x != optimised_padding_x:
1130 # Horizontal padding would become different after optimisation; this would not work
1131 return op
1132
1133 # Resize IFM
Raul Farkas090f18a2023-01-24 16:29:06 +00001134 op.ifm_shapes[0] = Shape4D(
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001135 [ifm_shape.batch, ifm_shape.height, ifm_shape.width // resize_factor, ifm_shape.depth * resize_factor]
Raul Farkas090f18a2023-01-24 16:29:06 +00001136 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001137
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001138 # Compute list of 0 padding for each dimensions of the filter
1139 filter_dimension_padding = [(0, 0) for _ in weight_tensor.shape]
1140 # Update padding for filter width with computed padding
1141 filter_dimension_padding[1] = (left_filter_padding, right_filter_padding)
1142 # Add padding to the filter
1143 zero_point = weight_tensor.quantization.zero_point
1144 padding_constant = zero_point if np.isscalar(zero_point) else 0
1145 padded_filter_tensor = np.pad(weight_tensor.values, filter_dimension_padding, constant_values=padding_constant)
1146 weight_shape[1] = padded_filter_tensor.shape[1]
1147 weight_tensor.values = padded_filter_tensor
Raul Farkas090f18a2023-01-24 16:29:06 +00001148 # Change weight shape based on stride_x
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001149 weight_shape[1] //= resize_factor
1150 weight_shape[2] *= resize_factor
Raul Farkas090f18a2023-01-24 16:29:06 +00001151
James Peet7519d502021-07-19 16:47:58 +01001152 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001153 weight_tensor.set_all_shapes(weight_shape)
1154 # If multiple copies of the weights are used, we could avoid
1155 # them having the same address by changing the value_id
1156 weight_tensor.value_id = uuid.uuid4()
1157
1158 # Strides
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001159 stride_x = final_stride
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001160 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
1161
1162 return op
1163
1164
Raul Farkas66207142023-05-25 11:15:20 +01001165def convert_conv_to_fc(op: Operation, arch, nng) -> Operation:
1166 """Convert 1x1 Conv2D that behave like FullyConnected to FullyConnected, since they don't need any weight
1167 buffering.
1168 """
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001169 # Conv 1x1 can be equivalent to Fully Connected.
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001170 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
1171 if op.type == Op.Conv2DBias:
1172 h = op.ifm_shapes[0].height
1173 w = op.ifm_shapes[0].width
1174 kh, kw, _, _ = op.inputs[1].shape
1175 if h == 1 and w == 1 and kh == 1 and kw == 1:
1176 # Overwrite this op as a Fully Connected Op
1177 op.name += "_fc"
1178 op.type = Op.FullyConnected
1179 op.attrs = {
1180 "weights_format": 0,
1181 }
1182 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
1183 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +01001184 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
1185 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001186
1187 DebugDatabase.add_optimised(op, op)
1188 return op
1189
1190
Raul Farkas66207142023-05-25 11:15:20 +01001191def fixup_relus_with_differing_ifm_ofm_scaling(op: Operation, arch, nng) -> Operation:
1192 """Fixup Relu with different IFM and OFM to allow fusing by adding its own primary op."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001193 if op.run_on_npu and op.type.is_relu_op():
1194 ifm = op.inputs[0]
1195 ofm = op.outputs[0]
1196 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
1197 # and requires its own to be inserted
1198 if not check_quantized_tens_scaling_equal(ifm, ofm):
1199 # Override this op with its own primary op (avgpool)
1200 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
1201 # And fuse the original activation function to it
1202 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +02001203 # Add explicit rescaling
1204 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
1205 multiplier, shift = scaling.quantise_scale(rescale)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001206 relu_fused_op.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001207 # Tidy up and assign the ifm and ofm to the new op
1208 ifm.consumer_list.remove(op)
1209
1210 relu_fused_op.add_input_tensor(ifm)
1211 relu_fused_op.set_output_tensor(ofm)
1212 relu_fused_op.set_ifm_ofm_shapes()
1213 op = relu_fused_op
1214 return op
1215
1216
Raul Farkas66207142023-05-25 11:15:20 +01001217def convert_lstm(op: Operation, arch, nng) -> Operation:
1218 """Convert LSTM op into its basic opearations to allow for support on NPU."""
Fredrik Svedberg0ac08042023-04-11 22:35:04 +02001219 if op.type == Op.UnidirectionalSequenceLstm:
1220 lstm = Lstm(op)
1221 op = lstm.get_graph()
1222 return op
1223
1224
Raul Farkas66207142023-05-25 11:15:20 +01001225def convert_softmax(op: Operation, arch, nng) -> Operation:
1226 """Convert Softmax op into its basic operations to allow for support on NPU."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001227 if op.type == Op.Softmax and op.run_on_npu:
1228 softmax = SoftMax(op)
1229 op = softmax.get_graph()
1230 return op
1231
1232
Raul Farkas66207142023-05-25 11:15:20 +01001233def convert_prelu(op: Operation, arch, nng) -> Operation:
1234 """Convert PReLU op to other ops based on alpha values to allow for support on NPU."""
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001235 if op.type == Op.Prelu:
1236 ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
1237 if None in (ifm, alpha, ofm):
1238 return op
1239
Fredrik Svedberg66591652022-08-29 10:51:27 +02001240 if alpha.values is not None:
1241 # If const alpha check for possible optimisations
1242 alpha_zp = alpha.quantization.zero_point
1243 alpha_scale = alpha.quantization.scale_f32
1244 # If all alpha values are the same the PReLU can be converted to LeakyRelu
Rickard Bolin5fdcf172022-12-19 12:56:17 +00001245 alpha_min = (alpha.values.min().astype(int) - alpha_zp) * alpha_scale
1246 alpha_max = (alpha.values.max().astype(int) - alpha_zp) * alpha_scale
Fredrik Svedberg66591652022-08-29 10:51:27 +02001247 if alpha_min == alpha_max:
1248 # or even a Relu
1249 if alpha_min == 0:
1250 new_op = Op.Relu
1251 else:
1252 new_op = Op.LeakyRelu
1253 op.attrs["alpha"] = alpha_min
1254 # setup alpha_scaling for bit exact result
1255 ifm_scale = ifm.quantization.scale_f32
1256 ofm_scale = ofm.quantization.scale_f32
1257 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha_scale, ofm_scale)
1258 op.attrs["alpha_scaling"] = (alpha.values.min() - alpha_zp, alpha_scale, alpha_shift)
1259 # Change op type
1260 op.type = new_op
1261 op.name = op.name.replace("Prelu", new_op.name)
1262 del op.inputs[1] # Remove alpha tensor
1263 return op
1264 elif alpha_max < 1:
1265 # If alpha_max is less than 1 convert PReLU to Max(alpha * IFM, identity * IFM)
1266 # Multiply with alpha tensor
1267 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
1268 mul_alpha.add_input_tensor(ifm)
1269 mul_alpha.add_input_tensor(alpha)
1270 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1271 mul_alpha.set_output_tensor(fm_alpha)
1272 mul_alpha.set_ifm_ofm_shapes()
1273 DebugDatabase.add_optimised(op, mul_alpha)
1274 if check_quantized_tens_scaling_equal(ifm, ofm):
1275 # No scaling is needed
1276 fm_id = ifm
1277 else:
1278 # Add multiplication with identity
1279 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1280 mul_identity.add_input_tensor(ifm)
1281 # Create const tensor containing identity as scalar
1282 quantization = ifm.quantization.clone()
1283 quantization.scale_f32 = np.float32(1)
1284 quantization.zero_point = 0
1285 one = create_const_tensor("one_const", [], ifm.dtype, [1], quantization=quantization)
1286 mul_identity.add_input_tensor(one)
1287 # Make sure that fm_id is allocated to a different address than fm_alpha
1288 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1289 mul_identity.set_output_tensor(fm_id)
1290 mul_identity.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +00001291 DebugDatabase.add_optimised(op, mul_identity)
Fredrik Svedberg66591652022-08-29 10:51:27 +02001292
1293 # Combine scaled and alpha multiplied values
1294 max_op = Operation(Op.Maximum, op.name + "_max")
1295 max_op.add_input_tensor(fm_alpha)
1296 max_op.add_input_tensor(fm_id)
1297 max_op.set_output_tensor(ofm)
1298 max_op.set_ifm_ofm_shapes()
1299
1300 DebugDatabase.add_optimised(op, max_op)
1301 ifm.consumer_list.remove(op)
1302 return max_op
1303
1304 # Catch all PReLU conversion for the cases that could not be optimised above
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001305 no_scale_quant = ifm.quantization.clone()
1306 no_scale_quant.scale_f32 = None
1307 no_scale_quant.zero_point = 0
Fredrik Svedberg66591652022-08-29 10:51:27 +02001308 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001309
1310 # Select values < 0
1311 min_op = Operation(Op.Minimum, op.name + "_min")
1312 min_op.add_input_tensor(ifm)
1313 min_op.add_input_tensor(zero)
1314 fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
1315 min_op.set_output_tensor(fm_negative)
1316 min_op.set_ifm_ofm_shapes()
1317 DebugDatabase.add_optimised(op, min_op)
1318
1319 # and multiply with alpha tensor
1320 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
1321 mul_alpha.add_input_tensor(fm_negative)
1322 mul_alpha.add_input_tensor(alpha)
1323 fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
1324 mul_alpha.set_output_tensor(fm_alpha)
1325 mul_alpha.set_ifm_ofm_shapes()
1326 DebugDatabase.add_optimised(op, mul_alpha)
1327
1328 # Select (and scale) values > 0
1329 relu_op = Operation(Op.Relu, op.name + "_relu")
1330 relu_op.add_input_tensor(ifm)
1331 fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1332 relu_op.set_output_tensor(fm_scaled)
1333 relu_op.set_ifm_ofm_shapes()
1334 DebugDatabase.add_optimised(op, relu_op)
1335
1336 # Add scaled and alpha multiplied values (without scaling)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001337 add_op = Operation(Op.Add, op.name + "_add")
1338 add_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001339 add_op.add_input_tensor(fm_alpha)
1340 add_op.add_input_tensor(fm_scaled)
1341 add_op.set_output_tensor(ofm)
1342 add_op.set_ifm_ofm_shapes()
1343
1344 DebugDatabase.add_optimised(op, add_op)
1345 ifm.consumer_list.remove(op)
1346 op = add_op
1347
1348 return op
1349
1350
Raul Farkas66207142023-05-25 11:15:20 +01001351def convert_mul_max_to_abs_or_lrelu(op: Operation, arch, nng) -> Operation:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001352 r"""Whenever there is a subgraph with this topology:
1353
Jonas Ohlssond8575072022-03-30 10:30:25 +02001354 Input X For X = -1 or X > 0
1355 | \ / This subgraph can be replaced with either
1356 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
1357 | /
1358 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001359 """
1360
1361 if op.type == Op.Maximum:
1362 # finds the Mul input(s) to the Max
1363 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
1364 if len(muls) == 1:
1365 mul = muls[0].ops[0]
1366 elif len(muls) == 2:
1367 # In the case both inputs are Muls, find the one with the same input as the Max
Fredrik Svedberg66591652022-08-29 10:51:27 +02001368 mul_ifms = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1]
1369 if len(mul_ifms):
1370 mul = mul_ifms[0].ops[0]
1371 else:
1372 # Not using same input
1373 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001374 else:
1375 # No Mul inputs
1376 return op
1377
1378 # make sure the Mul doesn't have any other consumers
1379 mul_ofm = mul.outputs[0]
1380 if len(mul_ofm.consumers()) != 1:
1381 return op
1382 # make sure the Mul doesn't have a fused activation function
1383 if mul.activation:
1384 return op
1385 ifm, ofm = op.get_ifm_ofm()
1386 if ifm is None or ofm is None:
1387 return op
1388
1389 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1390 return op
1391 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
1392 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
1393 return op
1394
1395 # finds the branched input that goes to both the Max and the Mul
1396 shared = set(op.inputs) & set(mul.inputs)
1397 if len(shared) == 1:
1398 shared_in = shared.pop()
1399 # find the constant scalar input to the Mul
1400 const_tens = (set(mul.inputs) - {shared_in}).pop()
1401 # check that it is a scalar
1402 if const_tens.shape != []:
1403 return op
1404 const = const_tens.ops[0]
1405 # check that it is a constant
1406 if const.type != Op.Const:
1407 return op
1408 # Remove the Mul from the shared input's consumers
1409 shared_in.consumer_list.remove(mul)
1410 else:
1411 return op
1412
1413 val = const.outputs[0].values
1414 if val >= 0:
1415 new_op = Op.LeakyRelu
1416 op.attrs["alpha"] = val
1417 # to produce bit exact results, the alpha is not enough;
1418 # save additional scaling info in attr "alpha_scale", to be used as input
1419 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +01001420 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001421 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
1422 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
1423 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
1424 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
1425 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
1426 elif val == -1:
1427 new_op = Op.Abs
1428 else:
1429 return op
1430
1431 op.type = new_op
1432 op.name = op.name.replace("Maximum", new_op.name)
1433 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
1434 op.inputs = [shared_in]
1435 op.set_ifm_ofm_shapes()
1436
1437 # Record optimisation in debug database
1438 DebugDatabase.add_optimised(op, op)
1439
1440 return op
1441
1442
Raul Farkas66207142023-05-25 11:15:20 +01001443def convert_hardswish_to_lut(op: Operation, arch, nng) -> Operation:
1444 """Convert HardSwish to LUT to allow for support on NPU."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001445 if op.type == Op.HardSwish:
1446 ifm, ofm = op.get_ifm_ofm()
1447 # Generate the LUT
1448 ifm_scale = np.double(ifm.quantization.scale_f32)
1449 ofm_scale = np.double(ofm.quantization.scale_f32)
1450 zp_in = ifm.quantization.zero_point
1451 zp_out = ofm.quantization.zero_point
1452 ifm_scale_hires = (1 / 128) * ifm_scale
1453 relu_multiplier = np.double(3 / 32768)
1454 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
1455 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
1456 # Use 16bit scale
1457 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
1458 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
1459
1460 values = []
1461 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1462 quantized_min = min(ix)
1463 quantized_max = max(ix)
1464 for x in ix:
1465 input_value = x - zp_in
1466 input_value_hires = input_value * 128
1467 # Compute the input value on essentially the output scale, not shifted yet
1468 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
1469 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
1470 relu_value = np.int16(input_value_hires)
1471 if relu_shift < 31:
1472 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
1473
1474 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
1475
1476 if relu_shift < 31:
1477 relu_value = fp_math.shift_left16(relu_value, 1)
1478
1479 if relu_shift > 31:
1480 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
1481
1482 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
1483 # Now convert that to a 16bit fixedpoint value in [0, 1]
1484 relu_value = (relu_value + (1 << 15)) >> 1
1485 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
1486 shift = 31 - out_shift
1487 shift = -shift if shift < 0 else 0
1488 # Finally apply the output shift
1489 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
1490 lut_result = min(quantized_max, max(quantized_min, lut_result))
1491 values.append(lut_result)
1492 return convert_to_lut(op, values, "hardswish")
1493 return op
1494
1495
1496def convert_lrelu_to_mul_max(op, arch):
1497 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
1498 # (the opposite of convert_mul_max_to_abs_or_lrelu)
1499 ifm, ofm = op.get_ifm_ofm()
1500 if ifm is None or ofm is None:
1501 return op
1502
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001503 alpha = np.float32(op.attrs["alpha"])
1504 use_mul_max = 0 < alpha < 1
Fredrik Svedberg36424312022-09-16 09:39:26 +02001505 is_converted_prelu = "alpha_scaling" in op.attrs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001506 if use_mul_max:
1507 mul_ifm = ifm
1508 new_op = Op.Maximum
1509 else:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001510 # Need to use a different approach for alpha < 0 or alpha > 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001511 no_scale_quant = ifm.quantization.clone()
1512 no_scale_quant.scale_f32 = None
1513 no_scale_quant.zero_point = 0
1514 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
1515
1516 # Select values < 0
1517 min_op = Operation(Op.Minimum, op.name + "_min")
1518 min_op.add_input_tensor(ifm)
1519 min_op.add_input_tensor(zero)
1520 mul_ifm = ifm.clone(op.name + "_negative", set_unique=True)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001521 if alpha < 0 and not is_converted_prelu:
1522 # For negative alpha that is not from a converted PReLU we need to use
1523 # int32 Mul below to perform the (negative) alpha scaling
1524 mul_ifm.dtype = DataType.int32
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001525 min_op.set_output_tensor(mul_ifm)
1526 min_op.set_ifm_ofm_shapes()
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001527 new_op = Op.Add
1528 op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001529 DebugDatabase.add_optimised(op, min_op)
1530
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001531 # Add multiplication with alpha
1532 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001533 mul_alpha.add_input_tensor(mul_ifm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001534 # Create const tensor containing alpha as scalar
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001535 quantization = ifm.quantization.clone()
1536 quantization.min = 0
1537 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
1538 quantization.zero_point = 0
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001539 alpha_dtype = mul_ifm.dtype
Fredrik Svedberg36424312022-09-16 09:39:26 +02001540 if is_converted_prelu:
1541 # The LeakyRelu was the result from convert_prelu and the scaling is provided
Fredrik Svedberg66591652022-08-29 10:51:27 +02001542 scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001543 mul_alpha.explicit_scaling = ExplicitScaling(False, [alpha_shift], [alpha_scale])
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001544 elif alpha == 0 or np.isinf(1 / alpha):
1545 # Handling of alpha near or at zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001546 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001547 scalar = 0
1548 else:
1549 quantization.scale_f32 = alpha
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001550 if alpha_dtype == DataType.int32:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001551 # When the datatype is int32 (alpha negative) we need to do the scaling with the multiplication
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001552 scalar, _ = scaling.elementwise_mul_scale(ifm.quantization.scale_f32, alpha, ofm.quantization.scale_f32)
1553 else:
1554 scalar = 1
Tim Hall3b1578e2023-01-13 17:57:25 +00001555 alpha_tens = create_const_tensor(op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], quantization=quantization)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001556 mul_alpha.add_input_tensor(alpha_tens)
1557 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1558 mul_alpha.set_output_tensor(fm_alpha)
1559 mul_alpha.set_ifm_ofm_shapes()
1560 DebugDatabase.add_optimised(op, mul_alpha)
1561
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001562 if not use_mul_max:
1563 relu_op = Operation(Op.Relu, op.name + "_relu")
1564 relu_op.add_input_tensor(ifm)
1565 fm_id = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1566 relu_op.set_output_tensor(fm_id)
1567 relu_op.set_ifm_ofm_shapes()
1568 DebugDatabase.add_optimised(op, relu_op)
1569 elif check_quantized_tens_scaling_equal(ifm, ofm):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001570 # No identity multiplication is needed
1571 fm_id = ifm
1572 else:
1573 # Add multiplication with identity
1574 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1575 mul_identity.add_input_tensor(ifm)
1576 # Create const tensor containing identity as scalar
1577 quantization = ifm.quantization.clone()
1578 quantization.min = 0
1579 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001580 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001581 quantization.zero_point = 0
Tim Hall3b1578e2023-01-13 17:57:25 +00001582 identity_tens = create_const_tensor(op.name + "_id_scalar", [], ifm.dtype, [1], quantization=quantization)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001583 mul_identity.add_input_tensor(identity_tens)
1584 # Make sure that fm_id is allocated to a different address than fm_alpha
1585 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1586 mul_identity.set_output_tensor(fm_id)
1587 mul_identity.set_ifm_ofm_shapes()
1588 DebugDatabase.add_optimised(op, mul_identity)
1589
1590 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001591 op.type = new_op
1592 op.name = op.name.replace("LeakyRelu", new_op.name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001593 op.inputs = []
1594 ifm.consumer_list.remove(op)
1595 op.add_input_tensor(fm_alpha)
1596 op.add_input_tensor(fm_id)
1597 op.set_ifm_ofm_shapes()
1598
1599 DebugDatabase.add_optimised(op, op)
1600 return op
1601
1602
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001603def convert_to_lut8(op, fn, fn_name):
1604 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
1605 # fn is a function(real) -> real
1606 ifm, ofm = op.get_ifm_ofm()
1607 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1608 return op
1609 # Generate the LUT
1610 ifm_scale = np.double(ifm.quantization.scale_f32)
1611 ofm_scale = np.double(ofm.quantization.scale_f32)
1612 zp_in = ifm.quantization.zero_point
1613 zp_out = ofm.quantization.zero_point
1614 values = []
1615 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1616 quantized_min = min(ix)
1617 quantized_max = max(ix)
1618 for x in ix:
1619 x_real = ifm_scale * (x - zp_in)
1620 y_real = fn(x_real)
1621 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1622 lut_result = min(quantized_max, max(quantized_min, lut_result))
1623 values.append(lut_result)
1624 return convert_to_lut(op, values, fn_name)
1625
1626
1627def convert_lrelu_to_lut(op, arch):
1628 ifm, ofm = op.get_ifm_ofm()
1629 # Generate the LUT
1630 alpha = op.attrs["alpha"]
1631 ifm_scale = np.double(ifm.quantization.scale_f32)
1632 ofm_scale = np.double(ofm.quantization.scale_f32)
1633 zp_in = ifm.quantization.zero_point
1634 zp_out = ofm.quantization.zero_point
1635 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1636 alpha_scalar = 1
1637 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1638 if "alpha_scaling" in op.attrs:
1639 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1640 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1641 values = []
1642 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1643 quantized_min = min(ix)
1644 quantized_max = max(ix)
1645 for x in ix:
1646 if x < zp_in:
1647 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1648 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1649 )
1650 else:
1651 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1652 lut_result = min(quantized_max, max(quantized_min, lut_result))
1653 values.append(lut_result)
1654 return convert_to_lut(op, values, "lrelu")
1655
1656
Raul Farkas66207142023-05-25 11:15:20 +01001657def convert_lrelu(op: Operation, arch, nng) -> Operation:
1658 """Convert LeakyRelu to a LUT based solution if possible, otherwise a mul + max."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001659 if op.type != Op.LeakyRelu:
1660 return op
1661 ifm, ofm = op.get_ifm_ofm()
1662 if ifm is None or ofm is None:
1663 return op
Fredrik Svedberg36424312022-09-16 09:39:26 +02001664 alpha = op.attrs["alpha"]
1665 if alpha == 0:
1666 # When alpha is 0 the opertion can be converted to a ReLU
1667 op.type = Op.Relu
1668 op.name = op.name.replace("LeakyRelu", op.type.name)
1669 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001670 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1671 # use LUT for int8/uint8
1672 return convert_lrelu_to_lut(op, arch)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001673 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16 and alpha > 0:
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001674 # use LeakyRelu unmodified for int16 with equal input/output scaling and positive alpha
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001675 return op
1676 return convert_lrelu_to_mul_max(op, arch)
1677
1678
Raul Farkas66207142023-05-25 11:15:20 +01001679def convert_tanh_sigmoid_to_lut(op: Operation, arch, nng) -> Operation:
1680 """Convert int8/uint8 Sigmoid and Tanh to a LUT based solution."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001681 if op.type == Op.Sigmoid:
1682 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1683 elif op.type == Op.Tanh:
1684 return convert_to_lut8(op, math.tanh, "tanh")
1685 return op
1686
1687
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001688def fuse_activation_function_with_prev(op, arch, nng):
1689 # if op is a no-op: attempts to move the activation function to the preceding op
1690 if not op.attrs.get("is_nop", False) or op.activation is None:
1691 return op
1692 ifm, ofm = op.get_ifm_ofm()
1693 if ifm is None or ofm is None:
1694 return op
1695 # finds the input(s) to the operation
1696 prev_op = ifm.ops[0]
1697 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1698 fuse = (
1699 prev_op.run_on_npu
1700 and prev_op.type.npu_block_type != NpuBlockType.Default
1701 and len(ifm.ops) == 1
1702 and len(prev_op.outputs[0].consumers()) == 1
1703 and prev_op.activation is None
1704 )
1705 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1706 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1707 # LUT currently only works correctly for elementwise ops
1708 fuse = False
1709 if not fuse:
1710 return op
1711 # Move the fused activation function + corresponding info to prev_op
1712 prev_op.activation = op.activation
1713 prev_op.forced_output_quantization = op.forced_output_quantization
1714 if op.activation_lut is not None:
1715 prev_op.set_activation_lut(op.activation_lut)
1716 # Bypass op
1717 prev_op.set_output_tensor(ofm)
wilisa0179a89042022-11-02 17:18:43 +00001718 DebugDatabase.add_optimised(prev_op, prev_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001719 return op
1720
1721
1722def _leading_pad_ok(leading_pad, stride, kernel_size):
1723 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1724 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1725 max_size = kernel_size // 2
1726 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1727
1728
Raul Farkas66207142023-05-25 11:15:20 +01001729def replace_pad_by_hw_pad(op: Operation, arch, nng) -> Operation:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001730 """
1731 Tries to completely remove a PAD operator by using hardware padding.
1732 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1733 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1734 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1735 if both operations can be run on the NPU.
1736 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1737 """
1738 if (
1739 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +00001740 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001741 and op.run_on_npu
1742 and op.attrs["padding"] == Padding.VALID
1743 ):
1744 pad_op = op.ifm.ops[0]
1745 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1746 return op
1747 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1748 return op
1749 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1750 k = op.kernel
1751 k_w, k_h = k.dilated_wh()
1752
1753 # Check if the PAD operator can be replaced by hardware padding
1754 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1755 # Too much padding, it would require hardware padding to actually insert zeros
1756 return op
1757 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1758 return op
1759
1760 if op.type.is_avgpool_op():
1761 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1762 for pad, k_size in (
1763 (left, k_w),
1764 (right, k_w),
1765 (top, k_h),
1766 (bottom, k_h),
1767 ):
1768 if pad not in (0, k_size // 2):
1769 return op
1770 # Average pool is converted to depthwise, because NPU average pool + same padding
1771 # has a special implementation that is different from PAD followed by average pool with
1772 # valid padding.
1773 k_w, k_h = op.kernel.width, op.kernel.height
1774 ifm = op.ifm
1775 # Remember other inputs
1776 other_inputs = op.inputs[1:]
1777 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1778 quantization = QuantizationParameters(0.0, 255.0)
1779 quantization.scale_f32 = 1.0 / (k_w * k_h)
1780 quantization.zero_point = 0
1781 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1782 weights = np.full(shape, 1)
1783
1784 weight_tens = create_const_tensor(
1785 op.name + "_weights",
1786 shape,
1787 op.ifm.dtype,
1788 weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001789 purpose=TensorPurpose.Weights,
1790 quantization=quantization,
1791 )
James Peet7519d502021-07-19 16:47:58 +01001792 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001793 op.type = Op.DepthwiseConv2DBias
1794 op.inputs = []
1795 op.add_input_tensor(ifm)
1796 op.add_input_tensor(weight_tens)
Tim Hall5ff4cd12023-05-16 22:39:14 +01001797
1798 if op.ifm.dtype == DataType.uint8:
1799 op.rounding_mode = RoundingMode.HalfUp
1800
1801 # Add bias tensor, all biases set to 0
1802 op.inputs.append(None)
1803 fixup_bias_tensors(op, arch, nng, DataType.int32)
1804
1805 else:
1806 op.rounding_mode = RoundingMode.AwayZero
1807
1808 # The DepthwiseConv needs to be performed with the IFM zero point set appropriately so that the correct
1809 # pad values are used. However, in order to use the rounding away from zero mode the zero point needs to
1810 # have been removed so that the zero point is at zero. This is done by adding a kernel sized amount of
1811 # the zero point as a bias. The datatype of the bias needs to be set to int32, even for an int16 IFM,
1812 # because this will cause full precision scaling to be used (see weight compression). Finally, the OFM
1813 # zero point will need forcing to zero (as it has already been removed)
1814 nr_biases = op.inputs[1].shape[-1]
1815 bias_values = [op.ifm.quantization.zero_point * k_h * k_w] * nr_biases
1816 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
1817 op.add_input_tensor(bias_tensor)
1818
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001819 # Add other inputs
1820 op.inputs.extend(other_inputs)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001821
1822 # Bypass the PAD operator
1823 op.set_input_tensor(pad_op.ifm, 0)
1824 # Adjust the padding attributes of the convolution operator
1825 op.attrs["padding"] = Padding.EXPLICIT
1826 op.attrs["explicit_padding"] = (top, left, bottom, right)
1827 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +00001828 DebugDatabase.add_optimised(op, op)
1829
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001830 return op
1831
1832
1833def convert_pad(op: Operation, arch, nng):
1834 """
1835 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1836 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1837 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1838 """
1839 if op.type != Op.Pad or not op.run_on_npu:
1840 return op
1841 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1842
1843 ifm = op.ifm
1844 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001845 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001846 ofm = op.ofm
1847 assert ofm is not None
1848 ofm.ops = []
1849 ofm_shape = op.ofm_shapes[0]
1850
1851 # Average pool op that copies IFM to the right place inside the OFM
1852 shp0 = Shape4D(0, 0, 0, 0)
1853 shp_top = shp0.with_height(top)
1854 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1855 avgpool_op.activation = op.activation
1856 quant = ofm.quantization
1857 pad_value = quant.zero_point
1858 # Add operations that fill the borders of the OFM
1859 if top > 0:
1860 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1861 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001862 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001863 )
1864 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1865 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1866 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1867 if bottom > 0:
1868 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1869 zero_tens = create_const_tensor(
1870 op.name + "_bottom",
1871 shape.as_list(),
1872 ofm.dtype,
1873 shape.elements() * [pad_value],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001874 quantization=quant,
1875 )
1876 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1877 create_avg_pool_for_concat(
1878 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1879 )
1880 if left > 0:
1881 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1882 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001883 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001884 )
1885 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1886 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1887 if right > 0:
1888 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1889 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001890 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001891 )
1892 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1893 create_avg_pool_for_concat(
1894 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1895 )
1896
1897 op.type = Op.ConcatTFLite
1898 return avgpool_op
1899
1900
Raul Farkas66207142023-05-25 11:15:20 +01001901def fixup_bias_tensors(op: Operation, arch, nng, dtype=None) -> Operation:
1902 """Fixup ops that require a bias and don't have one by adding a bias tensor filled with zeros."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001903 if op.type.needs_bias() and op.bias is None:
1904 # Op has no bias, add bias tensor filled with zeros
1905 nr_biases = op.inputs[1].shape[-1]
1906 bias_values = [0] * nr_biases
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001907 # The DataType of the bias tensor can be explicitly provided or deduced from the ifm
1908 # DataType. Default is int32 bias for 8-bit ifms and int64 for int16 ifms.
1909 # For int16 the selected bias DataType will have an impact on the scaling
1910 # used when encoding the scales and biases later. The default mode will match the
1911 # refence with reduced scaling for int64 bias.
1912 # This means that in cases (in the graph optimiser) where DepthwiseConv2DBias
1913 # is used to emulate average pool int32 bias should be selected for full precision
1914 # int16 scaling.
1915 if dtype is None:
1916 dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32
1917 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values)
Raul Farkas3e7157b2023-05-09 09:09:17 +01001918 bias_index = op.type.info.indices.biases[0]
1919 if bias_index < len(op.inputs):
1920 op.set_input_tensor(bias_tensor, bias_index)
1921 else:
1922 op.add_input_tensor(bias_tensor)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001923
1924 return op
1925
1926
wilisa0146c94772023-02-08 09:56:14 +00001927def detect_asymmetric_weights(op):
1928 # Check all ops (cpu and npu)
1929 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
1930 if op.ifm.dtype in (DataType.int8, DataType.int16):
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001931 if not np.all(op.weights.quantization.zero_point == 0):
wilisa0146c94772023-02-08 09:56:14 +00001932 print(f"Warning: Op {op.type} '{op.name}' has asymmetric weights.", end=" ")
1933 return True
1934 return False
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001935
wilisa0146c94772023-02-08 09:56:14 +00001936
Raul Farkas66207142023-05-25 11:15:20 +01001937def fixup_asymmetric_weights(op: Operation, arch, nng) -> Operation:
wilisa0146c94772023-02-08 09:56:14 +00001938 if detect_asymmetric_weights(op):
1939 if op.run_on_npu:
1940 print("Zero points have been adjusted.")
1941 op.weights.quantization.zero_point *= 0
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001942 return op
1943
1944
wilisa0146c94772023-02-08 09:56:14 +00001945def check_asymmetric_weights(op, arch, nng):
1946 # This function can modify the run_on_npu flag which causes an operator to be placed on the CPU. It is usually only
1947 # set by the supported operator checks. Therefore, it should be run immediately after those checks to avoid the
1948 # possibility of other graph optimiser functions modify the operator (that is later run on the CPU)
1949 if detect_asymmetric_weights(op):
1950 if op.run_on_npu:
1951 print("To run the operator on Ethos-U use the option --force-symmetric-int-weights")
1952 op.run_on_npu = False
1953 return op
1954
1955
1956def fixup_or_check_asymmetric_weights(force_symmetric_int_weights):
1957 if force_symmetric_int_weights:
1958 return fixup_asymmetric_weights
1959 else:
1960 return check_asymmetric_weights
1961
1962
Rickard Bolina68b82a2023-04-20 15:12:28 +00001963def convert_mean_to_depthwise_conv(op, arch, nng):
Alexander Hansson90c34b52023-05-31 15:03:03 +00001964 """
1965 When h x w <= 4096 When h x w > 4096 there is a need to split into several ops.
1966 Do this by splitting up h and change the read_offset/shape.
1967 Below is an example where ifm is 1x190x64x1
1968 MEAN MEAN
1969 | |-----------------------|----------------------|
1970 DepthwiseConv2DBias 1_DepthwiseConv2DBias 2_DepthwiseConv2DBias 3_DepthwiseConv2DBias
1971 | | | |
1972 MUL |---------ADD-----------| |
1973 | |
1974 |----------------ADD---------------|
1975 |
1976 MUL
1977 1_DepthwiseConv2DBias: read_offset [0, 0, 0, 0]> read_shape [1, 64, 64, 1]>
1978 2_DepthwiseConv2DBias: read_offset [0, 64, 0, 0]> read_shape [1, 64, 64, 1]>
1979 3_DepthwiseConv2DBias: read_offset [0, 128, 0, 0]> read_shape [1, 62, 64, 1]>
1980 """
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001981 if op.type == Op.Mean and op.run_on_npu:
Alexander Hansson90c34b52023-05-31 15:03:03 +00001982 max_kernel_size = 4096
1983 max_height = 64
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001984 inp, axis = op.inputs
1985 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001986 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001987 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001988 dims_ofm = len(ofm_shape)
Alexander Hansson90c34b52023-05-31 15:03:03 +00001989 ofmq = op.ofm.quantization
1990 ifmq = op.ifm.quantization
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001991
1992 # Height and width axes have different index depending on dimensions
1993 if axis.shape == [] or axis.shape[0] == 1: # single axis
1994 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
Alexander Hansson90c34b52023-05-31 15:03:03 +00001995 # If dims is 4, axis 1 refers to h-dimension
1996 if dims == 4:
1997 reduce_h, reduce_w = (True, False) if axis == 1 else (False, True)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001998 else:
Alexander Hansson90c34b52023-05-31 15:03:03 +00001999 reduce_h, reduce_w = (True, False) if axis == 0 else (False, True)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002000 else: # multiple axes
2001 axis = sorted(axis.values)
Alexander Hansson90c34b52023-05-31 15:03:03 +00002002 reduce_h, reduce_w = (True, True)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002003
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002004 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01002005 def extend_dims(dim, in_shape):
2006 if dim < 4:
2007 in_shape = [1] + in_shape
2008 if dim == 2:
2009 in_shape += [1]
2010 return in_shape
2011
2012 if dims < 4 or dims_ofm < 4:
2013 # Fix the ofm dimension when keep_dims is false
2014 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
2015 if isinstance(axis, int) and dims_ofm + 1 == dims:
2016 ofm_shape.insert(axis, 1)
2017 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
2018 for i in axis:
2019 ofm_shape.insert(i, 1)
2020 shape = extend_dims(dims, shape)
2021 dims_ofm = len(ofm_shape)
2022 ofm_shape = extend_dims(dims_ofm, ofm_shape)
2023 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002024
Alexander Hansson90c34b52023-05-31 15:03:03 +00002025 # Compute kernel sizes for our convolutions
2026 h = shape[1] if reduce_h else 1
2027 w = shape[2] if reduce_w else 1
2028 num_elements_in_axis = h * w
2029
2030 # If one convolution is enough, but height is greater than max kernel height
2031 # reshape from HxW to 1x(HxW)
2032 # This can only be done if the mean is computed over both H and W
2033 if h > max_height and num_elements_in_axis <= max_kernel_size and reduce_h and reduce_w:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002034 shape = [shape[0], 1, h * w, shape[3]]
2035 op.ifm_shapes[0] = Shape4D(shape)
Alexander Hansson90c34b52023-05-31 15:03:03 +00002036 op.ifm.shape = shape
2037 w = h * w
2038 h = 1
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002039
Alexander Hansson90c34b52023-05-31 15:03:03 +00002040 intermediate_op = None
2041 height_per_conv = min(max_kernel_size // w, h)
2042 height_per_conv = min(height_per_conv, max_height)
2043 num_convs = math.ceil(h / height_per_conv)
2044 convs = list()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002045
Alexander Hansson90c34b52023-05-31 15:03:03 +00002046 for i in range(num_convs):
2047 is_last_op = i == (num_convs - 1)
2048
2049 intermediate_op = op.clone(f"{op.name}_conv_{i}")
2050
2051 intermediate_op.type = Op.DepthwiseConv2DBias
2052
2053 # Set necessary depthwise attributes
2054 intermediate_op.attrs.update(
2055 {
2056 "padding": Padding.VALID,
2057 "stride_h": 1,
2058 "stride_w": 1,
2059 "strides": (1, 1, 1, 1),
2060 "depth_multiplier": 1,
2061 "channel_multiplier": 1,
2062 "dilation_h_factor": 1,
2063 "dilation_w_factor": 1,
2064 "dilation": (1, 1, 1, 1),
2065 }
2066 )
2067
2068 b, _, _, c = shape
2069
2070 intermediate_tensor = op.ofm.clone(suffix=f"_conv_sum_{i}", set_unique=True)
2071 intermediate_tensor.dtype = DataType.int32
2072 intermediate_op.set_output_tensor(intermediate_tensor)
2073
2074 # as we have several convs, scaling/rounding must be done after the sum has been calculated
2075 intermediate_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1])
2076
2077 # compute height for the kernel
2078 if is_last_op and h % height_per_conv != 0:
2079 weight_h = h % height_per_conv
2080 else:
2081 weight_h = height_per_conv
2082
2083 # compute ifm read offset and shape for the convolution
2084 read_shape_h = weight_h if reduce_h else shape[1]
2085 read_shape_w = w if reduce_w else shape[2]
2086
2087 intermediate_op.read_offsets[0] = Shape4D([0, i * height_per_conv, 0, 0])
2088 intermediate_op.read_shapes[0] = Shape4D(shape).with_hw(read_shape_h, read_shape_w)
2089
2090 weight_quant = QuantizationParameters(0, 255, scale_f32=1.0, zero_point=0)
2091 weight_shape = [weight_h, w, c, b]
2092 weight_tensor = create_const_tensor(
2093 f"{intermediate_op.name}_weights",
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002094 weight_shape,
Alexander Hansson90c34b52023-05-31 15:03:03 +00002095 DataType.uint8,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002096 np.ones(weight_shape),
Alexander Hansson90c34b52023-05-31 15:03:03 +00002097 TensorPurpose.Weights,
2098 quantization=weight_quant,
2099 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002100
Alexander Hansson90c34b52023-05-31 15:03:03 +00002101 weights_1D = np.ones(np.prod(weight_shape))
2102 weight_tensor.equivalence_id = create_equivalence_id(tuple(weights_1D))
2103 weight_tensor.value_id = weight_tensor.equivalence_id
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002104
Alexander Hansson90c34b52023-05-31 15:03:03 +00002105 intermediate_op.set_input_tensor(weight_tensor, 1)
Rickard Bolina68b82a2023-04-20 15:12:28 +00002106
Alexander Hansson90c34b52023-05-31 15:03:03 +00002107 dtype = DataType.int64 if intermediate_op.ifm.dtype == DataType.int16 else DataType.int32
2108 bias_values = [0] * c
2109 bias = create_const_tensor(f"{intermediate_op.name}_bias", [c], dtype, bias_values)
2110 bias.equivalence_id = create_equivalence_id(tuple(bias_values))
2111 bias.value_id = bias.equivalence_id
2112 intermediate_op.inputs.append(bias)
2113 intermediate_op.set_ifm_ofm_shapes()
Johan Alfven7b3008a2023-04-13 18:54:47 +02002114
Alexander Hansson90c34b52023-05-31 15:03:03 +00002115 # We want to avoid reshaping the tensor directly, to not affect other ops
2116 # so we update the shape explicitly for this operation
2117 intermediate_op.ifm_shapes[0] = Shape4D(shape)
Rickard Bolina68b82a2023-04-20 15:12:28 +00002118
Alexander Hansson90c34b52023-05-31 15:03:03 +00002119 convs.append(intermediate_op)
2120 DebugDatabase.add_optimised(op, intermediate_op)
2121
2122 # If we have more than one convolution
2123 # We use add operations to accumulate the intermediate tensors
2124 if len(convs) > 1:
2125 prev_add_op = None
2126 idx = 0
2127
2128 while len(convs):
2129 intermediate_tensor = op.ofm.clone(suffix=f"_add_sum_{idx}", set_unique=True)
2130 intermediate_tensor.dtype = DataType.int32
2131
2132 one_scale_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
2133
2134 ifm = convs.pop().ofm
2135 if not prev_add_op:
2136 ifm2 = convs.pop().ofm
2137 else:
2138 ifm2 = prev_add_op.ofm
2139
2140 intermediate_op = create_add(f"{op.name}_add_{idx}", ifm, ifm2, one_scale_quant)
2141 intermediate_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1])
2142 intermediate_op.set_output_tensor(intermediate_tensor)
2143 intermediate_op.set_ifm_ofm_shapes()
2144
2145 prev_add_op = intermediate_op
2146 idx += 1
2147
2148 DebugDatabase.add_optimised(op, intermediate_op)
2149
2150 # Convert the original mean op to our final Mul operation
2151 # Which scales and divides by num_elements_in_axis
2152 op.type = Op.Mul
2153 op.name = f"{op.name}_mul"
2154 op.attrs = {}
2155 op.set_input_tensor(intermediate_op.ofm, 0)
Rickard Bolina68b82a2023-04-20 15:12:28 +00002156
Johan Alfven7b3008a2023-04-13 18:54:47 +02002157 # The multiplier is calculated in the same way as in the reference,
2158 # clamping the shift value at the price of some precision loss.
Johan Alfven7b3008a2023-04-13 18:54:47 +02002159 output_multiplier, output_shift_vela = quantise_scale(np.double(ifmq.scale_f32) / np.double(ofmq.scale_f32))
2160
2161 # Convert to reference representation shift value
2162 output_shift = 31 - output_shift_vela
2163
2164 # Reference calculation
2165 # round_down_log2 same as 63 - CountLeadingZeros(num_elements_in_axis)
2166 shift = round_down_log2(num_elements_in_axis)
2167 shift = min(shift, 32)
2168 shift = min(shift, 31 + output_shift)
2169 output_multiplier = (output_multiplier << shift) // num_elements_in_axis
2170 output_shift = output_shift - shift
2171
2172 # Convert to vela representation shift
2173 output_shift_vela = 31 - output_shift
2174
2175 # For int32 scaling is not supported so instead multiply with the scale
2176 # intermediate * scale -> round and shift.
Alexander Hansson90c34b52023-05-31 15:03:03 +00002177 identity_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
Johan Alfven7b3008a2023-04-13 18:54:47 +02002178 scalar = create_const_tensor(
2179 op.name + "_scalar", [1, 1, 1, 1], DataType.int32, [output_multiplier], quantization=identity_quant
2180 )
Alexander Hansson90c34b52023-05-31 15:03:03 +00002181 op.set_input_tensor(scalar, 1)
2182 op.set_ifm_ofm_shapes()
Johan Alfven7b3008a2023-04-13 18:54:47 +02002183
2184 # Reference using TFL rounding for the multiply
Alexander Hansson90c34b52023-05-31 15:03:03 +00002185 op.rounding_mode = RoundingMode.TFLite
Johan Alfven7b3008a2023-04-13 18:54:47 +02002186
2187 # Need to use explicit scaling to get the wanted shift
Alexander Hansson90c34b52023-05-31 15:03:03 +00002188 op.explicit_scaling = ExplicitScaling(False, [output_shift_vela], [1])
2189 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002190 return op
2191
2192
Raul Farkas66207142023-05-25 11:15:20 +01002193def convert_ops_to_lut(op: Operation, arch, nng) -> Operation:
2194 """Convert Exp to 8bit or 16bit LUT to allow for support on NPU."""
Johan Alfvence502732023-04-24 13:35:40 +02002195 if op.type == Op.Exp:
2196 if op.ifm.dtype == DataType.int8:
2197 return create_lut_8bit_op(op, math.exp, "exp")
2198 elif op.ifm.dtype == DataType.int16:
2199 return create_lut_int16_op(op, math.exp, "exp")
2200 else:
2201 # Should already be catched in tflite supported ops
2202 assert False, f"Unsupported data type {op.ifm.dtype} for {op.type}"
2203
Johan Alfven8e525ca2023-05-07 13:12:37 +02002204 if op.type == Op.Rsqrt:
2205 return create_lut_rsqrt_int8_op(op)
2206
Johan Alfvence502732023-04-24 13:35:40 +02002207 return op
2208
2209
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002210def optimise_quantize(op: Operation, arch, nng):
2211
2212 if op.type == Op.Quantize and op.run_on_npu:
2213
2214 ifm, ofm = op.get_ifm_ofm()
2215 input_values = ifm.values
2216
2217 # Guard clause - input not const or no values to quantize
2218 if ifm.ops[0].type != Op.Const or input_values is None:
2219 return op
2220
2221 # Singular val in numpy array, convert to indexable array
2222 if input_values.ndim == 0:
2223 input_values = np.array([input_values])
2224
Fredrik Svedberg11563172022-07-06 14:54:12 +02002225 # requantized int8 to int8 or int16 to int16
2226 if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002227
2228 # scale needs to use double precision to match TFLite reference kernel
2229 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
2230 effective_multiplier, effective_shift = quantise_scale(effective_scale)
2231
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002232 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002233 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002234 input_val = val - ifm.quantization.zero_point
2235
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002236 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
2237 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002238
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002239 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
2240 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002241
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002242 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
2243 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002244
2245 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002246 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002247
2248 quantized_vals = []
2249 for val in input_values:
2250
2251 # Derive quantized value
2252 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002253 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
2254 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002255
2256 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002257 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
2258
2259 # Unsupported data type
2260 else:
2261 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002262
2263 # Make quantize op const and disconnect from parent node
2264
2265 # Remove reference of the current quant op from the parent tensor's consumer list
2266 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
2267
2268 # Clear any references to parent node
2269 op.inputs = []
2270
2271 # Convert this quantize op to const
2272 op.type = Op.Const
2273
2274 return op
2275
2276
Ayaan Masood4965fae2022-06-29 11:30:57 +01002277def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
2278 """Static optimisation for SHAPE operator output value known at compile time"""
2279
2280 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
2281
2282 if op.type == Op.Shape and op.run_on_npu:
2283
2284 ifm, ofm = op.get_ifm_ofm()
2285
2286 if len(ifm.shape) != ofm.shape[0]:
2287 return op
2288
2289 # Remove reference of the current shape op from the parent tensor's consumer list
2290 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
2291
2292 # Clear any references to parent node
2293 op.inputs = []
2294
2295 # Convert this SHAPE op to const
2296 op.type = Op.Const
2297
2298 # Add size calculation to shape output tensors
2299 ofm.values = np.array(ifm.shape)
2300
2301 return op
2302
2303
Raul Farkas66207142023-05-25 11:15:20 +01002304def fixup_dilation_gt2(op: Operation, arch, nng) -> Operation:
2305 """Fixup Conv2DBias and DepthwiseConv2DBias to allow dilation greater than 2."""
Tim Hallea4ba662022-11-11 18:19:53 +00002306 assert op.run_on_npu
2307 if op.type == Op.Conv2DBias or op.type == Op.DepthwiseConv2DBias:
2308 dilation_w, dilation_h = op.get_kernel_dilation()
2309
2310 # if dilation in either axis is greater than that supported by the hardware then we must manually dilate the
2311 # kernel
2312 if dilation_w > 2 or dilation_h > 2:
2313 kernel_w, kernel_h = op.get_kernel_size()
2314 kernel_ic = op.weights.shape[-2]
2315 kernel_oc = op.weights.shape[-1]
2316
2317 # if the dilation is a multiple of 2 then the hardware dialtion can be enabled to provide that multiple
2318 # of 2. this allows the kernel size to be reduced (via the scaled dilation) by half in that dimension.
2319 # odd = 1, even = 2
2320 hw_dilation_h = 1 if (dilation_h & 1) else 2
2321 hw_dilation_w = 1 if (dilation_w & 1) else 2
2322
2323 scale_dilation_h = dilation_h // hw_dilation_h
2324 scale_dilation_w = dilation_w // hw_dilation_w
2325
2326 # create new empty kernel (HWIO format)
2327 new_kernel_h = (kernel_h - 1) * scale_dilation_h + 1
2328 new_kernel_w = (kernel_w - 1) * scale_dilation_w + 1
2329
2330 new_kernel_shape = [new_kernel_h, new_kernel_w, kernel_ic, kernel_oc]
2331 new_kernel_values = np.zeros(new_kernel_shape, dtype=op.weights.values.dtype)
2332
2333 # copy the original kernel values into the new sparse kernel
2334 for h in range(0, kernel_h):
2335 for w in range(0, kernel_w):
2336 new_h = h * scale_dilation_h
2337 new_w = w * scale_dilation_w
2338 new_kernel_values[new_h, new_w, :, :] = op.weights.values[h, w, :, :]
2339
2340 # update the weight tensor with the new dilated kernel
2341 op.weights.shape = new_kernel_shape
2342 op.weights.values = new_kernel_values
2343
2344 # enable(=2) / disable(=1) hardware dilation
2345 op.attrs["dilation"] = (1, hw_dilation_h, hw_dilation_w, 1) # nhwc format
2346 op.attrs["dilation_h_factor"] = hw_dilation_h
2347 op.attrs["dilation_w_factor"] = hw_dilation_w
2348
2349 return op
2350
2351
Tim Hall2180a172023-03-10 18:11:34 +00002352def fixup_reshape(op, arch, nng):
2353 def _get_explicit_shape(implicit_shape, total_size):
2354 # the explicit shape is a copy of the implicit shape but with the special -1 (remaining size) value converted to
2355 # the appropriate value
2356 if implicit_shape is None:
2357 return None
2358
2359 explicit_shape = list(implicit_shape)
2360 if -1 in explicit_shape:
2361 explicit_shape[explicit_shape.index(-1)] = int(total_size / abs(np.prod(implicit_shape)))
2362
2363 return explicit_shape
2364
2365 if op.type == Op.Reshape:
2366 ifm_tensor, _, ofm_tensor = op.get_ifm_ifm2_ofm()
2367 ifm_size = ifm_tensor.elements()
2368 ofm_shape = ofm_tensor.shape
2369
2370 new_shape_tensor_shape = op.inputs[1].values.flatten() if len(op.inputs) > 1 else None
2371 new_shape_tensor_shape = _get_explicit_shape(new_shape_tensor_shape, ifm_size)
2372
2373 new_shape_attribute = op.attrs.get("new_shape", None)
2374 new_shape_attribute = _get_explicit_shape(new_shape_attribute, ifm_size)
2375
2376 # if present the new shape tensor overrides the new_shape attribute
2377 if new_shape_tensor_shape is not None:
2378 # check tensor
2379 if not np.array_equal(new_shape_tensor_shape, ofm_shape):
2380 print(
2381 f"Warning: {optype_to_builtintype(op.type)} '{op.name}' has new shape tensor"
2382 f" ({new_shape_tensor_shape}) that does not match output tensor shape {ofm_shape}. Will use output"
2383 f" tensor shape."
2384 )
2385 elif new_shape_attribute is not None:
2386 # check attribute
2387 if not np.array_equal(new_shape_attribute, ofm_shape):
2388 print(
2389 f"Warning: {optype_to_builtintype(op.type)} '{op.name}' has new_shape attribute"
2390 f" ({new_shape_attribute}) that does not match output tensor shape {ofm_shape}. Will use output"
2391 f" tensor shape."
2392 )
2393 else:
2394 print(
2395 f"Warning: {optype_to_builtintype(op.type)} '{op.name}' does not have a new shape tensor or a new_shape"
2396 f" attribute. Will use output tensor shape {ofm_shape}."
2397 )
2398
2399 # force new shape tensor to output shape
2400 new_shape_tensor = create_const_tensor(
2401 op.name + "_new_shape", [len(ofm_shape)], DataType.int32, np.array(ofm_shape, np.int32)
2402 )
2403 if len(op.inputs) > 1:
2404 op.set_input_tensor(new_shape_tensor, 1)
2405 else:
2406 op.add_input_tensor(new_shape_tensor)
2407
2408 # force new_shape attribute to output shape
2409 op.attrs["new_shape"] = ofm_shape
2410
2411 return op
2412
2413
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002414def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02002415 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002416 return op
2417
2418
wilisa0146c94772023-02-08 09:56:14 +00002419def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
Fredrik Svedberg11563172022-07-06 14:54:12 +02002420 # Compile time static optimisations
wilisa0146c94772023-02-08 09:56:14 +00002421 optimisation_list = [
2422 optimise_quantize,
2423 convert_shape_op_to_constant_tensor,
2424 fixup_or_check_asymmetric_weights(force_symmetric_int_weights),
2425 ]
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002426
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002427 for idx, sg in enumerate(nng.subgraphs):
2428 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002429 nng,
2430 sg,
2431 arch,
2432 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01002433 optimisation_list,
2434 rewrite_unsupported=False,
2435 )
2436
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002437 # Pre-processing step
Tim Hall2180a172023-03-10 18:11:34 +00002438 pre_process_list = [supported_operator_check, set_ifm_ofm_op_shapes, fixup_reshape]
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002439
Ayaan Masood4965fae2022-06-29 11:30:57 +01002440 for idx, sg in enumerate(nng.subgraphs):
2441 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
2442 nng,
2443 sg,
2444 arch,
2445 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02002446 pre_process_list,
2447 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002448 )
2449
2450 # Handle Concat Ops
2451 for idx, sg in enumerate(nng.subgraphs):
2452 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
2453 sg.refresh_after_modification()
2454
2455 # Handle Split Ops
2456 for idx, sg in enumerate(nng.subgraphs):
2457 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
2458 nng,
2459 sg,
2460 arch,
2461 [],
2462 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
2463 rewrite_unsupported=False,
2464 )
2465
2466 for idx, sg in enumerate(nng.subgraphs):
2467 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002468 nng,
2469 sg,
2470 arch,
2471 [rewrite_split_ops],
2472 [],
2473 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002474 )
2475
Johan Alfvena5e1b622023-02-02 14:59:03 +01002476 # Bypass or rewrite memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002477 for idx, sg in enumerate(nng.subgraphs):
2478 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002479 nng,
2480 sg,
2481 arch,
2482 [],
Johan Alfvena5e1b622023-02-02 14:59:03 +01002483 [bypass_memory_only_ops],
Jonas Ohlssond8575072022-03-30 10:30:25 +02002484 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002485 )
2486
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002487 # Rewrite of operators
2488 op_rewrite_list = [
2489 set_tensor_equivalence,
Johan Alfvence502732023-04-24 13:35:40 +02002490 convert_ops_to_lut,
Rickard Bolina68b82a2023-04-20 15:12:28 +00002491 convert_mean_to_depthwise_conv,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002492 convert_depthwise_to_conv,
2493 convert_conv_to_fc,
Fredrik Svedberg0ac08042023-04-11 22:35:04 +02002494 convert_lstm,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002495 convert_softmax,
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02002496 convert_prelu,
Fredrik Svedberg36424312022-09-16 09:39:26 +02002497 convert_mul_max_to_abs_or_lrelu,
2498 convert_lrelu,
Raul Farkas3e7157b2023-05-09 09:09:17 +01002499 convert_avg_pool_to_conv2d,
Raul Farkas69782af2023-05-09 10:39:52 +01002500 fixup_strided_conv,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002501 convert_hardswish_to_lut,
2502 rewrite_fully_connected_input,
2503 convert_batched_fc_shape,
2504 fixup_conv2d_backprop,
2505 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002506 reorder_depthwise_weights,
Rickard Bolin6986a072022-12-19 12:33:40 +00002507 convert_argmax_to_depthwise_conv_and_max_pool,
Tim Hall885033b2022-07-21 11:46:03 +01002508 fixup_resize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002509 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01002510 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002511 convert_tanh_sigmoid_to_lut,
2512 replace_pad_by_hw_pad,
Tim Hallea4ba662022-11-11 18:19:53 +00002513 fixup_dilation_gt2,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002514 ]
2515
2516 for idx, sg in enumerate(nng.subgraphs):
2517 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002518 nng,
2519 sg,
2520 arch,
2521 [],
2522 op_rewrite_list,
2523 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002524 )
2525
2526 for idx, sg in enumerate(nng.subgraphs):
2527 # remove passthrough tensors and attempt further optimizations
2528 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
2529 nng,
2530 sg,
2531 arch,
2532 [remove_passthrough_tensor],
2533 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
2534 )
2535
2536 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
2537 # since ifm/ofm_shapes are of importance to this function
2538 for sg in nng.subgraphs:
2539 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
2540 sg.refresh_after_modification()
2541
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01002542 # Make sure that const optimisations on subgraph outputs are handled correctly
2543 for sg in nng.subgraphs:
2544 for ofm in sg.output_tensors:
2545 if ofm.is_const and ofm.ops[0].type_changed:
2546 # Subgraph output cannot be const - insert a memory copy
2547 op = ofm.ops[0]
2548 ofm_clone = ofm.clone()
2549 ofm_clone.values = ofm.values
2550 ofm.values = None
Tim Hall3b1578e2023-01-13 17:57:25 +00002551 zero = create_const_tensor("zero", [1], ofm.dtype, [0], quantization=ofm.quantization)
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01002552 memcpy = create_add_nop(f"{ofm.name}_copy")
2553 memcpy.add_input_tensor(ofm_clone)
2554 memcpy.add_input_tensor(zero)
2555 memcpy.set_output_tensor(ofm)
2556 memcpy.set_ifm_ofm_shapes()
2557 op.set_output_tensor(ofm_clone)
2558 DebugDatabase.add_optimised(op, memcpy)
2559
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002560 return nng