blob: f686feaf9d4fcd3f27b0329d5de8a4767b6f7ee4 [file] [log] [blame]
Tim Hall3b1578e2023-01-13 17:57:25 +00001# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020017# Description:
18# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
19# to do the traversal of the graph.
Raul Farkas10d6b3b2023-01-30 12:58:46 +000020from __future__ import annotations
21
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020022import math
23import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020024
25import numpy as np
26
27from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020028from . import rewrite_graph
29from . import scaling
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020030from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020031from .data_type import DataType
32from .debug_database import DebugDatabase
33from .errors import UnsupportedFeatureError
34from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020035from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020036from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020037from .graph_optimiser_util import convert_depthwise_to_conv
Fredrik Svedberg0ac08042023-04-11 22:35:04 +020038from .graph_optimiser_util import create_avg_pool_for_concat
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020039from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020040from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020041from .graph_optimiser_util import needed_total_padding
42from .graph_optimiser_util import set_ifm_ofm_op_shapes
43from .graph_optimiser_util import set_tensor_equivalence
Fredrik Svedberg0ac08042023-04-11 22:35:04 +020044from .lstm import Lstm
Johan Alfvence502732023-04-24 13:35:40 +020045from .lut import convert_to_lut
46from .lut import create_lut_8bit_op
47from .lut import create_lut_int16_op
Johan Alfven8e525ca2023-05-07 13:12:37 +020048from .lut import create_lut_rsqrt_int8_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020049from .numeric_util import clamp_sigmoid
Johan Alfven56811e62023-03-27 11:33:50 +020050from .numeric_util import full_shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020051from .numeric_util import round_away_zero
Johan Alfven7b3008a2023-04-13 18:54:47 +020052from .numeric_util import round_down_log2
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020053from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020054from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020055from .operation import NpuBlockType
56from .operation import Op
57from .operation import Operation
58from .operation import Padding
Tim Hall5ff4cd12023-05-16 22:39:14 +010059from .operation import RoundingMode
Alexander Hansson90c34b52023-05-31 15:03:03 +000060from .operation_util import create_add
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +010061from .operation_util import create_add_nop
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020062from .operation_util import create_avgpool_nop
Johan Alfvenc1ad80b2023-03-31 10:19:23 +020063from .operation_util import create_cast_op
Rickard Bolin6986a072022-12-19 12:33:40 +000064from .operation_util import create_depthwise_maxpool
Johan Alfvenc1ad80b2023-03-31 10:19:23 +020065from .operation_util import create_memcpy
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020066from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010067from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020068from .shape4d import Shape4D
69from .softmax import SoftMax
70from .tensor import check_quantized_tens_scaling_equal
71from .tensor import create_const_tensor
72from .tensor import create_equivalence_id
73from .tensor import QuantizationParameters
74from .tensor import Tensor
75from .tensor import TensorPurpose
76from .tflite_mapping import optype_to_builtintype
Raul Farkas3b64f062023-05-16 17:18:31 +010077from .utils import calc_resize_factor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020078
79passthrough_nodes = (Op.Identity,)
80
81
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020082def remove_passthrough_tensor(tens, arch, nng):
83 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
84 assert len(tens.ops[0].inputs) == 1
85 tens = tens.ops[0].inputs[0]
86 return tens
87
88
89def rewrite_concat_ops(op, arch):
90 if not op.run_on_npu or not op.type.is_concat_op():
91 return
92
93 axis_4D = 0
94 ofm = op.ofm
95 ofm.ops = []
96 offset = 0
97
98 unfuse_activation_function(op)
99
100 if op.type == Op.Pack:
101 # Pack is also referred to as Stack
102 axis = int(op.attrs["axis"])
103 if axis < 0: # Convert to positive axis
104 axis = len(op.inputs[0].shape) + 1 + axis
105
106 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
107
108 axis_4D = axis + (4 - len(desired_shape))
109
110 for idx, inp in enumerate(op.inputs):
111 op.ifm_shapes[idx] = Shape4D(desired_shape)
112 op.type = Op.PackReshaped
113
114 inputs, axis = op.get_concat_inputs_axis()
115 for idx, inp in enumerate(inputs):
116 if op.type != Op.PackReshaped:
117 op.ifm_shapes[idx] = Shape4D(inp.shape)
118 if axis >= 0:
119 axis_4D = axis + (4 - len(inp.shape))
120 else:
121 axis_4D = axis
122 write_offset = [0, 0, 0, 0]
123 write_offset[axis_4D] = offset
124 concat_end = offset + op.ifm_shapes[idx][axis_4D]
125 create_avg_pool_for_concat(
126 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
127 )
128 offset = concat_end
129 assert ofm.shape[axis] == offset
130
131 return op
132
133
134def rewrite_split_ops(tens, arch, nng):
135
136 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
137 split_op = tens.ops[0]
138
139 # Not supported so leave it and run on CPU
140 if not split_op.run_on_npu:
141 return tens
142
143 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
144
145 tens.ops = []
146 new_op = Operation(Op.SplitSliceRead, split_op.name)
147 new_op.inputs = [inp]
148 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000149 if None in (offset_end, offset_start):
150 read_shape = None
151 else:
152 # the read shape is relative to each start offset
William Isakssona71efe02023-07-12 12:28:05 +0000153 read_shape = Shape4D([oe - os for oe, os in zip(offset_end, offset_start)])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200154
155 # For Split the offset cannot be extracted from the tensor so it has to
156 # be calculated from the index of the output tensor
157 if axis is not None:
158 # Get the start and end of the split
159 offset_start = [0] * 4
160 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
161 for idx, out in enumerate(outputs):
162 if axis_4D_list is not None:
163 axis_4D = axis_4D_list[idx]
164 else:
165 split_op.ofm_shapes[idx] = Shape4D(out.shape)
166 if axis >= 0:
167 axis_4D = axis + (4 - len(out.shape))
168 else:
169 axis_4D = axis
170
171 if out == tens:
172 ofm_shape_idx = idx
173 read_shape = split_op.ofm_shapes[idx]
174 break
175
176 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
177
178 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
179 new_op.read_shapes[0] = read_shape
180 new_op.run_on_npu = True
181 new_op.set_output_tensor(tens)
182 new_op.ifm_shapes.append(Shape4D(inp.shape))
183 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
184 DebugDatabase.add_optimised(split_op, new_op)
185
186 return tens
187
188
189def remove_SplitSliceRead(op, arch):
190
191 if op.type == Op.SplitSliceRead:
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200192 # Check if it is possible to put the SplitSliceRead on the tensor consumer(s),
193 # or if an avgpool need to be inserted
194 if op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape) and all(
195 consumer is not None and consumer.run_on_npu and consumer.type not in memory_only_ops
196 for consumer in op.ofm.consumer_list
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200197 ):
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200198 # SplitSliceRead can be performed by tensor consumer(s)
199 for cons_op in list(op.ofm.consumer_list):
200 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200201 else:
202 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
203 avgpool_op.add_input_tensor(op.ifm)
204 avgpool_op.outputs = [op.ofm]
205 op.ofm.ops.remove(op)
206 op.ofm.ops.append(avgpool_op)
207 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
208 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
209 avgpool_op.read_offsets[0] = op.read_offsets[0]
210 avgpool_op.read_shapes[0] = op.read_shapes[0]
211
212 op.ifm.consumer_list.remove(op)
213 DebugDatabase.add_optimised(op, avgpool_op)
214
215
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200216def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
217 k_w, k_h = kernel.dilated_wh()
218 s_x, s_y = kernel.stride
219 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
220 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
221 if padding_type == Padding.SAME:
222 left_pad = (xpad + 0) // 2
223 right_pad = (xpad + 1) // 2
224 top_pad = (ypad + 0) // 2
225 bottom_pad = (ypad + 1) // 2
226 elif padding_type == Padding.VALID:
227 left_pad = 0
228 right_pad = 0
229 top_pad = 0
230 bottom_pad = 0
231 elif padding_type == Padding.EXPLICIT:
232 # Padding is specified in a PAD operator which has been bypassed.
233 top, left, bottom, right = explicit_padding
234 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
235 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Rickard Bolin9ae34552022-06-09 13:07:17 +0000236 elif padding_type == Padding.TILE:
237 # The values in the explicit padding only represent the "direction" in which to pad
238 top_pad, left_pad, bottom_pad, right_pad = explicit_padding
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200239 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000240 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200241 padding = (top_pad, left_pad, bottom_pad, right_pad)
242 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
243 return padding, skirt
244
245
Johan Alfvenc0bb8682023-09-04 17:18:33 +0200246def calc_upscaled_padding_and_skirt(
247 padding_type, kernel_size, stride, input_shape, upscaling_factor_y, upscaling_factor_x
248):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200249 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
250 if padding_type == Padding.SAME:
Johan Alfvenc0bb8682023-09-04 17:18:33 +0200251 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor_y, int(stride[1]), int(kernel_height))
252 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor_x, int(stride[2]), int(kernel_width))
253 right_pad = max(((xpad + 1) // upscaling_factor_x) - 1, 0)
254 bottom_pad = max(((ypad + 1) // upscaling_factor_y) - 1, 0)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200255 left_pad = max(kernel_width - 1 - right_pad, 0)
256 top_pad = max(kernel_height - 1 - bottom_pad, 0)
257 elif padding_type == Padding.VALID:
258 right_pad = max(kernel_width - 2, 0)
259 bottom_pad = max(kernel_height - 2, 0)
260 left_pad = kernel_width - 1
261 top_pad = kernel_height - 1
262 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000263 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200264 padding = (top_pad, left_pad, bottom_pad, right_pad)
265 skirt = padding
266 return padding, skirt
267
268
Raul Farkas66207142023-05-25 11:15:20 +0100269def fixup_conv2d_backprop(op: Operation, arch, nng) -> Operation:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200270 if op.type == Op.Conv2DBackpropInput:
271 # flip the inputs
272 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
273 op.type = Op.Conv2DBackpropInputSwitchedBias
Johan Alfvenc0bb8682023-09-04 17:18:33 +0200274 stride_w = op.kernel.stride.x
275 stride_h = op.kernel.stride.y
276 if stride_w > 1 or stride_h > 1:
277 # Transpose conv2d with upscaling
278 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200279
280 # Update strides
281 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
wilisa0179a89042022-11-02 17:18:43 +0000282 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200283
284 return op
285
286
287# Convert the op to an elementwise add
Tim Hall885033b2022-07-21 11:46:03 +0100288def convert_resize_1x1_to_add(op):
289 op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200290 op.name = op.name + "_add"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200291 # Create an input tensor filled with zeros
wilisa018289d512023-01-12 08:17:23 +0000292 name = op.inputs[1].name + "_add"
293 dtype = op.inputs[0].dtype
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200294 shape = op.ofm_shapes[0].as_list()
wilisa018289d512023-01-12 08:17:23 +0000295 values = np.zeros(shape, dtype.as_numpy_type())
296 quantization = QuantizationParameters(0.0, 255.0)
297 quantization.scale_f32 = 1.0
298 quantization.zero_point = 0
wilisa0116b5e5e2023-02-14 12:03:59 +0000299 op.inputs[1] = op.inputs[0]
300 op.set_input_tensor(create_const_tensor(name, shape, dtype, values, quantization=quantization), 0)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200301 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000302 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200303
304 return op
305
306
Tim Hall5ff4cd12023-05-16 22:39:14 +0100307# Convert ResizeNearestNeighbor with align corners to a depthwise convolution. The IFM will already have been upscaled
Tim Hall885033b2022-07-21 11:46:03 +0100308# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
309# to select the appropriate nearest neighbor value
310def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
311 ifm = op.ifm
312 ofm = op.ofm
313 output_depth = ofm.shape[-1]
314 dw_op_attrs = {
315 "padding": Padding.VALID,
316 "stride_h": 1,
317 "stride_w": 1,
318 "strides": (1, 1, 1, 1),
319 "depth_multiplier": 1,
320 "channel_multiplier": 1,
321 "dilation_h_factor": 1,
322 "dilation_w_factor": 1,
323 "dilation": (1, 1, 1, 1),
324 }
325
Tim Hall5ff4cd12023-05-16 22:39:14 +0100326 # change ResizeNearestNeighbor to Depthwise
Tim Hall885033b2022-07-21 11:46:03 +0100327 op.type = Op.DepthwiseConv2DBias
328 op.attrs.update(dw_op_attrs)
329 op.set_input_tensor(ifm, 0) # ifm tensor index
330 op.activation = None
331
332 # add input resample to resize by x2
333 op.ifm_resampling_mode = resampling_mode.NEAREST
334
335 # don't care about the rounding mode as it is nearest neighbor
336
337 # setup weight tensor
338 weight_quant = QuantizationParameters()
339 weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
340 weight_quant.zero_point = 0
341 weight_quant.quant_dim = 0
342 ofm_dtype = ofm.dtype
Tim Hall3b1578e2023-01-13 17:57:25 +0000343 if ofm_dtype.type == BaseType.UnsignedInt:
Tim Hall885033b2022-07-21 11:46:03 +0100344 weight_quant.quant_min = 0
345 weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
346 else:
Tim Hall885033b2022-07-21 11:46:03 +0100347 weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
348 weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
349
350 weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
351
352 # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
353 # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
354 # below-and-right (i.e. next) to it (D).
355 # 0---1---2
356 # | A | B |
357 # 1---*---+
358 # | C | D |
359 # 2---+---+
360 weight_values = [0] * (upscale_factor * upscale_factor)
361 centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
362 weight_values[centre_coeff] = 1
363
364 # add weight tensor, this will discard the size tensor of the resize op
365 op.set_input_tensor(
366 create_const_tensor(
367 "weights",
368 weight_shape,
Tim Hall3b1578e2023-01-13 17:57:25 +0000369 ofm_dtype,
Tim Hall885033b2022-07-21 11:46:03 +0100370 np.array(weight_values).reshape(weight_shape),
Tim Hall885033b2022-07-21 11:46:03 +0100371 quantization=weight_quant,
372 ),
373 1, # inputs tensor weight index
374 )
375
376 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
377 # need to append the bias tensor as resize ops only have 2 inputs
378 assert len(op.inputs) == 2
379 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +0200380 fixup_bias_tensors(op, None, None, DataType.int32)
Tim Hall885033b2022-07-21 11:46:03 +0100381
382 # finally update the shape incase we've change the tensor shapes or connections
383 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000384 DebugDatabase.add_optimised(op, op)
Tim Hall885033b2022-07-21 11:46:03 +0100385
386 return op
387
388
389# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
390# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
391# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
392def convert_resize_to_upscale_and_average_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200393 pre_op = op
394 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000395 dtype = op.ifm.dtype
Tim Hall885033b2022-07-21 11:46:03 +0100396
Rickard Boline546def2022-01-25 15:45:00 +0000397 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Tim Hall47c76362022-07-18 21:26:47 +0100398 op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
Tim Hall3c5cfe92022-03-16 16:31:57 +0000399 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200400
401 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
Tim Hall47c76362022-07-18 21:26:47 +0100402
403 # Get upscale factor that was calculated in the supported operators check
404 upscale_factor = op.attrs["upscale_factor"]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200405
Rickard Boline546def2022-01-25 15:45:00 +0000406 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100407 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
408 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
Rickard Boline546def2022-01-25 15:45:00 +0000409 n = int(np.log2(upscale_factor))
410
Tim Hall885033b2022-07-21 11:46:03 +0100411 # Perform x2 upscaling n-1 times
Rickard Boline546def2022-01-25 15:45:00 +0000412 scaled_op = pre_op
413 for count in range(n - 1):
414 if count > 0:
415 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200416 scaled_op.inputs[0] = pre_op.outputs[0]
417
Tim Hall885033b2022-07-21 11:46:03 +0100418 # Nearest neighbor x2 upscaling
Tim Hall47c76362022-07-18 21:26:47 +0100419 upscaled_shape = upscaled_shape * 2
Rickard Boline546def2022-01-25 15:45:00 +0000420 shape = op.ofm_shapes[0].as_list()
421 shape[1:3] = upscaled_shape
422 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
423 out_tens.quantization = op.outputs[0].quantization.clone()
424 scaled_op.set_output_tensor(out_tens)
425 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200426
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200427 scaled_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000428 DebugDatabase.add_optimised(op, scaled_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200429
Tim Hall885033b2022-07-21 11:46:03 +0100430 # Last x2 upscaling
Rickard Boline546def2022-01-25 15:45:00 +0000431 if n > 1:
432 scaled_op = op.clone(f"_{n-1}")
433 scaled_op.inputs[0] = pre_op.outputs[0]
Tim Hall885033b2022-07-21 11:46:03 +0100434
435 if scaled_op.original_type == Op.ResizeBilinear:
436 if scaled_op.attrs["align_corners"]:
437 # no padding
438 scaled_op.attrs["padding"] = Padding.VALID
439 else:
440 # padding to the right and bottom (limits average pool to 8x8 kernel)
441 scaled_op.attrs["padding"] = Padding.EXPLICIT
442 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
443
444 # kernal size dependent on the upscaling factor
445 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
446 else: # Op.ResizeNearestNeighbor
447 if scaled_op.attrs["align_corners"]:
448 # use depthwise conv to select the correct value
449 scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
450 else:
Johan Alfvéna64616c2022-10-17 12:29:12 +0200451 # Keep 1x1 kernel and average pool, this applies both when
452 # half-pixel-centers is True and False. Calculations are the
453 # same in the reference.
Tim Hall885033b2022-07-21 11:46:03 +0100454 pass
455
Rickard Boline546def2022-01-25 15:45:00 +0000456 scaled_op.outputs = outputs
457 scaled_op.outputs[0].ops = [scaled_op]
458 scaled_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000459 DebugDatabase.add_optimised(op, scaled_op)
Rickard Boline546def2022-01-25 15:45:00 +0000460
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200461 return op
462
463
Raul Farkas66207142023-05-25 11:15:20 +0100464def convert_argmax_to_depthwise_conv_and_max_pool(op: Operation, arch, nng) -> Operation:
Rickard Bolin6986a072022-12-19 12:33:40 +0000465 """
466 Convert ArgMax to DWConv2D->MaxPool->DWConv2D, see details below.
467
468 Example:
469 arr = [4, [00000100,
470 6, = 00000110, # <-- This is the largest value, so we're expecting argmax(arr) = 1
471 5] 00000101]
472
473 Use 16-bit precision and shift all values 7 bits to the left:
474 Shifted_arr = [0000001000000000,
475 0000001100000000,
476 0000001010000000]
477
478 Add "c - index of channel" to each channel:
479 Shifted_arr_plus_reverse_idx = [0000001000000010, (+2)
480 0000001100000001, (+1)
481 0000001010000000] (+0)
482
483 The index is reversed since ArgMax selects the lowest index if maximum value is found at two index. The index will
484 act as a tie-breaker between channels with equal values and since we want the smallest channel index to be chosen
485 we reverse the index before the maxpool and then subtract the index from the number of channel after the maxpool to
486 get the correct index.
487
488 Find the maximum value in the array:
489 val = max(shifted_arr_plus_reverse_idx) = 0000001100000001
490
491 Subtract the value from the number of channels:
492 shifted_arr_plus_idx = (c-1) - val = 2 - 1 = 1
493
494 Extract the 7 lowest bits using a LUT to cut off the 9 most significant bits:
495 idx = LUT(val) = 0000000000000001 = 1
496 """
497
498 if op.type == Op.ArgMax:
499 ifm, ofm = op.inputs[0], op.outputs[0]
500 identity_quant = QuantizationParameters()
501 identity_quant.zero_point = 0
502 identity_quant.scale_f32 = 1.0
Rickard Bolin6986a072022-12-19 12:33:40 +0000503 # Add last dimension to ofm shape
504 ofm.shape += [1]
505 ofm.ops = []
506
507 # Create 1x1 Depthwise convolution with 2**7 weights for each channel to convert precision to 16 bit and shift
508 # all values 7 bits to the left
509 # Set necessary depthwise attributes
510 dw_op_attrs = {
511 "padding": Padding.VALID,
512 "stride_h": 1,
513 "stride_w": 1,
514 "strides": (1, 1, 1, 1),
515 "depth_multiplier": 1,
516 "channel_multiplier": 1,
517 "dilation_h_factor": 1,
518 "dilation_w_factor": 1,
519 "dilation": (1, 1, 1, 1),
520 "explicit_padding": None,
521 }
Johan Alfvenc1ad80b2023-03-31 10:19:23 +0200522 orig_name = op.name
523 op.name = f"{orig_name}_depthwise_conv_SHL_7"
Rickard Bolin6986a072022-12-19 12:33:40 +0000524 op.type = Op.DepthwiseConv2DBias
525 op.attrs.update(dw_op_attrs)
Johan Alfven56811e62023-03-27 11:33:50 +0200526 n, h, w, c = full_shape(4, ifm.shape, 1)
Rickard Bolin6986a072022-12-19 12:33:40 +0000527 shape = [1, 1, 1, c]
528 kernel = np.dstack([2**7] * c)
529 op.inputs = []
530 op.add_input_tensor(ifm)
531 op.add_input_tensor(
532 create_const_tensor(
533 "weights",
534 shape,
535 DataType.uint8,
536 np.array(kernel).reshape(shape),
537 quantization=identity_quant,
538 ),
539 )
540 # Let the bias for each channel be the "reverse" index of the channel it is in, ie c - channel_idx
541 reverse_idxs = list(reversed(range(c)))
542 bias_tensor = create_const_tensor(op.name + "_bias", [c], DataType.int64, reverse_idxs)
543 op.add_input_tensor(bias_tensor)
544
545 intermediate_tens = Tensor([n, h, w, c], DataType.int16, "int16_and_shifted_7_bits_left")
546 intermediate_tens.quantization = ifm.quantization
547 op.set_output_tensor(intermediate_tens)
548 op.set_ifm_ofm_shapes()
549 orig_ifm_shape = op.ifm_shapes[0]
550 DebugDatabase.add_optimised(op, op)
551
552 # To extract 7 least significant bits and swap reverse index back to real index using a LUT activation, we set
553 # the base value to c-1 and slope to -128. The 16-bit LUT uses a table of 32-bit values where the top 16 bits
554 # represent the slope and bottom 16 bits the base which are used to interpolate the activation value.
555 slope = (-128 & 0xFFFF) << 16 # Top 16 bits of 32 bit LUT table value
556 base = c - 1 # Bottom 16 bits of the LUT table value
557 lut_tensor = create_const_tensor(
558 "maxpool_LUT_extract_7_LSB",
559 [1, 1, 1, 512],
560 DataType.uint32,
561 [slope + base] * 512,
562 TensorPurpose.LUT,
563 )
564
565 # Split large feature maps into smaller chunks since the Depthwise Maxpool height dimension can overflow due to
566 # flattening the ifm to (H*W)xCx1
567 max_height = 2**16 // orig_ifm_shape.width
568 num_full_height_ops = orig_ifm_shape.height // max_height
569 last_op_height = orig_ifm_shape.height - max_height * num_full_height_ops
570 op_heights = [max_height] * num_full_height_ops
571 if last_op_height > 0:
572 op_heights.append(last_op_height)
573
574 # Create maxpool output tensor which is reshaped to 1x(H*W)x1x1. The product H*W might be larger than the
575 # maximum allowed height, but that's handled by reading and writing the data in chunks
576 maxpool_ofm = Tensor([1, orig_ifm_shape.height * orig_ifm_shape.width, 1, 1], DataType.int16, "argmax_maxpool")
577 maxpool_ofm.quantization = identity_quant
578
579 for op_idx, op_height in enumerate(op_heights):
580 maxpool_op = create_depthwise_maxpool(
581 f"dw_maxpool_{op_idx}", intermediate_tens, orig_ifm_shape, identity_quant
582 )
583 maxpool_op.outputs = [maxpool_ofm]
584 maxpool_ofm.ops.append(maxpool_op)
585 maxpool_op.ofm_shapes = [Shape4D(maxpool_ofm.shape)]
586 maxpool_op.set_activation_lut(lut_tensor)
587
588 # Set read and write shapes/offsets to read/write chunks of the IFM/OFM
589 maxpool_op.read_shapes[0] = Shape4D([1, op_height * orig_ifm_shape.width, orig_ifm_shape.depth, 1])
590 maxpool_op.read_offsets[0] = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0])
591 maxpool_op.write_shape = Shape4D([1, op_height * orig_ifm_shape.width, 1, 1])
592 maxpool_op.write_offset = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0])
593 DebugDatabase.add_optimised(op, maxpool_op)
594
Johan Alfvenc1ad80b2023-03-31 10:19:23 +0200595 # Set final shape
596 maxpool_ofm.set_all_shapes([1, h, w, 1])
597
598 # Convert 16bit to 32bit or 64bit
599 if ofm.dtype == DataType.int64:
600 # If OFM dtype is int64 the result is converted by two cast ops (16bit to 32bit)
601 #
602 # A -> B -> C -> D (OFM)
603 # |0001| |00010000| |0001|0000| |00010000|00000000|
604 # i16 i32 i16 i16 i32 i32
605 # <-------i64------->
606 #
607 # Memcpy is used to copy the content from B to C and from D to OFM
608 # Memcpy will be turned into a nop or an DMA transer if memory regions differs.
609 intermediate_32bit = Tensor([1, h, w, 1], DataType.int32, f"{orig_name}_32bit")
610 else:
611 intermediate_32bit = ofm
612
613 op_cast = create_cast_op(f"{orig_name}_cast_to_32bit_1", maxpool_ofm, intermediate_32bit)
614 DebugDatabase.add_optimised(op, op_cast)
615
616 if ofm.dtype == DataType.int64:
617 # Create int16 tensor with double shape to cover the intermediate_32bit result from the first cast
618 intermediate_16bit_2x_size = Tensor([1, h, w, 2], DataType.int16, f"{orig_name}_16bit_2x_size")
619 memcpy_op = create_memcpy(f"{orig_name}_memcpy_1", intermediate_32bit, intermediate_16bit_2x_size)
620 DebugDatabase.add_optimised(op, memcpy_op)
621
622 # Create int32 tensor with double ofm shape to be able to store a "int64" result
623 intermediate_32bit_2x_size = Tensor([1, h, w, 2], DataType.int32, f"{orig_name}_32bit_2x_size")
624
625 op_cast = create_cast_op(
626 f"{orig_name}_cast_to_32bit_2", intermediate_16bit_2x_size, intermediate_32bit_2x_size
627 )
628 DebugDatabase.add_optimised(op, op_cast)
629
630 memcpy_op = create_memcpy("f{orig_name}_memcpy_2", intermediate_32bit_2x_size, ofm)
631 DebugDatabase.add_optimised(op, memcpy_op)
Rickard Bolin6986a072022-12-19 12:33:40 +0000632
633 return op
634
635
Rickard Bolinfea15162022-07-04 16:19:16 +0000636def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
637 def _compute_interpolation_values(index, input_size, output_size):
638 scale = input_size / output_size
639 scaled_value = (index + 0.5 * half_pixel_centers) * scale - 0.5 * half_pixel_centers
640 lower_bound = max(np.floor(scaled_value), 0)
641
642 return scaled_value, lower_bound
643
644 def _compute_kernels(input_height, input_width, output_height, output_width):
645 kernels = []
646 for y in (1, 2):
647 for x in (1, 2):
648 sv_h, lb_h = _compute_interpolation_values(y, input_height, output_height)
649 sv_w, lb_w = _compute_interpolation_values(x, input_width, output_width)
650
651 # Interpolation values calculated for (x, y) = ([1, 2], [1, 2]) will always generalize to the whole
652 # input for upscale = 2 and input sizes >= 2x2 and be in the correct order for going left-to-right,
653 # top-to-bottom - same as the depthwise convolution strides across each tile
654 kernel = np.zeros((2, 2))
655 kernel[1, 1] = (1 - (sv_h - lb_h)) * (1 - (sv_w - lb_w))
656 kernel[0, 1] = (sv_h - lb_h) * (1 - (sv_w - lb_w))
657 kernel[1, 0] = (1 - (sv_h - lb_h)) * (sv_w - lb_w)
658 kernel[0, 0] = (sv_h - lb_h) * (sv_w - lb_w)
659 kernel *= 16
660 kernels.append(kernel)
661
662 return kernels
663
664 def _build_convolutions(op, kernels):
665 dw_op_attrs = {
666 "padding": Padding.TILE,
667 "stride_h": 1,
668 "stride_w": 1,
669 "strides": (1, 1, 1, 1),
670 "depth_multiplier": 1,
671 "channel_multiplier": 1,
672 "dilation_h_factor": 1,
673 "dilation_w_factor": 1,
674 "dilation": (1, 1, 1, 1),
675 }
676 ifm = op.ifm
677 ofm = op.ofm
678 ofm.ops = []
679 elem_size = 2 if ofm.dtype == DataType.int16 else 1
680
681 n, h, w, c = ifm.shape
682 _, _, ow, _ = ofm.shape
683
684 intermediate_tens = Tensor(ifm.shape, ifm.dtype, "intermediate_tens")
685 intermediate_tens.quantization = op.outputs[0].quantization.clone()
686 avgpool_op = op
687 avgpool_op.name = "rb_init_avgpool"
688 avgpool_op.type = Op.AvgPool
689 avgpool_op.attrs["padding"] = Padding.VALID
690 avgpool_op.attrs["stride_w"] = 1
691 avgpool_op.attrs["stride_h"] = 1
692 avgpool_op.attrs["filter_width"] = 1
693 avgpool_op.attrs["filter_height"] = 1
694 avgpool_op.attrs["strides"] = [1, 1, 1, 1]
695 avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
696
697 avgpool_op.add_input_tensor(ifm)
698 avgpool_op.set_output_tensor(intermediate_tens)
699 avgpool_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000700 DebugDatabase.add_optimised(op, op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000701
702 dw_conv = Operation(Op.DepthwiseConv2DBias, "depthwise_conv")
703 dw_conv._original_type = Op.ResizeBilinear
704 dw_conv.write_shape = Shape4D(n, h, w, c)
705 dw_conv.write_offset = Shape4D(0, 0, 0, 0)
706
Tim Hall5ff4cd12023-05-16 22:39:14 +0100707 # Resize bilinear requires rounding away from zero
708 dw_conv.rounding_mode = RoundingMode.AwayZero
Rickard Bolinfea15162022-07-04 16:19:16 +0000709
710 # Double height and width stride to write the output of each of the four depthwise convolutions below
711 # interleaved with each other when combined with OFM tile base offsets.
712 dw_conv.ofm_stride_multiplier = [1, 2, 2] # C/H/W
713
714 # Choose tile padding direction - pad by 1 with edge values in two direction.
715 # For example, TL (top left) will pad top and left in H/W-plane in all channels.
716 directions = [[1, 1, 0, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 1]] # TL, TR, BL, BR
717 for i in (0, 1):
718 for j in (0, 1):
719 index = i * 2 + j
720 dw_conv.name = f"depthwise_conv_{index}"
721 dw_op_attrs["explicit_padding"] = directions[index]
722 dw_conv.attrs.update(dw_op_attrs)
723
724 # This will offset the start of the write by modifying the Tile 0 base address
725 dw_conv.tile_base_offsets_ofm[0] = (i * ow + j) * c * elem_size
726
727 ofm.ops.append(dw_conv)
728 dw_conv.outputs = [ofm]
729
730 kernel = kernels[index]
731 shape = [2, 2, 1, c]
732 kernel = np.dstack([kernel] * c)
733
734 quant = QuantizationParameters()
735 quant.zero_point = 0
736 quant.scale_f32 = 1.0 / 16
737
738 dw_conv.inputs = []
739 dw_conv.add_input_tensor(intermediate_tens)
740 dw_conv.add_input_tensor(
741 create_const_tensor(
742 "weights",
743 shape,
744 intermediate_tens.dtype,
745 np.array(kernel).reshape(shape),
Rickard Bolinfea15162022-07-04 16:19:16 +0000746 quantization=quant,
747 ),
748 )
749
750 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
751 # need to append the bias tensor as resize ops only have 2 inputs
752 assert len(dw_conv.inputs) == 2
753 dw_conv.inputs.append(None)
Rickard Bolin017b4cc2022-09-23 10:16:48 +0000754 fixup_bias_tensors(dw_conv, None, None, dtype=DataType.int32)
Rickard Bolinfea15162022-07-04 16:19:16 +0000755
756 dw_conv.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000757 DebugDatabase.add_optimised(op, dw_conv)
758
Rickard Bolinfea15162022-07-04 16:19:16 +0000759 dw_conv = dw_conv.clone(f"_{index}")
760 return op
761
762 _, input_height, input_width, _ = op.ifm.shape
763 _, output_height, output_width, _ = op.ofm.shape
764
765 kernels = _compute_kernels(input_height, input_width, output_height, output_width)
766 op = _build_convolutions(op, kernels)
767
768 return op
769
770
Raul Farkas66207142023-05-25 11:15:20 +0100771def fixup_resize(op: Operation, arch, nng) -> Operation:
772 """Fixup resize ops to increase support for ResizeNearestNeighbor cases."""
Tim Hall885033b2022-07-21 11:46:03 +0100773 if op.type.is_resize_op() and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200774 if op.ifm_shapes[0] == op.ofm_shapes[0]:
Tim Hall885033b2022-07-21 11:46:03 +0100775 # Bypass the resize op which is essentially a NOP
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200776 op.inputs = op.inputs[:1]
777 op.type = Op.Identity
778 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
Tim Hall885033b2022-07-21 11:46:03 +0100779 convert_resize_1x1_to_add(op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000780 elif op.type == Op.ResizeBilinear and op.attrs.get("half_pixel_centers", False):
781 convert_resizebilinear_to_depthwise_convolutions(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200782 else:
Tim Hall885033b2022-07-21 11:46:03 +0100783 convert_resize_to_upscale_and_average_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200784
785 return op
786
787
788def convert_nop_split_to_identity(op, arch, nng):
789 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
790 # the list comprehension should return a list with a single tensor
791 # if it shouldn't, remove_passthrough_tensor will fail appropriately
792 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
793 op.type = Op.Identity
794 return op
795
796
Raul Farkas66207142023-05-25 11:15:20 +0100797def rewrite_fully_connected_input(op: Operation, arch, nng) -> Operation:
798 """Rewrite FullyConnected shape as 2D to allow it to run on NPU."""
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200799 # If the operation already have a read shape do not modify
800 # the ifm shape, since that will already be correct
801 if op.type == Op.FullyConnected and not op.read_shapes[0]:
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100802 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
803 assert new_shape is not None, "Tensor can not be reshaped to 2D"
804 op.ifm_shapes[0] = new_shape
Johan Alfvén65835e02022-10-13 10:49:30 +0200805
806 if op.ifm_shapes[0].batch > 1 and op.ofm_shapes[0].batch == 1:
807 # If IFM is batching then also make sure OFM is batching
808 h, w = op.ofm_shapes[0].height, op.ofm_shapes[0].width
809 op.ofm_shapes[0] = Shape4D([h * w, 1, 1, op.ofm_shapes[0].depth])
810
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200811 return op
812
813
Raul Farkas66207142023-05-25 11:15:20 +0100814def convert_batched_fc_shape(op: Operation, arch, nng) -> Operation:
815 """Convert batched FullyConnected op shape to allow for support on NPU."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200816 if op.type == Op.FullyConnected:
817 # Check if the first dimension indicates batching
818 if op.ifm_shapes[0].batch > 1:
819 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
820 n = op.ifm_shapes[0].batch
821 h, w = batching_split.get(n, (1, n))
822 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
823
824 # Reshape Weights to be 4D. IO becomes HWIO
825 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100826 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
827 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200828
829 n = op.ofm_shapes[0].batch
830 h, w = batching_split.get(n, (1, n))
831 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
832 return op
833
834
835def unfuse_activation_function(op):
836 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
837 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
838 op.activation = None
839 out_tens = op.outputs[0]
840 intermediate_tens = out_tens.clone("_act_intermediate")
841 act_op.set_output_tensor(out_tens)
842 act_op.add_input_tensor(intermediate_tens)
843 op.set_output_tensor(intermediate_tens)
844 act_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000845 DebugDatabase.add_optimised(op, act_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200846
847
848def rewrite_stridedslice_output(op, arch, nng):
849 if not op.run_on_npu or op.type != Op.StridedSlice:
850 return op
851
852 new_axis_mask = op.attrs["new_axis_mask"]
853 shrink_axis_mask = op.attrs["shrink_axis_mask"]
854
855 if shrink_axis_mask == 0 and new_axis_mask == 0:
856 return op
857
858 axis_4D = [0] * len(op.outputs)
859 for idx, out_tens in enumerate(op.outputs):
860 output_shape = list(out_tens.shape)
861
862 if shrink_axis_mask != 0:
863 n = 0
864 axis = 0
865 while shrink_axis_mask:
866 prev_mask = shrink_axis_mask
867 n += 1
868 shrink_axis_mask &= shrink_axis_mask - 1
869 axis = int(math.log2(prev_mask - shrink_axis_mask))
870 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
871
872 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
873 op.attrs["shrink_axis_mask"] = 0
874 if axis >= 0:
875 axis_4D[idx] = axis + (4 - len(output_shape))
876 else:
877 axis_4D[idx] = axis
878 op.ofm_shapes[idx] = Shape4D(output_shape)
879
880 elif new_axis_mask != 0:
881 n = 0
882 axis = 0
883 while new_axis_mask:
884 prev_mask = new_axis_mask
885 n += 1
886 new_axis_mask &= new_axis_mask - 1
887 axis = int(math.log2(prev_mask - new_axis_mask))
888 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
889 new_axis_mask >>= 1
890
891 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
892 op.attrs["new_axis_mask"] = 0
893 if axis >= 0:
894 axis_4D[idx] = axis + (4 - len(output_shape))
895 else:
896 axis_4D[idx] = axis
897 op.ofm_shapes[idx] = Shape4D(output_shape)
898
899 op.attrs["split_axis_4D"] = axis_4D
900 return op
901
902
903def rewrite_unpack_output(op, arch, nng):
904 tens = op.outputs[0]
905 if op.run_on_npu and op.type == Op.Unpack:
906 # Unpack is also referred to as Unstack
907 axis = int(op.attrs["axis"])
908 if axis < 0: # Convert to positive axis
909 axis = len(op.inputs[0].shape) + 1 + axis
910 op.type = Op.UnpackReshaped
911 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
912
913 axis_4D = axis + (4 - len(desired_output_shape))
914 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
915
916 for idx, out_tens in enumerate(op.outputs):
917 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
918 return op
919
920
921def add_padding_fields(op, arch, nng):
922 if op.run_on_npu:
923 if "padding" in op.attrs:
924 input_shape = op.ifm_shapes[0]
925 output_shape = op.ofm_shapes[0]
926 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
927 kernel_size = op.inputs[1].shape[:2]
928 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
929 kernel_size = op.attrs["ksize"][1:3]
930 else:
931 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
932
Johan Alfvenc0bb8682023-09-04 17:18:33 +0200933 if op.type == Op.Conv2DBackpropInputSwitchedBias and op.ifm_resampling_mode == resampling_mode.TRANSPOSE:
934 # Transpose with upscale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200935 padding, skirt = calc_upscaled_padding_and_skirt(
Johan Alfvenc0bb8682023-09-04 17:18:33 +0200936 op.attrs["padding"],
937 kernel_size,
938 op.attrs["strides"],
939 input_shape,
940 output_shape.height // input_shape.height,
941 output_shape.width // input_shape.width,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200942 )
943 else:
944 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200945 op.attrs["padding"],
946 op.kernel,
947 input_shape,
948 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200949 )
950
951 op.attrs["explicit_padding"] = padding
952 op.attrs["skirt"] = skirt
953
954 return op
955
956
Raul Farkas66207142023-05-25 11:15:20 +0100957def reorder_depthwise_weights(op: Operation, arch, nng) -> Operation:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200958 if op.type.is_depthwise_conv2d_op():
959 weight_tensor = op.inputs[1]
Alexander Hansson90c34b52023-05-31 15:03:03 +0000960 if not weight_tensor.weight_transpose_depthwise:
961 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
962 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
963 weight_tensor.weight_transpose_depthwise = True
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200964
965 return op
966
967
Raul Farkas3e7157b2023-05-09 09:09:17 +0100968def convert_avg_pool_to_conv2d(op: Operation, arch, nng) -> Operation:
969 """Convert strided Average Pools with stride >= 4 to Conv2D."""
970 if op.type != Op.AvgPool:
971 return op
972
973 stride_x, stride_y = op.get_kernel_stride()
974 # For strides <= 3 no optimization is needed
975 if stride_x <= 3:
976 return op
977 h, w = op.attrs["filter_height"], op.attrs["filter_width"]
978 inputs = op.inputs[0]
979 shape = inputs.shape
980
981 # Set necessary conv2d attributes
982 op.attrs.update(
983 {
984 "stride_h": stride_y,
985 "stride_w": stride_x,
986 "dilation_h_factor": 1,
987 "dilation_w_factor": 1,
988 "strides": (1, stride_y, stride_x, 1),
989 "dilation": (1, 1, 1, 1),
990 }
991 )
992
993 # Change op type
994 op.type = Op.Conv2DBias
995 op.name += "_conv2d"
996
997 op.rounding_mode = RoundingMode.AwayZero
998 shape = [h, w, 1, op.ofm.shape[-1]]
999 weights = np.full(shape, 1)
1000 quant = QuantizationParameters(scale_f32=1 / (h * w), zero_point=0)
1001 # Add unit weight tensor
1002 op.add_input_tensor(
1003 create_const_tensor(
1004 "weights",
1005 shape,
1006 inputs.dtype,
1007 weights,
1008 quantization=quant,
1009 ),
1010 )
1011 op.weights.values = np.reshape(op.inputs[1].values, shape)
1012
1013 # Set IFM/OFM shapes after changing op type
1014 op.set_ifm_ofm_shapes()
1015 return op
1016
1017
1018def fixup_strided_conv(op: Operation, arch, nng):
Raul Farkas72c6a242023-03-16 16:38:05 +00001019 """Optimize or fixup strided Conv2DBias
1020 Optimization:
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001021 Reduce, when possible, the Conv2DBias stride from N with 1 > N > 4 to 1
1022 by re-shaping both IFM and filter.
Raul Farkas72c6a242023-03-16 16:38:05 +00001023
1024 Fixup:
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001025 Introduce software support for Conv2DBias with stride_width > 4 by
1026 reducing it to 1, 2 or 3 (HW supported strides) when possible by
1027 re-shaping both IFM and filter.
Raul Farkas72c6a242023-03-16 16:38:05 +00001028 """
Raul Farkas090f18a2023-01-24 16:29:06 +00001029 if op.type != Op.Conv2DBias:
Louis Verhaard43d27582022-03-17 14:06:00 +01001030 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001031 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +01001032 weight_tensor = op.weights
1033 ifm_shape = op.ifm_shapes[0]
Raul Farkas69782af2023-05-09 10:39:52 +01001034
1035 # Do not optimize if op is not the first in the network and stride is
1036 # supported by the hardware
1037 if op.op_index != 0 and stride_x < 4:
1038 return op
1039
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001040 resize_factor, final_stride = calc_resize_factor(ifm_shape.width, stride_x)
1041
1042 def calc_filter_padding(
1043 ifm_padding_type: Padding | None,
1044 ifm_current_padding_x: int,
1045 post_op_stride: int,
1046 opt_resize_factor: int,
1047 filter_width: int,
Raul Farkas3b64f062023-05-16 17:18:31 +01001048 ifm_width: int,
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001049 ) -> tuple[int, int, int, int]:
1050 """Calculate zero padding to be added to the filter.
1051
1052 Parameters
1053 ----------
1054 ifm_padding_type : Padding or None
1055 The padding type that is applied to the IFM.
1056 ifm_current_padding_x : int
1057 Padding amount that is added to the IFM before optimization.
1058 post_op_stride : int
1059 The final stride once optimization is performed.
1060 opt_resize_factor : int
1061 The factor by which the stride will be reduced.
1062 E.g. opt_resize_factor = 2 on a stride of 4 will produce
1063 a stride of 2 after the optimization
1064 filter_width : int
1065 Width of the filter before optimization.
Raul Farkas3b64f062023-05-16 17:18:31 +01001066 ifm_width : int
1067 Width of the IFM before optimization
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001068
1069 Returns
1070 -------
1071 padding : tuple[int, int, int, int]
1072 A tuple with the ammount of padding on each side (top, left, bottom, right)
1073 """
1074 padding_size = 0
1075 padding = (0, 0, 0, 0)
1076 if ifm_padding_type and ifm_padding_type != Padding.VALID:
Raul Farkas3b64f062023-05-16 17:18:31 +01001077 # Compute padding size for the filter that guarantees that HW padding added to IFM matches
1078 # before and after the optimization is performed
1079 expected_filter_size = 0
1080 pre_opt_stride = post_op_stride * opt_resize_factor
1081 post_opt_ifm_width = ifm_width // opt_resize_factor
1082 # Compute the total expected filter size post optimization that ensures that the same HW padding
1083 # is added to IFM.
1084 # There are two ways of calculating required filter size depending on whether IFM width is divisible
1085 # by stride width or not. These approaches match the cases used to calculate HW padding in
1086 # needed_total_padding method.
1087 if ifm_width % pre_opt_stride == 0:
1088 expected_filter_size = ifm_current_padding_x + post_op_stride
1089 else:
1090 expected_filter_size = ifm_current_padding_x + (post_opt_ifm_width % post_op_stride)
1091 # Compute padding size from expected filter size
1092 padding_size = expected_filter_size * opt_resize_factor - filter_width
1093
1094 if ifm_current_padding_x == 0:
1095 # If no HW padding is added to IFM, divide filter padding between left and right following
1096 # the same strategy as the reference.
1097 padding_left = padding_size // 2
1098 else:
1099 # If HW padding is added to IFM, split padding for the filter so that left padding and right padding
1100 # are proportional to left and right HW padding.
1101 left_hw_padding = ifm_current_padding_x // 2
1102 # Compute filter padding
1103 padding_left = padding_size // ifm_current_padding_x * left_hw_padding
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001104 padding = (0, padding_left, 0, padding_size - padding_left)
1105
1106 # Check if filter width is divisible by the stride width (required for optimization)
Raul Farkas3b64f062023-05-16 17:18:31 +01001107 # If filter width is not divisible by stride width and no HW padding is added to IFM, compute
1108 # filter padding required for the filter width to be divisible by the stride width and apply it as right
1109 # padding.
1110 if filter_width % opt_resize_factor != 0 and (padding_size == 0 or ifm_current_padding_x == 0):
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001111 padding_size = opt_resize_factor - (filter_width % opt_resize_factor)
1112 # Add padding zeros to the right
1113 padding = (0, 0, 0, padding_size)
1114
1115 return padding
1116
1117 # Compute the depth of the IFM once the strided Conv2D is optimised
1118 post_opt_ifm_depth = ifm_shape.depth * resize_factor
1119
1120 if stride_x > 1 and (post_opt_ifm_depth <= 8 or stride_x > 3) and resize_factor != 1 and weight_tensor is not None:
1121 k_w, _ = op.get_kernel_size()
1122 weight_shape = weight_tensor.shape
1123
1124 padding_type = op.attrs.get("padding", None)
1125 if padding_type in (None, Padding.EXPLICIT, Padding.TILE):
Louis Verhaard43d27582022-03-17 14:06:00 +01001126 return op
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001127 # Compute current padding as if IFM padding is SAME
1128 curr_padding_x = needed_total_padding(ifm_shape.width, stride_x, k_w)
1129 # Compute the padding needed on the filter for the optimisation
1130 _, left_filter_padding, _, right_filter_padding = calc_filter_padding(
Raul Farkas3b64f062023-05-16 17:18:31 +01001131 padding_type, curr_padding_x, final_stride, resize_factor, k_w, ifm_shape.width
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001132 )
1133 total_horizontal_padding = left_filter_padding + right_filter_padding
1134 # If IFM padding is enabled, check if pre-opt and post-opt padding is
1135 # the same while taking into consideration the extra filter padding.
1136 if padding_type == Padding.SAME:
1137 optimised_padding_x = needed_total_padding(
1138 ifm_shape.width // resize_factor, final_stride, (k_w + 1 + total_horizontal_padding) // resize_factor
1139 )
1140 if curr_padding_x != optimised_padding_x:
1141 # Horizontal padding would become different after optimisation; this would not work
1142 return op
1143
1144 # Resize IFM
Raul Farkas090f18a2023-01-24 16:29:06 +00001145 op.ifm_shapes[0] = Shape4D(
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001146 [ifm_shape.batch, ifm_shape.height, ifm_shape.width // resize_factor, ifm_shape.depth * resize_factor]
Raul Farkas090f18a2023-01-24 16:29:06 +00001147 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001148
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001149 # Compute list of 0 padding for each dimensions of the filter
1150 filter_dimension_padding = [(0, 0) for _ in weight_tensor.shape]
1151 # Update padding for filter width with computed padding
1152 filter_dimension_padding[1] = (left_filter_padding, right_filter_padding)
1153 # Add padding to the filter
1154 zero_point = weight_tensor.quantization.zero_point
1155 padding_constant = zero_point if np.isscalar(zero_point) else 0
1156 padded_filter_tensor = np.pad(weight_tensor.values, filter_dimension_padding, constant_values=padding_constant)
1157 weight_shape[1] = padded_filter_tensor.shape[1]
1158 weight_tensor.values = padded_filter_tensor
Raul Farkas090f18a2023-01-24 16:29:06 +00001159 # Change weight shape based on stride_x
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001160 weight_shape[1] //= resize_factor
1161 weight_shape[2] *= resize_factor
Raul Farkas090f18a2023-01-24 16:29:06 +00001162
James Peet7519d502021-07-19 16:47:58 +01001163 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001164 weight_tensor.set_all_shapes(weight_shape)
1165 # If multiple copies of the weights are used, we could avoid
1166 # them having the same address by changing the value_id
1167 weight_tensor.value_id = uuid.uuid4()
1168
1169 # Strides
Raul Farkas10d6b3b2023-01-30 12:58:46 +00001170 stride_x = final_stride
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001171 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
1172
1173 return op
1174
1175
Raul Farkas66207142023-05-25 11:15:20 +01001176def convert_conv_to_fc(op: Operation, arch, nng) -> Operation:
1177 """Convert 1x1 Conv2D that behave like FullyConnected to FullyConnected, since they don't need any weight
1178 buffering.
1179 """
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001180 # Conv 1x1 can be equivalent to Fully Connected.
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001181 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
1182 if op.type == Op.Conv2DBias:
1183 h = op.ifm_shapes[0].height
1184 w = op.ifm_shapes[0].width
1185 kh, kw, _, _ = op.inputs[1].shape
1186 if h == 1 and w == 1 and kh == 1 and kw == 1:
1187 # Overwrite this op as a Fully Connected Op
1188 op.name += "_fc"
1189 op.type = Op.FullyConnected
1190 op.attrs = {
1191 "weights_format": 0,
1192 }
1193 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
1194 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +01001195 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
1196 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001197
1198 DebugDatabase.add_optimised(op, op)
1199 return op
1200
1201
Raul Farkas66207142023-05-25 11:15:20 +01001202def fixup_relus_with_differing_ifm_ofm_scaling(op: Operation, arch, nng) -> Operation:
1203 """Fixup Relu with different IFM and OFM to allow fusing by adding its own primary op."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001204 if op.run_on_npu and op.type.is_relu_op():
1205 ifm = op.inputs[0]
1206 ofm = op.outputs[0]
1207 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
1208 # and requires its own to be inserted
1209 if not check_quantized_tens_scaling_equal(ifm, ofm):
1210 # Override this op with its own primary op (avgpool)
1211 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
1212 # And fuse the original activation function to it
1213 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +02001214 # Add explicit rescaling
1215 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
1216 multiplier, shift = scaling.quantise_scale(rescale)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001217 relu_fused_op.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001218 # Tidy up and assign the ifm and ofm to the new op
1219 ifm.consumer_list.remove(op)
1220
1221 relu_fused_op.add_input_tensor(ifm)
1222 relu_fused_op.set_output_tensor(ofm)
1223 relu_fused_op.set_ifm_ofm_shapes()
1224 op = relu_fused_op
1225 return op
1226
1227
Raul Farkas66207142023-05-25 11:15:20 +01001228def convert_lstm(op: Operation, arch, nng) -> Operation:
1229 """Convert LSTM op into its basic opearations to allow for support on NPU."""
Fredrik Svedberg0ac08042023-04-11 22:35:04 +02001230 if op.type == Op.UnidirectionalSequenceLstm:
1231 lstm = Lstm(op)
1232 op = lstm.get_graph()
1233 return op
1234
1235
Raul Farkas66207142023-05-25 11:15:20 +01001236def convert_softmax(op: Operation, arch, nng) -> Operation:
1237 """Convert Softmax op into its basic operations to allow for support on NPU."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001238 if op.type == Op.Softmax and op.run_on_npu:
1239 softmax = SoftMax(op)
1240 op = softmax.get_graph()
1241 return op
1242
1243
Raul Farkas66207142023-05-25 11:15:20 +01001244def convert_prelu(op: Operation, arch, nng) -> Operation:
1245 """Convert PReLU op to other ops based on alpha values to allow for support on NPU."""
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001246 if op.type == Op.Prelu:
1247 ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
1248 if None in (ifm, alpha, ofm):
1249 return op
1250
Fredrik Svedberg66591652022-08-29 10:51:27 +02001251 if alpha.values is not None:
1252 # If const alpha check for possible optimisations
1253 alpha_zp = alpha.quantization.zero_point
1254 alpha_scale = alpha.quantization.scale_f32
1255 # If all alpha values are the same the PReLU can be converted to LeakyRelu
Rickard Bolin5fdcf172022-12-19 12:56:17 +00001256 alpha_min = (alpha.values.min().astype(int) - alpha_zp) * alpha_scale
1257 alpha_max = (alpha.values.max().astype(int) - alpha_zp) * alpha_scale
Fredrik Svedberg66591652022-08-29 10:51:27 +02001258 if alpha_min == alpha_max:
1259 # or even a Relu
1260 if alpha_min == 0:
1261 new_op = Op.Relu
1262 else:
1263 new_op = Op.LeakyRelu
1264 op.attrs["alpha"] = alpha_min
1265 # setup alpha_scaling for bit exact result
1266 ifm_scale = ifm.quantization.scale_f32
1267 ofm_scale = ofm.quantization.scale_f32
1268 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha_scale, ofm_scale)
1269 op.attrs["alpha_scaling"] = (alpha.values.min() - alpha_zp, alpha_scale, alpha_shift)
1270 # Change op type
1271 op.type = new_op
1272 op.name = op.name.replace("Prelu", new_op.name)
1273 del op.inputs[1] # Remove alpha tensor
1274 return op
1275 elif alpha_max < 1:
1276 # If alpha_max is less than 1 convert PReLU to Max(alpha * IFM, identity * IFM)
1277 # Multiply with alpha tensor
1278 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
1279 mul_alpha.add_input_tensor(ifm)
1280 mul_alpha.add_input_tensor(alpha)
1281 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1282 mul_alpha.set_output_tensor(fm_alpha)
1283 mul_alpha.set_ifm_ofm_shapes()
1284 DebugDatabase.add_optimised(op, mul_alpha)
1285 if check_quantized_tens_scaling_equal(ifm, ofm):
1286 # No scaling is needed
1287 fm_id = ifm
1288 else:
1289 # Add multiplication with identity
1290 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1291 mul_identity.add_input_tensor(ifm)
1292 # Create const tensor containing identity as scalar
1293 quantization = ifm.quantization.clone()
1294 quantization.scale_f32 = np.float32(1)
1295 quantization.zero_point = 0
1296 one = create_const_tensor("one_const", [], ifm.dtype, [1], quantization=quantization)
1297 mul_identity.add_input_tensor(one)
1298 # Make sure that fm_id is allocated to a different address than fm_alpha
1299 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1300 mul_identity.set_output_tensor(fm_id)
1301 mul_identity.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +00001302 DebugDatabase.add_optimised(op, mul_identity)
Fredrik Svedberg66591652022-08-29 10:51:27 +02001303
1304 # Combine scaled and alpha multiplied values
1305 max_op = Operation(Op.Maximum, op.name + "_max")
1306 max_op.add_input_tensor(fm_alpha)
1307 max_op.add_input_tensor(fm_id)
1308 max_op.set_output_tensor(ofm)
1309 max_op.set_ifm_ofm_shapes()
1310
1311 DebugDatabase.add_optimised(op, max_op)
1312 ifm.consumer_list.remove(op)
1313 return max_op
1314
1315 # Catch all PReLU conversion for the cases that could not be optimised above
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001316 no_scale_quant = ifm.quantization.clone()
1317 no_scale_quant.scale_f32 = None
1318 no_scale_quant.zero_point = 0
Fredrik Svedberg66591652022-08-29 10:51:27 +02001319 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001320
1321 # Select values < 0
1322 min_op = Operation(Op.Minimum, op.name + "_min")
1323 min_op.add_input_tensor(ifm)
1324 min_op.add_input_tensor(zero)
1325 fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
1326 min_op.set_output_tensor(fm_negative)
1327 min_op.set_ifm_ofm_shapes()
1328 DebugDatabase.add_optimised(op, min_op)
1329
1330 # and multiply with alpha tensor
1331 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
1332 mul_alpha.add_input_tensor(fm_negative)
1333 mul_alpha.add_input_tensor(alpha)
1334 fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
1335 mul_alpha.set_output_tensor(fm_alpha)
1336 mul_alpha.set_ifm_ofm_shapes()
1337 DebugDatabase.add_optimised(op, mul_alpha)
1338
1339 # Select (and scale) values > 0
1340 relu_op = Operation(Op.Relu, op.name + "_relu")
1341 relu_op.add_input_tensor(ifm)
1342 fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1343 relu_op.set_output_tensor(fm_scaled)
1344 relu_op.set_ifm_ofm_shapes()
1345 DebugDatabase.add_optimised(op, relu_op)
1346
1347 # Add scaled and alpha multiplied values (without scaling)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001348 add_op = Operation(Op.Add, op.name + "_add")
1349 add_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001350 add_op.add_input_tensor(fm_alpha)
1351 add_op.add_input_tensor(fm_scaled)
1352 add_op.set_output_tensor(ofm)
1353 add_op.set_ifm_ofm_shapes()
1354
1355 DebugDatabase.add_optimised(op, add_op)
1356 ifm.consumer_list.remove(op)
1357 op = add_op
1358
1359 return op
1360
1361
Raul Farkas66207142023-05-25 11:15:20 +01001362def convert_mul_max_to_abs_or_lrelu(op: Operation, arch, nng) -> Operation:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001363 r"""Whenever there is a subgraph with this topology:
1364
Jonas Ohlssond8575072022-03-30 10:30:25 +02001365 Input X For X = -1 or X > 0
1366 | \ / This subgraph can be replaced with either
1367 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
1368 | /
1369 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001370 """
1371
1372 if op.type == Op.Maximum:
1373 # finds the Mul input(s) to the Max
1374 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
1375 if len(muls) == 1:
1376 mul = muls[0].ops[0]
1377 elif len(muls) == 2:
1378 # In the case both inputs are Muls, find the one with the same input as the Max
Fredrik Svedberg66591652022-08-29 10:51:27 +02001379 mul_ifms = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1]
1380 if len(mul_ifms):
1381 mul = mul_ifms[0].ops[0]
1382 else:
1383 # Not using same input
1384 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001385 else:
1386 # No Mul inputs
1387 return op
1388
1389 # make sure the Mul doesn't have any other consumers
1390 mul_ofm = mul.outputs[0]
1391 if len(mul_ofm.consumers()) != 1:
1392 return op
1393 # make sure the Mul doesn't have a fused activation function
1394 if mul.activation:
1395 return op
1396 ifm, ofm = op.get_ifm_ofm()
1397 if ifm is None or ofm is None:
1398 return op
1399
1400 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1401 return op
1402 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
1403 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
1404 return op
1405
1406 # finds the branched input that goes to both the Max and the Mul
1407 shared = set(op.inputs) & set(mul.inputs)
1408 if len(shared) == 1:
1409 shared_in = shared.pop()
1410 # find the constant scalar input to the Mul
1411 const_tens = (set(mul.inputs) - {shared_in}).pop()
1412 # check that it is a scalar
1413 if const_tens.shape != []:
1414 return op
1415 const = const_tens.ops[0]
1416 # check that it is a constant
1417 if const.type != Op.Const:
1418 return op
1419 # Remove the Mul from the shared input's consumers
1420 shared_in.consumer_list.remove(mul)
1421 else:
1422 return op
1423
1424 val = const.outputs[0].values
1425 if val >= 0:
1426 new_op = Op.LeakyRelu
1427 op.attrs["alpha"] = val
1428 # to produce bit exact results, the alpha is not enough;
1429 # save additional scaling info in attr "alpha_scale", to be used as input
1430 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +01001431 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001432 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
1433 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
1434 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
1435 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
1436 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
1437 elif val == -1:
1438 new_op = Op.Abs
1439 else:
1440 return op
1441
1442 op.type = new_op
1443 op.name = op.name.replace("Maximum", new_op.name)
1444 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
1445 op.inputs = [shared_in]
1446 op.set_ifm_ofm_shapes()
1447
1448 # Record optimisation in debug database
1449 DebugDatabase.add_optimised(op, op)
1450
1451 return op
1452
1453
Raul Farkas66207142023-05-25 11:15:20 +01001454def convert_hardswish_to_lut(op: Operation, arch, nng) -> Operation:
1455 """Convert HardSwish to LUT to allow for support on NPU."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001456 if op.type == Op.HardSwish:
1457 ifm, ofm = op.get_ifm_ofm()
1458 # Generate the LUT
1459 ifm_scale = np.double(ifm.quantization.scale_f32)
1460 ofm_scale = np.double(ofm.quantization.scale_f32)
1461 zp_in = ifm.quantization.zero_point
1462 zp_out = ofm.quantization.zero_point
1463 ifm_scale_hires = (1 / 128) * ifm_scale
1464 relu_multiplier = np.double(3 / 32768)
1465 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
1466 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
1467 # Use 16bit scale
1468 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
1469 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
1470
1471 values = []
1472 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1473 quantized_min = min(ix)
1474 quantized_max = max(ix)
1475 for x in ix:
1476 input_value = x - zp_in
1477 input_value_hires = input_value * 128
1478 # Compute the input value on essentially the output scale, not shifted yet
1479 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
1480 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
1481 relu_value = np.int16(input_value_hires)
1482 if relu_shift < 31:
1483 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
1484
1485 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
1486
1487 if relu_shift < 31:
1488 relu_value = fp_math.shift_left16(relu_value, 1)
1489
1490 if relu_shift > 31:
1491 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
1492
1493 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
1494 # Now convert that to a 16bit fixedpoint value in [0, 1]
1495 relu_value = (relu_value + (1 << 15)) >> 1
1496 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
1497 shift = 31 - out_shift
1498 shift = -shift if shift < 0 else 0
1499 # Finally apply the output shift
1500 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
1501 lut_result = min(quantized_max, max(quantized_min, lut_result))
1502 values.append(lut_result)
1503 return convert_to_lut(op, values, "hardswish")
1504 return op
1505
1506
1507def convert_lrelu_to_mul_max(op, arch):
1508 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
1509 # (the opposite of convert_mul_max_to_abs_or_lrelu)
1510 ifm, ofm = op.get_ifm_ofm()
1511 if ifm is None or ofm is None:
1512 return op
1513
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001514 alpha = np.float32(op.attrs["alpha"])
1515 use_mul_max = 0 < alpha < 1
Fredrik Svedberg36424312022-09-16 09:39:26 +02001516 is_converted_prelu = "alpha_scaling" in op.attrs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001517 if use_mul_max:
1518 mul_ifm = ifm
1519 new_op = Op.Maximum
1520 else:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001521 # Need to use a different approach for alpha < 0 or alpha > 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001522 no_scale_quant = ifm.quantization.clone()
1523 no_scale_quant.scale_f32 = None
1524 no_scale_quant.zero_point = 0
1525 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
1526
1527 # Select values < 0
1528 min_op = Operation(Op.Minimum, op.name + "_min")
1529 min_op.add_input_tensor(ifm)
1530 min_op.add_input_tensor(zero)
1531 mul_ifm = ifm.clone(op.name + "_negative", set_unique=True)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001532 if alpha < 0 and not is_converted_prelu:
1533 # For negative alpha that is not from a converted PReLU we need to use
1534 # int32 Mul below to perform the (negative) alpha scaling
1535 mul_ifm.dtype = DataType.int32
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001536 min_op.set_output_tensor(mul_ifm)
1537 min_op.set_ifm_ofm_shapes()
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001538 new_op = Op.Add
1539 op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001540 DebugDatabase.add_optimised(op, min_op)
1541
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001542 # Add multiplication with alpha
1543 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001544 mul_alpha.add_input_tensor(mul_ifm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001545 # Create const tensor containing alpha as scalar
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001546 quantization = ifm.quantization.clone()
1547 quantization.min = 0
1548 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
1549 quantization.zero_point = 0
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001550 alpha_dtype = mul_ifm.dtype
Fredrik Svedberg36424312022-09-16 09:39:26 +02001551 if is_converted_prelu:
1552 # The LeakyRelu was the result from convert_prelu and the scaling is provided
Fredrik Svedberg66591652022-08-29 10:51:27 +02001553 scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001554 mul_alpha.explicit_scaling = ExplicitScaling(False, [alpha_shift], [alpha_scale])
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001555 elif alpha == 0 or np.isinf(1 / alpha):
1556 # Handling of alpha near or at zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001557 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001558 scalar = 0
1559 else:
1560 quantization.scale_f32 = alpha
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001561 if alpha_dtype == DataType.int32:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001562 # When the datatype is int32 (alpha negative) we need to do the scaling with the multiplication
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001563 scalar, _ = scaling.elementwise_mul_scale(ifm.quantization.scale_f32, alpha, ofm.quantization.scale_f32)
1564 else:
1565 scalar = 1
Tim Hall3b1578e2023-01-13 17:57:25 +00001566 alpha_tens = create_const_tensor(op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], quantization=quantization)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001567 mul_alpha.add_input_tensor(alpha_tens)
1568 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1569 mul_alpha.set_output_tensor(fm_alpha)
1570 mul_alpha.set_ifm_ofm_shapes()
1571 DebugDatabase.add_optimised(op, mul_alpha)
1572
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001573 if not use_mul_max:
1574 relu_op = Operation(Op.Relu, op.name + "_relu")
1575 relu_op.add_input_tensor(ifm)
1576 fm_id = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1577 relu_op.set_output_tensor(fm_id)
1578 relu_op.set_ifm_ofm_shapes()
1579 DebugDatabase.add_optimised(op, relu_op)
1580 elif check_quantized_tens_scaling_equal(ifm, ofm):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001581 # No identity multiplication is needed
1582 fm_id = ifm
1583 else:
1584 # Add multiplication with identity
1585 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1586 mul_identity.add_input_tensor(ifm)
1587 # Create const tensor containing identity as scalar
1588 quantization = ifm.quantization.clone()
1589 quantization.min = 0
1590 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001591 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001592 quantization.zero_point = 0
Tim Hall3b1578e2023-01-13 17:57:25 +00001593 identity_tens = create_const_tensor(op.name + "_id_scalar", [], ifm.dtype, [1], quantization=quantization)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001594 mul_identity.add_input_tensor(identity_tens)
1595 # Make sure that fm_id is allocated to a different address than fm_alpha
1596 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1597 mul_identity.set_output_tensor(fm_id)
1598 mul_identity.set_ifm_ofm_shapes()
1599 DebugDatabase.add_optimised(op, mul_identity)
1600
1601 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001602 op.type = new_op
1603 op.name = op.name.replace("LeakyRelu", new_op.name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001604 op.inputs = []
1605 ifm.consumer_list.remove(op)
1606 op.add_input_tensor(fm_alpha)
1607 op.add_input_tensor(fm_id)
1608 op.set_ifm_ofm_shapes()
1609
1610 DebugDatabase.add_optimised(op, op)
1611 return op
1612
1613
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001614def convert_to_lut8(op, fn, fn_name):
1615 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
1616 # fn is a function(real) -> real
1617 ifm, ofm = op.get_ifm_ofm()
1618 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1619 return op
1620 # Generate the LUT
1621 ifm_scale = np.double(ifm.quantization.scale_f32)
1622 ofm_scale = np.double(ofm.quantization.scale_f32)
1623 zp_in = ifm.quantization.zero_point
1624 zp_out = ofm.quantization.zero_point
1625 values = []
1626 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1627 quantized_min = min(ix)
1628 quantized_max = max(ix)
1629 for x in ix:
1630 x_real = ifm_scale * (x - zp_in)
1631 y_real = fn(x_real)
1632 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1633 lut_result = min(quantized_max, max(quantized_min, lut_result))
1634 values.append(lut_result)
1635 return convert_to_lut(op, values, fn_name)
1636
1637
1638def convert_lrelu_to_lut(op, arch):
1639 ifm, ofm = op.get_ifm_ofm()
1640 # Generate the LUT
1641 alpha = op.attrs["alpha"]
1642 ifm_scale = np.double(ifm.quantization.scale_f32)
1643 ofm_scale = np.double(ofm.quantization.scale_f32)
1644 zp_in = ifm.quantization.zero_point
1645 zp_out = ofm.quantization.zero_point
1646 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1647 alpha_scalar = 1
1648 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1649 if "alpha_scaling" in op.attrs:
1650 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1651 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1652 values = []
1653 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1654 quantized_min = min(ix)
1655 quantized_max = max(ix)
1656 for x in ix:
1657 if x < zp_in:
1658 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1659 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1660 )
1661 else:
1662 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1663 lut_result = min(quantized_max, max(quantized_min, lut_result))
1664 values.append(lut_result)
1665 return convert_to_lut(op, values, "lrelu")
1666
1667
Raul Farkas66207142023-05-25 11:15:20 +01001668def convert_lrelu(op: Operation, arch, nng) -> Operation:
1669 """Convert LeakyRelu to a LUT based solution if possible, otherwise a mul + max."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001670 if op.type != Op.LeakyRelu:
1671 return op
1672 ifm, ofm = op.get_ifm_ofm()
1673 if ifm is None or ofm is None:
1674 return op
Fredrik Svedberg36424312022-09-16 09:39:26 +02001675 alpha = op.attrs["alpha"]
1676 if alpha == 0:
1677 # When alpha is 0 the opertion can be converted to a ReLU
1678 op.type = Op.Relu
1679 op.name = op.name.replace("LeakyRelu", op.type.name)
1680 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001681 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1682 # use LUT for int8/uint8
1683 return convert_lrelu_to_lut(op, arch)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001684 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16 and alpha > 0:
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001685 # use LeakyRelu unmodified for int16 with equal input/output scaling and positive alpha
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001686 return op
1687 return convert_lrelu_to_mul_max(op, arch)
1688
1689
Raul Farkas66207142023-05-25 11:15:20 +01001690def convert_tanh_sigmoid_to_lut(op: Operation, arch, nng) -> Operation:
1691 """Convert int8/uint8 Sigmoid and Tanh to a LUT based solution."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001692 if op.type == Op.Sigmoid:
1693 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1694 elif op.type == Op.Tanh:
1695 return convert_to_lut8(op, math.tanh, "tanh")
1696 return op
1697
1698
Johan Gunnarsson98556372023-08-10 13:10:44 +02001699def convert_quantize(op: Operation, arch, nng) -> Operation:
1700 """Convert Quantize to Avgpool. This conversion only works for int-to-int re-quantization and
1701 not to/from floats. Therefor, this rewrite should only run after the supported ops check to
1702 avoid rewriting ops that will run on CPU."""
1703 if op.type == Op.Quantize:
1704 # Create a new AvgPool op and steal its attrs, then reuse the original op with different type
1705 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
1706 op.type = Op.AvgPool
1707 op.attrs = avgpool_op.attrs.copy()
1708
1709 DebugDatabase.add_optimised(op, op)
1710
1711 return op
1712
1713
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001714def fuse_activation_function_with_prev(op, arch, nng):
1715 # if op is a no-op: attempts to move the activation function to the preceding op
1716 if not op.attrs.get("is_nop", False) or op.activation is None:
1717 return op
1718 ifm, ofm = op.get_ifm_ofm()
1719 if ifm is None or ofm is None:
1720 return op
1721 # finds the input(s) to the operation
1722 prev_op = ifm.ops[0]
1723 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1724 fuse = (
1725 prev_op.run_on_npu
1726 and prev_op.type.npu_block_type != NpuBlockType.Default
1727 and len(ifm.ops) == 1
1728 and len(prev_op.outputs[0].consumers()) == 1
1729 and prev_op.activation is None
1730 )
1731 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1732 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1733 # LUT currently only works correctly for elementwise ops
1734 fuse = False
1735 if not fuse:
1736 return op
1737 # Move the fused activation function + corresponding info to prev_op
1738 prev_op.activation = op.activation
1739 prev_op.forced_output_quantization = op.forced_output_quantization
1740 if op.activation_lut is not None:
1741 prev_op.set_activation_lut(op.activation_lut)
1742 # Bypass op
1743 prev_op.set_output_tensor(ofm)
wilisa0179a89042022-11-02 17:18:43 +00001744 DebugDatabase.add_optimised(prev_op, prev_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001745 return op
1746
1747
1748def _leading_pad_ok(leading_pad, stride, kernel_size):
1749 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1750 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1751 max_size = kernel_size // 2
1752 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1753
1754
Raul Farkas66207142023-05-25 11:15:20 +01001755def replace_pad_by_hw_pad(op: Operation, arch, nng) -> Operation:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001756 """
1757 Tries to completely remove a PAD operator by using hardware padding.
1758 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1759 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1760 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1761 if both operations can be run on the NPU.
1762 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1763 """
1764 if (
1765 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +00001766 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001767 and op.run_on_npu
1768 and op.attrs["padding"] == Padding.VALID
1769 ):
1770 pad_op = op.ifm.ops[0]
1771 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1772 return op
1773 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1774 return op
1775 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1776 k = op.kernel
1777 k_w, k_h = k.dilated_wh()
1778
1779 # Check if the PAD operator can be replaced by hardware padding
1780 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1781 # Too much padding, it would require hardware padding to actually insert zeros
1782 return op
1783 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1784 return op
1785
1786 if op.type.is_avgpool_op():
1787 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1788 for pad, k_size in (
1789 (left, k_w),
1790 (right, k_w),
1791 (top, k_h),
1792 (bottom, k_h),
1793 ):
1794 if pad not in (0, k_size // 2):
1795 return op
1796 # Average pool is converted to depthwise, because NPU average pool + same padding
1797 # has a special implementation that is different from PAD followed by average pool with
1798 # valid padding.
1799 k_w, k_h = op.kernel.width, op.kernel.height
1800 ifm = op.ifm
1801 # Remember other inputs
1802 other_inputs = op.inputs[1:]
1803 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1804 quantization = QuantizationParameters(0.0, 255.0)
1805 quantization.scale_f32 = 1.0 / (k_w * k_h)
1806 quantization.zero_point = 0
1807 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1808 weights = np.full(shape, 1)
1809
1810 weight_tens = create_const_tensor(
1811 op.name + "_weights",
1812 shape,
1813 op.ifm.dtype,
1814 weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001815 purpose=TensorPurpose.Weights,
1816 quantization=quantization,
1817 )
James Peet7519d502021-07-19 16:47:58 +01001818 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001819 op.type = Op.DepthwiseConv2DBias
1820 op.inputs = []
1821 op.add_input_tensor(ifm)
1822 op.add_input_tensor(weight_tens)
Tim Hall5ff4cd12023-05-16 22:39:14 +01001823
1824 if op.ifm.dtype == DataType.uint8:
1825 op.rounding_mode = RoundingMode.HalfUp
1826
1827 # Add bias tensor, all biases set to 0
1828 op.inputs.append(None)
1829 fixup_bias_tensors(op, arch, nng, DataType.int32)
1830
1831 else:
1832 op.rounding_mode = RoundingMode.AwayZero
1833
1834 # The DepthwiseConv needs to be performed with the IFM zero point set appropriately so that the correct
1835 # pad values are used. However, in order to use the rounding away from zero mode the zero point needs to
1836 # have been removed so that the zero point is at zero. This is done by adding a kernel sized amount of
1837 # the zero point as a bias. The datatype of the bias needs to be set to int32, even for an int16 IFM,
1838 # because this will cause full precision scaling to be used (see weight compression). Finally, the OFM
1839 # zero point will need forcing to zero (as it has already been removed)
1840 nr_biases = op.inputs[1].shape[-1]
1841 bias_values = [op.ifm.quantization.zero_point * k_h * k_w] * nr_biases
1842 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
1843 op.add_input_tensor(bias_tensor)
1844
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001845 # Add other inputs
1846 op.inputs.extend(other_inputs)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001847
1848 # Bypass the PAD operator
1849 op.set_input_tensor(pad_op.ifm, 0)
1850 # Adjust the padding attributes of the convolution operator
1851 op.attrs["padding"] = Padding.EXPLICIT
1852 op.attrs["explicit_padding"] = (top, left, bottom, right)
1853 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +00001854 DebugDatabase.add_optimised(op, op)
1855
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001856 return op
1857
1858
1859def convert_pad(op: Operation, arch, nng):
1860 """
1861 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1862 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1863 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1864 """
1865 if op.type != Op.Pad or not op.run_on_npu:
1866 return op
1867 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1868
1869 ifm = op.ifm
1870 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001871 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001872 ofm = op.ofm
1873 assert ofm is not None
1874 ofm.ops = []
1875 ofm_shape = op.ofm_shapes[0]
1876
1877 # Average pool op that copies IFM to the right place inside the OFM
1878 shp0 = Shape4D(0, 0, 0, 0)
1879 shp_top = shp0.with_height(top)
1880 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1881 avgpool_op.activation = op.activation
1882 quant = ofm.quantization
1883 pad_value = quant.zero_point
1884 # Add operations that fill the borders of the OFM
1885 if top > 0:
1886 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1887 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001888 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001889 )
1890 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1891 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1892 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1893 if bottom > 0:
1894 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1895 zero_tens = create_const_tensor(
1896 op.name + "_bottom",
1897 shape.as_list(),
1898 ofm.dtype,
1899 shape.elements() * [pad_value],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001900 quantization=quant,
1901 )
1902 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1903 create_avg_pool_for_concat(
1904 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1905 )
1906 if left > 0:
1907 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1908 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001909 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001910 )
1911 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1912 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1913 if right > 0:
1914 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1915 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001916 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001917 )
1918 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1919 create_avg_pool_for_concat(
1920 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1921 )
1922
1923 op.type = Op.ConcatTFLite
1924 return avgpool_op
1925
1926
Raul Farkas66207142023-05-25 11:15:20 +01001927def fixup_bias_tensors(op: Operation, arch, nng, dtype=None) -> Operation:
1928 """Fixup ops that require a bias and don't have one by adding a bias tensor filled with zeros."""
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001929 if op.type.needs_bias() and op.bias is None:
1930 # Op has no bias, add bias tensor filled with zeros
1931 nr_biases = op.inputs[1].shape[-1]
1932 bias_values = [0] * nr_biases
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001933 # The DataType of the bias tensor can be explicitly provided or deduced from the ifm
1934 # DataType. Default is int32 bias for 8-bit ifms and int64 for int16 ifms.
1935 # For int16 the selected bias DataType will have an impact on the scaling
1936 # used when encoding the scales and biases later. The default mode will match the
1937 # refence with reduced scaling for int64 bias.
1938 # This means that in cases (in the graph optimiser) where DepthwiseConv2DBias
1939 # is used to emulate average pool int32 bias should be selected for full precision
1940 # int16 scaling.
1941 if dtype is None:
1942 dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32
1943 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values)
Raul Farkas3e7157b2023-05-09 09:09:17 +01001944 bias_index = op.type.info.indices.biases[0]
1945 if bias_index < len(op.inputs):
1946 op.set_input_tensor(bias_tensor, bias_index)
1947 else:
1948 op.add_input_tensor(bias_tensor)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001949
1950 return op
1951
1952
wilisa0146c94772023-02-08 09:56:14 +00001953def detect_asymmetric_weights(op):
1954 # Check all ops (cpu and npu)
1955 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
1956 if op.ifm.dtype in (DataType.int8, DataType.int16):
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001957 if not np.all(op.weights.quantization.zero_point == 0):
wilisa0146c94772023-02-08 09:56:14 +00001958 print(f"Warning: Op {op.type} '{op.name}' has asymmetric weights.", end=" ")
1959 return True
1960 return False
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001961
wilisa0146c94772023-02-08 09:56:14 +00001962
Raul Farkas66207142023-05-25 11:15:20 +01001963def fixup_asymmetric_weights(op: Operation, arch, nng) -> Operation:
wilisa0146c94772023-02-08 09:56:14 +00001964 if detect_asymmetric_weights(op):
1965 if op.run_on_npu:
1966 print("Zero points have been adjusted.")
1967 op.weights.quantization.zero_point *= 0
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001968 return op
1969
1970
wilisa0146c94772023-02-08 09:56:14 +00001971def check_asymmetric_weights(op, arch, nng):
1972 # This function can modify the run_on_npu flag which causes an operator to be placed on the CPU. It is usually only
1973 # set by the supported operator checks. Therefore, it should be run immediately after those checks to avoid the
1974 # possibility of other graph optimiser functions modify the operator (that is later run on the CPU)
1975 if detect_asymmetric_weights(op):
1976 if op.run_on_npu:
1977 print("To run the operator on Ethos-U use the option --force-symmetric-int-weights")
1978 op.run_on_npu = False
1979 return op
1980
1981
1982def fixup_or_check_asymmetric_weights(force_symmetric_int_weights):
1983 if force_symmetric_int_weights:
1984 return fixup_asymmetric_weights
1985 else:
1986 return check_asymmetric_weights
1987
1988
Rickard Bolina68b82a2023-04-20 15:12:28 +00001989def convert_mean_to_depthwise_conv(op, arch, nng):
Alexander Hansson90c34b52023-05-31 15:03:03 +00001990 """
1991 When h x w <= 4096 When h x w > 4096 there is a need to split into several ops.
1992 Do this by splitting up h and change the read_offset/shape.
1993 Below is an example where ifm is 1x190x64x1
1994 MEAN MEAN
1995 | |-----------------------|----------------------|
1996 DepthwiseConv2DBias 1_DepthwiseConv2DBias 2_DepthwiseConv2DBias 3_DepthwiseConv2DBias
1997 | | | |
1998 MUL |---------ADD-----------| |
1999 | |
2000 |----------------ADD---------------|
2001 |
2002 MUL
2003 1_DepthwiseConv2DBias: read_offset [0, 0, 0, 0]> read_shape [1, 64, 64, 1]>
2004 2_DepthwiseConv2DBias: read_offset [0, 64, 0, 0]> read_shape [1, 64, 64, 1]>
2005 3_DepthwiseConv2DBias: read_offset [0, 128, 0, 0]> read_shape [1, 62, 64, 1]>
2006 """
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002007 if op.type == Op.Mean and op.run_on_npu:
Alexander Hansson90c34b52023-05-31 15:03:03 +00002008 max_kernel_size = 4096
2009 max_height = 64
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002010 inp, axis = op.inputs
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002011 dims = len(inp.shape)
2012 dims_ofm = len(op.ofm.shape)
Alexander Hansson90c34b52023-05-31 15:03:03 +00002013 ofmq = op.ofm.quantization
2014 ifmq = op.ifm.quantization
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002015
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002016 # reduce_axis[i] is true if axis i should be reduced
2017 if axis.shape == []:
2018 reduce_axis = [True if i == axis.values else False for i in range(dims)]
2019 else:
2020 reduce_axis = [True if i in axis.values else False for i in range(dims)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002021
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002022 ifm_shape = inp.shape.copy()
2023 intermediate_shape = op.ofm.shape.copy()
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01002024
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002025 # Fix intermediate_shape when keep_dims is false
2026 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the intermediate_shape should be 1xHx1xC
2027 if dims_ofm < dims:
2028 for i in range(dims):
2029 if reduce_axis[i]:
2030 intermediate_shape.insert(i, 1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002031
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002032 # Reshape to 4D
Alexander Hanssonda8741a2023-06-30 15:41:13 +00002033 reduce_axis = full_shape(4, reduce_axis, False)
2034 ifm_shape = full_shape(4, ifm_shape, 1)
2035 intermediate_shape = full_shape(4, intermediate_shape, 1)
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002036
2037 # If all dimensions to reduce have shape 1, the operation is essentially a memcpy.
2038 # We can then remove the whole op by propagating ofm to previous ops
2039 if not any([reduce_axis[i] and ifm_shape[i] > 1 for i in range(4)]):
2040 op.type = Op.Memcpy
2041 op = bypass_memory_only_ops(op, arch, nng)
2042 return op
2043
Alexander Hanssonda8741a2023-06-30 15:41:13 +00002044 # Support mean over depth-axis by left-shifting the C channel
2045 # From semantics checks we can assume that one of H,W,C has shape 1
2046 if reduce_axis[3] and ifm_shape[3] > 1:
2047 assert 1 in ifm_shape[1:], "Mean reduction over depth channel, but none of H,W,C has shape 1"
2048 # If W=1 reshape NxHx1xC -> NxHxCx1, else reshape Nx1xWxC -> NxWxCx1
2049 idx_to_del = 2 if ifm_shape[2] == 1 else 1
2050
2051 # Delete axis with size 1
2052 del reduce_axis[idx_to_del]
2053 del ifm_shape[idx_to_del]
2054 del intermediate_shape[idx_to_del]
2055
2056 # Add another element to set channel-axis to one
2057 reduce_axis.append(False)
2058 ifm_shape.append(1)
2059 intermediate_shape.append(1)
2060
2061 # Compute kernel sizes for our convolutions
2062 # Batch axis is implicit as it is only supported if batch size is 1.
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002063 h = ifm_shape[1] if reduce_axis[1] else 1
2064 w = ifm_shape[2] if reduce_axis[2] else 1
2065
Alexander Hansson90c34b52023-05-31 15:03:03 +00002066 num_elements_in_axis = h * w
2067
2068 # If one convolution is enough, but height is greater than max kernel height
2069 # reshape from HxW to 1x(HxW)
2070 # This can only be done if the mean is computed over both H and W
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002071 if h > max_height and num_elements_in_axis <= max_kernel_size and reduce_axis[1] and reduce_axis[2]:
2072 ifm_shape = [ifm_shape[0], 1, h * w, ifm_shape[3]]
Alexander Hansson90c34b52023-05-31 15:03:03 +00002073 w = h * w
2074 h = 1
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002075
Alexander Hansson90c34b52023-05-31 15:03:03 +00002076 intermediate_op = None
2077 height_per_conv = min(max_kernel_size // w, h)
2078 height_per_conv = min(height_per_conv, max_height)
2079 num_convs = math.ceil(h / height_per_conv)
2080 convs = list()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002081
Alexander Hansson90c34b52023-05-31 15:03:03 +00002082 for i in range(num_convs):
2083 is_last_op = i == (num_convs - 1)
2084
2085 intermediate_op = op.clone(f"{op.name}_conv_{i}")
2086
2087 intermediate_op.type = Op.DepthwiseConv2DBias
2088
2089 # Set necessary depthwise attributes
2090 intermediate_op.attrs.update(
2091 {
2092 "padding": Padding.VALID,
2093 "stride_h": 1,
2094 "stride_w": 1,
2095 "strides": (1, 1, 1, 1),
2096 "depth_multiplier": 1,
2097 "channel_multiplier": 1,
2098 "dilation_h_factor": 1,
2099 "dilation_w_factor": 1,
2100 "dilation": (1, 1, 1, 1),
2101 }
2102 )
2103
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002104 b, _, _, c = ifm_shape
Alexander Hansson90c34b52023-05-31 15:03:03 +00002105
2106 intermediate_tensor = op.ofm.clone(suffix=f"_conv_sum_{i}", set_unique=True)
2107 intermediate_tensor.dtype = DataType.int32
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002108 intermediate_tensor.shape = intermediate_shape
Alexander Hansson90c34b52023-05-31 15:03:03 +00002109 intermediate_op.set_output_tensor(intermediate_tensor)
2110
2111 # as we have several convs, scaling/rounding must be done after the sum has been calculated
2112 intermediate_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1])
2113
2114 # compute height for the kernel
2115 if is_last_op and h % height_per_conv != 0:
2116 weight_h = h % height_per_conv
2117 else:
2118 weight_h = height_per_conv
2119
2120 # compute ifm read offset and shape for the convolution
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002121 read_shape_h = weight_h if reduce_axis[1] else ifm_shape[1]
2122 read_shape_w = w if reduce_axis[2] else ifm_shape[2]
Alexander Hansson90c34b52023-05-31 15:03:03 +00002123
2124 intermediate_op.read_offsets[0] = Shape4D([0, i * height_per_conv, 0, 0])
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002125 intermediate_op.read_shapes[0] = Shape4D(ifm_shape).with_hw(read_shape_h, read_shape_w)
Alexander Hansson90c34b52023-05-31 15:03:03 +00002126
2127 weight_quant = QuantizationParameters(0, 255, scale_f32=1.0, zero_point=0)
2128 weight_shape = [weight_h, w, c, b]
2129 weight_tensor = create_const_tensor(
2130 f"{intermediate_op.name}_weights",
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002131 weight_shape,
Alexander Hansson90c34b52023-05-31 15:03:03 +00002132 DataType.uint8,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002133 np.ones(weight_shape),
Alexander Hansson90c34b52023-05-31 15:03:03 +00002134 TensorPurpose.Weights,
2135 quantization=weight_quant,
2136 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002137
Alexander Hansson90c34b52023-05-31 15:03:03 +00002138 weights_1D = np.ones(np.prod(weight_shape))
2139 weight_tensor.equivalence_id = create_equivalence_id(tuple(weights_1D))
2140 weight_tensor.value_id = weight_tensor.equivalence_id
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002141
Alexander Hansson90c34b52023-05-31 15:03:03 +00002142 intermediate_op.set_input_tensor(weight_tensor, 1)
Rickard Bolina68b82a2023-04-20 15:12:28 +00002143
Alexander Hansson90c34b52023-05-31 15:03:03 +00002144 dtype = DataType.int64 if intermediate_op.ifm.dtype == DataType.int16 else DataType.int32
2145 bias_values = [0] * c
2146 bias = create_const_tensor(f"{intermediate_op.name}_bias", [c], dtype, bias_values)
2147 bias.equivalence_id = create_equivalence_id(tuple(bias_values))
2148 bias.value_id = bias.equivalence_id
2149 intermediate_op.inputs.append(bias)
2150 intermediate_op.set_ifm_ofm_shapes()
Johan Alfven7b3008a2023-04-13 18:54:47 +02002151
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002152 # We want to avoid reshaping the ifm tensor directly, to not affect other ops
Alexander Hansson90c34b52023-05-31 15:03:03 +00002153 # so we update the shape explicitly for this operation
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002154 intermediate_op.ifm_shapes[0] = Shape4D(ifm_shape)
Rickard Bolina68b82a2023-04-20 15:12:28 +00002155
Alexander Hansson90c34b52023-05-31 15:03:03 +00002156 convs.append(intermediate_op)
2157 DebugDatabase.add_optimised(op, intermediate_op)
2158
2159 # If we have more than one convolution
2160 # We use add operations to accumulate the intermediate tensors
2161 if len(convs) > 1:
2162 prev_add_op = None
2163 idx = 0
2164
2165 while len(convs):
2166 intermediate_tensor = op.ofm.clone(suffix=f"_add_sum_{idx}", set_unique=True)
2167 intermediate_tensor.dtype = DataType.int32
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002168 intermediate_tensor.shape = intermediate_shape
Alexander Hansson90c34b52023-05-31 15:03:03 +00002169
2170 one_scale_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
2171
2172 ifm = convs.pop().ofm
2173 if not prev_add_op:
2174 ifm2 = convs.pop().ofm
2175 else:
2176 ifm2 = prev_add_op.ofm
Alexander Hansson90c34b52023-05-31 15:03:03 +00002177 intermediate_op = create_add(f"{op.name}_add_{idx}", ifm, ifm2, one_scale_quant)
2178 intermediate_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1])
2179 intermediate_op.set_output_tensor(intermediate_tensor)
2180 intermediate_op.set_ifm_ofm_shapes()
2181
2182 prev_add_op = intermediate_op
2183 idx += 1
2184
2185 DebugDatabase.add_optimised(op, intermediate_op)
2186
2187 # Convert the original mean op to our final Mul operation
2188 # Which scales and divides by num_elements_in_axis
2189 op.type = Op.Mul
2190 op.name = f"{op.name}_mul"
2191 op.attrs = {}
2192 op.set_input_tensor(intermediate_op.ofm, 0)
Rickard Bolina68b82a2023-04-20 15:12:28 +00002193
Johan Alfven7b3008a2023-04-13 18:54:47 +02002194 # The multiplier is calculated in the same way as in the reference,
2195 # clamping the shift value at the price of some precision loss.
Johan Alfven7b3008a2023-04-13 18:54:47 +02002196 output_multiplier, output_shift_vela = quantise_scale(np.double(ifmq.scale_f32) / np.double(ofmq.scale_f32))
2197
2198 # Convert to reference representation shift value
2199 output_shift = 31 - output_shift_vela
2200
2201 # Reference calculation
2202 # round_down_log2 same as 63 - CountLeadingZeros(num_elements_in_axis)
2203 shift = round_down_log2(num_elements_in_axis)
2204 shift = min(shift, 32)
2205 shift = min(shift, 31 + output_shift)
2206 output_multiplier = (output_multiplier << shift) // num_elements_in_axis
2207 output_shift = output_shift - shift
2208
2209 # Convert to vela representation shift
2210 output_shift_vela = 31 - output_shift
2211
2212 # For int32 scaling is not supported so instead multiply with the scale
2213 # intermediate * scale -> round and shift.
Alexander Hansson90c34b52023-05-31 15:03:03 +00002214 identity_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
Johan Alfven7b3008a2023-04-13 18:54:47 +02002215 scalar = create_const_tensor(
2216 op.name + "_scalar", [1, 1, 1, 1], DataType.int32, [output_multiplier], quantization=identity_quant
2217 )
Alexander Hansson90c34b52023-05-31 15:03:03 +00002218 op.set_input_tensor(scalar, 1)
2219 op.set_ifm_ofm_shapes()
Alexander Hansson1d5e8592023-06-27 12:36:25 +00002220 op.ofm_shapes[0] = Shape4D(intermediate_shape)
Johan Alfven7b3008a2023-04-13 18:54:47 +02002221
2222 # Reference using TFL rounding for the multiply
Alexander Hansson90c34b52023-05-31 15:03:03 +00002223 op.rounding_mode = RoundingMode.TFLite
Johan Alfven7b3008a2023-04-13 18:54:47 +02002224
2225 # Need to use explicit scaling to get the wanted shift
Alexander Hansson90c34b52023-05-31 15:03:03 +00002226 op.explicit_scaling = ExplicitScaling(False, [output_shift_vela], [1])
2227 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002228 return op
2229
2230
Raul Farkas66207142023-05-25 11:15:20 +01002231def convert_ops_to_lut(op: Operation, arch, nng) -> Operation:
2232 """Convert Exp to 8bit or 16bit LUT to allow for support on NPU."""
Johan Alfvence502732023-04-24 13:35:40 +02002233 if op.type == Op.Exp:
2234 if op.ifm.dtype == DataType.int8:
2235 return create_lut_8bit_op(op, math.exp, "exp")
2236 elif op.ifm.dtype == DataType.int16:
2237 return create_lut_int16_op(op, math.exp, "exp")
2238 else:
2239 # Should already be catched in tflite supported ops
2240 assert False, f"Unsupported data type {op.ifm.dtype} for {op.type}"
2241
Johan Alfven8e525ca2023-05-07 13:12:37 +02002242 if op.type == Op.Rsqrt:
2243 return create_lut_rsqrt_int8_op(op)
2244
Johan Alfvence502732023-04-24 13:35:40 +02002245 return op
2246
2247
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002248def optimise_quantize(op: Operation, arch, nng):
2249
2250 if op.type == Op.Quantize and op.run_on_npu:
2251
2252 ifm, ofm = op.get_ifm_ofm()
2253 input_values = ifm.values
2254
2255 # Guard clause - input not const or no values to quantize
2256 if ifm.ops[0].type != Op.Const or input_values is None:
2257 return op
2258
2259 # Singular val in numpy array, convert to indexable array
2260 if input_values.ndim == 0:
2261 input_values = np.array([input_values])
2262
Fredrik Svedberg11563172022-07-06 14:54:12 +02002263 # requantized int8 to int8 or int16 to int16
2264 if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002265
2266 # scale needs to use double precision to match TFLite reference kernel
2267 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
2268 effective_multiplier, effective_shift = quantise_scale(effective_scale)
2269
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002270 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002271 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002272 input_val = val - ifm.quantization.zero_point
2273
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002274 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
2275 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002276
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002277 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
2278 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002279
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002280 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
2281 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002282
2283 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002284 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002285
2286 quantized_vals = []
2287 for val in input_values:
2288
2289 # Derive quantized value
2290 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002291 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
2292 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002293
2294 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002295 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
2296
2297 # Unsupported data type
2298 else:
2299 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002300
2301 # Make quantize op const and disconnect from parent node
2302
2303 # Remove reference of the current quant op from the parent tensor's consumer list
2304 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
2305
2306 # Clear any references to parent node
2307 op.inputs = []
2308
2309 # Convert this quantize op to const
2310 op.type = Op.Const
2311
2312 return op
2313
2314
Ayaan Masood4965fae2022-06-29 11:30:57 +01002315def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
2316 """Static optimisation for SHAPE operator output value known at compile time"""
2317
2318 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
2319
2320 if op.type == Op.Shape and op.run_on_npu:
2321
2322 ifm, ofm = op.get_ifm_ofm()
2323
2324 if len(ifm.shape) != ofm.shape[0]:
2325 return op
2326
2327 # Remove reference of the current shape op from the parent tensor's consumer list
2328 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
2329
2330 # Clear any references to parent node
2331 op.inputs = []
2332
2333 # Convert this SHAPE op to const
2334 op.type = Op.Const
2335
2336 # Add size calculation to shape output tensors
2337 ofm.values = np.array(ifm.shape)
2338
2339 return op
2340
2341
Johan Gunnarsson24570f02023-08-29 15:33:10 +02002342def fixup_pool_strides(op: Operation, arch, nng):
2343 """Fixup Pool strides when the kernel size is equal to IFM shape. Stride is then irrelevant."""
Johan Gunnarsson7ccc5832023-09-07 12:28:28 +02002344 if op.type in (Op.AvgPool, Op.MaxPool, Op.QuantizedAvgPool, Op.QuantizedMaxPool):
Johan Gunnarsson24570f02023-08-29 15:33:10 +02002345 ifm, _ = op.get_ifm_ofm()
2346 kernel_w, kernel_h = op.get_kernel_size()
2347 if kernel_w == ifm.shape[2] and kernel_h == ifm.shape[1]:
2348 stride_n, _, _, stride_c = op.attrs["strides"]
2349 op.attrs["strides"] = (stride_n, 1, 1, stride_c)
2350 op.attrs["stride_w"] = 1
2351 op.attrs["stride_h"] = 1
2352
2353 return op
2354
2355
Raul Farkas66207142023-05-25 11:15:20 +01002356def fixup_dilation_gt2(op: Operation, arch, nng) -> Operation:
2357 """Fixup Conv2DBias and DepthwiseConv2DBias to allow dilation greater than 2."""
Tim Hallea4ba662022-11-11 18:19:53 +00002358 assert op.run_on_npu
2359 if op.type == Op.Conv2DBias or op.type == Op.DepthwiseConv2DBias:
2360 dilation_w, dilation_h = op.get_kernel_dilation()
2361
2362 # if dilation in either axis is greater than that supported by the hardware then we must manually dilate the
2363 # kernel
2364 if dilation_w > 2 or dilation_h > 2:
2365 kernel_w, kernel_h = op.get_kernel_size()
2366 kernel_ic = op.weights.shape[-2]
2367 kernel_oc = op.weights.shape[-1]
2368
2369 # if the dilation is a multiple of 2 then the hardware dialtion can be enabled to provide that multiple
2370 # of 2. this allows the kernel size to be reduced (via the scaled dilation) by half in that dimension.
2371 # odd = 1, even = 2
2372 hw_dilation_h = 1 if (dilation_h & 1) else 2
2373 hw_dilation_w = 1 if (dilation_w & 1) else 2
2374
2375 scale_dilation_h = dilation_h // hw_dilation_h
2376 scale_dilation_w = dilation_w // hw_dilation_w
2377
2378 # create new empty kernel (HWIO format)
2379 new_kernel_h = (kernel_h - 1) * scale_dilation_h + 1
2380 new_kernel_w = (kernel_w - 1) * scale_dilation_w + 1
2381
2382 new_kernel_shape = [new_kernel_h, new_kernel_w, kernel_ic, kernel_oc]
2383 new_kernel_values = np.zeros(new_kernel_shape, dtype=op.weights.values.dtype)
2384
2385 # copy the original kernel values into the new sparse kernel
2386 for h in range(0, kernel_h):
2387 for w in range(0, kernel_w):
2388 new_h = h * scale_dilation_h
2389 new_w = w * scale_dilation_w
2390 new_kernel_values[new_h, new_w, :, :] = op.weights.values[h, w, :, :]
2391
2392 # update the weight tensor with the new dilated kernel
2393 op.weights.shape = new_kernel_shape
2394 op.weights.values = new_kernel_values
2395
2396 # enable(=2) / disable(=1) hardware dilation
2397 op.attrs["dilation"] = (1, hw_dilation_h, hw_dilation_w, 1) # nhwc format
2398 op.attrs["dilation_h_factor"] = hw_dilation_h
2399 op.attrs["dilation_w_factor"] = hw_dilation_w
2400
2401 return op
2402
2403
Tim Hall2180a172023-03-10 18:11:34 +00002404def fixup_reshape(op, arch, nng):
2405 def _get_explicit_shape(implicit_shape, total_size):
2406 # the explicit shape is a copy of the implicit shape but with the special -1 (remaining size) value converted to
2407 # the appropriate value
2408 if implicit_shape is None:
2409 return None
2410
2411 explicit_shape = list(implicit_shape)
2412 if -1 in explicit_shape:
2413 explicit_shape[explicit_shape.index(-1)] = int(total_size / abs(np.prod(implicit_shape)))
2414
2415 return explicit_shape
2416
2417 if op.type == Op.Reshape:
2418 ifm_tensor, _, ofm_tensor = op.get_ifm_ifm2_ofm()
2419 ifm_size = ifm_tensor.elements()
2420 ofm_shape = ofm_tensor.shape
2421
2422 new_shape_tensor_shape = op.inputs[1].values.flatten() if len(op.inputs) > 1 else None
2423 new_shape_tensor_shape = _get_explicit_shape(new_shape_tensor_shape, ifm_size)
2424
2425 new_shape_attribute = op.attrs.get("new_shape", None)
2426 new_shape_attribute = _get_explicit_shape(new_shape_attribute, ifm_size)
2427
2428 # if present the new shape tensor overrides the new_shape attribute
2429 if new_shape_tensor_shape is not None:
2430 # check tensor
2431 if not np.array_equal(new_shape_tensor_shape, ofm_shape):
2432 print(
2433 f"Warning: {optype_to_builtintype(op.type)} '{op.name}' has new shape tensor"
2434 f" ({new_shape_tensor_shape}) that does not match output tensor shape {ofm_shape}. Will use output"
2435 f" tensor shape."
2436 )
2437 elif new_shape_attribute is not None:
2438 # check attribute
2439 if not np.array_equal(new_shape_attribute, ofm_shape):
2440 print(
2441 f"Warning: {optype_to_builtintype(op.type)} '{op.name}' has new_shape attribute"
2442 f" ({new_shape_attribute}) that does not match output tensor shape {ofm_shape}. Will use output"
2443 f" tensor shape."
2444 )
2445 else:
2446 print(
2447 f"Warning: {optype_to_builtintype(op.type)} '{op.name}' does not have a new shape tensor or a new_shape"
2448 f" attribute. Will use output tensor shape {ofm_shape}."
2449 )
2450
2451 # force new shape tensor to output shape
2452 new_shape_tensor = create_const_tensor(
2453 op.name + "_new_shape", [len(ofm_shape)], DataType.int32, np.array(ofm_shape, np.int32)
2454 )
2455 if len(op.inputs) > 1:
2456 op.set_input_tensor(new_shape_tensor, 1)
2457 else:
2458 op.add_input_tensor(new_shape_tensor)
2459
2460 # force new_shape attribute to output shape
2461 op.attrs["new_shape"] = ofm_shape
2462
2463 return op
2464
2465
Tim Hall9cf63a32023-06-27 12:07:49 +01002466def convert_conv_groups(op: Operation, arch, nng):
2467 """
2468 Convert convolution groups to a split followed by separate convolutions and then a concat.
2469 This needs to run before the concat and split handling functions"""
2470 if not op.type.is_conv2d_op():
2471 return op
2472
2473 num_conv_groups = op.attrs.get("num_conv_groups", 0)
2474 if num_conv_groups > 1:
2475 # convolution groups params
2476 ifm_depth_cg = op.ifm.shape[-1] // num_conv_groups
2477 num_filters_cg = op.weights.shape[-1] // num_conv_groups
2478
2479 # create split
2480 split_op = Operation(Op.Split, f"{op.name}_split")
2481 split_op.attrs.update(
2482 {
2483 "num_splits": num_conv_groups,
2484 }
2485 )
2486 # first input is the split axis
2487 split_op.add_input_tensor(
2488 # split along the depth axis
2489 create_const_tensor(f"{split_op.name}_axis", [0], DataType.int32, [-1])
2490 )
2491 # second input is the ifm
2492 split_op.add_input_tensor(op.ifm)
2493 # calculate shape of each ofm part
2494 split_op_ofm_shape = op.ifm.shape[:-1] + [ifm_depth_cg]
2495
2496 # create concat. do this prior to each conv group so that the for-loop can reference the concat as it iterates
2497 concat_op = Operation(Op.ConcatTFLite, f"{op.name}_concat")
2498 concat_op.attrs.update(
2499 {
2500 "axis": -1,
2501 "fused_activation_function": None,
2502 }
2503 )
2504 # calculate shape of each ifm part
2505 concat_op_ifm_shape = op.ofm.shape[:-1] + [num_filters_cg]
2506 # output is the concatenated tensor
2507 concat_op.set_output_tensor(op.ofm) # will disconnect ofm from op
2508
2509 # for each conv group
2510 for i in range(num_conv_groups):
2511 # cg params
2512 cg_oc_start = i * num_filters_cg
2513 cg_oc_end = (i + 1) * num_filters_cg
2514
2515 # split has multiple outputs
2516 split_op_ofm_part = Tensor(split_op_ofm_shape, op.ifm.dtype, f"{split_op.name}_out{i}")
2517 split_op_ofm_part.quantization = op.ifm.quantization.clone()
2518 split_op.add_output_tensor(split_op_ofm_part)
2519
2520 # concat has multiple inputs
2521 concat_op_ifm_part = Tensor(concat_op_ifm_shape, op.ifm.dtype, f"{concat_op.name}_in{i}")
2522 concat_op_ifm_part.quantization = op.ofm.quantization.clone()
2523 concat_op.add_input_tensor(concat_op_ifm_part)
2524
2525 # create convolution group operator
2526 conv_group_op = Operation(op.type, f"{op.name}_cg{i}")
2527 conv_group_op.attrs = op.attrs.copy()
2528 conv_group_op.attrs["num_conv_groups"] = 1
2529 # first input is the ifm
2530 conv_group_op.add_input_tensor(split_op_ofm_part)
2531 # second input is weights. the number of filters (i.e. the output channels) need to be split equally
2532 # across all of the convolution groups
2533 conv_group_op_weights_shape = op.weights.shape[:-1] + [num_filters_cg]
2534 conv_group_op_weights_quant = op.weights.quantization.clone()
2535 conv_group_op_weights_quant.scale_f32 = op.weights.quantization.scale_f32[..., cg_oc_start:cg_oc_end]
2536 conv_group_op_weights_quant.zero_point = op.weights.quantization.zero_point[..., cg_oc_start:cg_oc_end]
2537 conv_group_op.add_input_tensor(
2538 create_const_tensor(
2539 f"{op.weights.name}_cg{i}",
2540 conv_group_op_weights_shape,
2541 op.weights.dtype,
2542 op.weights.values[..., cg_oc_start:cg_oc_end],
2543 op.weights.purpose,
2544 conv_group_op_weights_quant,
2545 )
2546 )
2547 # third input is bias. like the weights, the bias needs to be split equally across all of the convolution
2548 # groups
2549 if op.bias is None:
2550 conv_group_op.add_input_tensor(None)
2551 else:
2552 conv_group_op_bias_shape = op.bias.shape[:-1] + [num_filters_cg]
2553 conv_group_op_bias_quant = op.bias.quantization.clone()
2554 conv_group_op_bias_quant.scale_f32 = op.bias.quantization.scale_f32[..., cg_oc_start:cg_oc_end]
2555 conv_group_op_bias_quant.zero_point = op.bias.quantization.zero_point[..., cg_oc_start:cg_oc_end]
2556 conv_group_op.add_input_tensor(
2557 create_const_tensor(
2558 f"{op.bias.name}_cg{i}",
2559 conv_group_op_bias_shape,
2560 op.bias.dtype,
2561 op.bias.values[..., cg_oc_start:cg_oc_end],
2562 op.bias.purpose,
2563 op.bias.quantization,
2564 )
2565 )
2566 # output goes to the concat
2567 conv_group_op.set_output_tensor(concat_op_ifm_part)
2568 # update the cg op shapes and debug db
2569 conv_group_op.set_ifm_ofm_shapes()
2570 DebugDatabase.add_optimised(op, conv_group_op)
2571
2572 # update the split/concat op shapes/debug db
2573 split_op.set_ifm_ofm_shapes()
2574 DebugDatabase.add_optimised(op, split_op)
2575 concat_op.set_ifm_ofm_shapes()
2576 DebugDatabase.add_optimised(op, concat_op)
2577
2578 # disconnect the original convolution operator.
2579 # the ofm has already been disconnected by concat_op.set_output_tensor()
2580 op.ifm.consumer_list.remove(op)
2581 op.inputs = []
2582 op.outputs = []
2583
2584 # return last op so that other graph optimiser functions can process the new operators
2585 op = concat_op
2586
2587 return op
2588
2589
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002590def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02002591 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002592 return op
2593
2594
wilisa0146c94772023-02-08 09:56:14 +00002595def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
Fredrik Svedberg11563172022-07-06 14:54:12 +02002596 # Compile time static optimisations
wilisa0146c94772023-02-08 09:56:14 +00002597 optimisation_list = [
2598 optimise_quantize,
2599 convert_shape_op_to_constant_tensor,
2600 fixup_or_check_asymmetric_weights(force_symmetric_int_weights),
Johan Gunnarsson24570f02023-08-29 15:33:10 +02002601 fixup_pool_strides,
wilisa0146c94772023-02-08 09:56:14 +00002602 ]
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002603
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002604 for idx, sg in enumerate(nng.subgraphs):
2605 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002606 nng,
2607 sg,
2608 arch,
2609 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01002610 optimisation_list,
2611 rewrite_unsupported=False,
2612 )
2613
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002614 # Pre-processing step
Tim Hall9cf63a32023-06-27 12:07:49 +01002615 pre_process_list = [supported_operator_check, set_ifm_ofm_op_shapes, fixup_reshape, convert_conv_groups]
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002616
Ayaan Masood4965fae2022-06-29 11:30:57 +01002617 for idx, sg in enumerate(nng.subgraphs):
2618 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
2619 nng,
2620 sg,
2621 arch,
2622 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02002623 pre_process_list,
2624 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002625 )
2626
2627 # Handle Concat Ops
2628 for idx, sg in enumerate(nng.subgraphs):
2629 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
2630 sg.refresh_after_modification()
2631
2632 # Handle Split Ops
2633 for idx, sg in enumerate(nng.subgraphs):
2634 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
2635 nng,
2636 sg,
2637 arch,
2638 [],
2639 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
2640 rewrite_unsupported=False,
2641 )
2642
2643 for idx, sg in enumerate(nng.subgraphs):
2644 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002645 nng,
2646 sg,
2647 arch,
2648 [rewrite_split_ops],
2649 [],
2650 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002651 )
2652
Johan Alfvena5e1b622023-02-02 14:59:03 +01002653 # Bypass or rewrite memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002654 for idx, sg in enumerate(nng.subgraphs):
2655 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002656 nng,
2657 sg,
2658 arch,
2659 [],
Johan Alfvena5e1b622023-02-02 14:59:03 +01002660 [bypass_memory_only_ops],
Jonas Ohlssond8575072022-03-30 10:30:25 +02002661 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002662 )
2663
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002664 # Rewrite of operators
2665 op_rewrite_list = [
2666 set_tensor_equivalence,
Johan Alfvence502732023-04-24 13:35:40 +02002667 convert_ops_to_lut,
Rickard Bolina68b82a2023-04-20 15:12:28 +00002668 convert_mean_to_depthwise_conv,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002669 convert_depthwise_to_conv,
2670 convert_conv_to_fc,
Fredrik Svedberg0ac08042023-04-11 22:35:04 +02002671 convert_lstm,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002672 convert_softmax,
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02002673 convert_prelu,
Fredrik Svedberg36424312022-09-16 09:39:26 +02002674 convert_mul_max_to_abs_or_lrelu,
2675 convert_lrelu,
Raul Farkas3e7157b2023-05-09 09:09:17 +01002676 convert_avg_pool_to_conv2d,
Raul Farkas69782af2023-05-09 10:39:52 +01002677 fixup_strided_conv,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002678 convert_hardswish_to_lut,
2679 rewrite_fully_connected_input,
2680 convert_batched_fc_shape,
2681 fixup_conv2d_backprop,
2682 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002683 reorder_depthwise_weights,
Rickard Bolin6986a072022-12-19 12:33:40 +00002684 convert_argmax_to_depthwise_conv_and_max_pool,
Tim Hall885033b2022-07-21 11:46:03 +01002685 fixup_resize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002686 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01002687 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002688 convert_tanh_sigmoid_to_lut,
Johan Gunnarsson98556372023-08-10 13:10:44 +02002689 convert_quantize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002690 replace_pad_by_hw_pad,
Tim Hallea4ba662022-11-11 18:19:53 +00002691 fixup_dilation_gt2,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002692 ]
2693
2694 for idx, sg in enumerate(nng.subgraphs):
2695 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002696 nng,
2697 sg,
2698 arch,
2699 [],
2700 op_rewrite_list,
2701 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002702 )
2703
2704 for idx, sg in enumerate(nng.subgraphs):
2705 # remove passthrough tensors and attempt further optimizations
2706 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
2707 nng,
2708 sg,
2709 arch,
2710 [remove_passthrough_tensor],
2711 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
2712 )
2713
2714 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
2715 # since ifm/ofm_shapes are of importance to this function
2716 for sg in nng.subgraphs:
2717 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
2718 sg.refresh_after_modification()
2719
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01002720 # Make sure that const optimisations on subgraph outputs are handled correctly
2721 for sg in nng.subgraphs:
2722 for ofm in sg.output_tensors:
2723 if ofm.is_const and ofm.ops[0].type_changed:
2724 # Subgraph output cannot be const - insert a memory copy
2725 op = ofm.ops[0]
2726 ofm_clone = ofm.clone()
2727 ofm_clone.values = ofm.values
2728 ofm.values = None
Tim Hall3b1578e2023-01-13 17:57:25 +00002729 zero = create_const_tensor("zero", [1], ofm.dtype, [0], quantization=ofm.quantization)
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01002730 memcpy = create_add_nop(f"{ofm.name}_copy")
2731 memcpy.add_input_tensor(ofm_clone)
2732 memcpy.add_input_tensor(zero)
2733 memcpy.set_output_tensor(ofm)
2734 memcpy.set_ifm_ofm_shapes()
2735 op.set_output_tensor(ofm_clone)
2736 DebugDatabase.add_optimised(op, memcpy)
2737
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002738 return nng