blob: 518b6db0e7e0554304326eddee2cfea4726128b6 [file] [log] [blame]
Tim Hall3b1578e2023-01-13 17:57:25 +00001# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020017# Description:
18# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
19# to do the traversal of the graph.
20import math
21import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020022
23import numpy as np
24
25from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020026from . import rewrite_graph
27from . import scaling
28from .api import NpuRoundingMode
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020029from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020030from .data_type import DataType
31from .debug_database import DebugDatabase
32from .errors import UnsupportedFeatureError
33from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020034from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020035from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020036from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020037from .graph_optimiser_util import convert_to_lut
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020038from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020039from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020040from .graph_optimiser_util import needed_total_padding
41from .graph_optimiser_util import set_ifm_ofm_op_shapes
42from .graph_optimiser_util import set_tensor_equivalence
43from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020044from .numeric_util import round_away_zero
45from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020046from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020047from .operation import NpuBlockType
48from .operation import Op
49from .operation import Operation
50from .operation import Padding
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +010051from .operation_util import create_add_nop
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020052from .operation_util import create_avgpool_nop
Rickard Bolin6986a072022-12-19 12:33:40 +000053from .operation_util import create_depthwise_maxpool
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020054from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010055from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020056from .shape4d import Shape4D
57from .softmax import SoftMax
58from .tensor import check_quantized_tens_scaling_equal
59from .tensor import create_const_tensor
60from .tensor import create_equivalence_id
61from .tensor import QuantizationParameters
62from .tensor import Tensor
63from .tensor import TensorPurpose
64from .tflite_mapping import optype_to_builtintype
65
66passthrough_nodes = (Op.Identity,)
67
68
69def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
70 """Creates an average pool for the given concat op/input feature map"""
71 ofm = concat_op.ofm
72 avgpool_op = create_avgpool_nop(name)
73 avgpool_op.inputs = [ifm]
74 avgpool_op.outputs = [ofm]
75
76 avgpool_op.write_offset = write_offset
77 avgpool_op.write_shape = ifm_shape
78 ofm.ops.append(avgpool_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020079 avgpool_op.ifm_shapes.append(ifm_shape)
80 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
81 avgpool_op.memory_function = Op.ConcatSliceWrite
wilisa0179a89042022-11-02 17:18:43 +000082 DebugDatabase.add_optimised(concat_op, avgpool_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020083 return avgpool_op
84
85
86def remove_passthrough_tensor(tens, arch, nng):
87 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
88 assert len(tens.ops[0].inputs) == 1
89 tens = tens.ops[0].inputs[0]
90 return tens
91
92
93def rewrite_concat_ops(op, arch):
94 if not op.run_on_npu or not op.type.is_concat_op():
95 return
96
97 axis_4D = 0
98 ofm = op.ofm
99 ofm.ops = []
100 offset = 0
101
102 unfuse_activation_function(op)
103
104 if op.type == Op.Pack:
105 # Pack is also referred to as Stack
106 axis = int(op.attrs["axis"])
107 if axis < 0: # Convert to positive axis
108 axis = len(op.inputs[0].shape) + 1 + axis
109
110 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
111
112 axis_4D = axis + (4 - len(desired_shape))
113
114 for idx, inp in enumerate(op.inputs):
115 op.ifm_shapes[idx] = Shape4D(desired_shape)
116 op.type = Op.PackReshaped
117
118 inputs, axis = op.get_concat_inputs_axis()
119 for idx, inp in enumerate(inputs):
120 if op.type != Op.PackReshaped:
121 op.ifm_shapes[idx] = Shape4D(inp.shape)
122 if axis >= 0:
123 axis_4D = axis + (4 - len(inp.shape))
124 else:
125 axis_4D = axis
126 write_offset = [0, 0, 0, 0]
127 write_offset[axis_4D] = offset
128 concat_end = offset + op.ifm_shapes[idx][axis_4D]
129 create_avg_pool_for_concat(
130 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
131 )
132 offset = concat_end
133 assert ofm.shape[axis] == offset
134
135 return op
136
137
138def rewrite_split_ops(tens, arch, nng):
139
140 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
141 split_op = tens.ops[0]
142
143 # Not supported so leave it and run on CPU
144 if not split_op.run_on_npu:
145 return tens
146
147 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
148
149 tens.ops = []
150 new_op = Operation(Op.SplitSliceRead, split_op.name)
151 new_op.inputs = [inp]
152 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000153 if None in (offset_end, offset_start):
154 read_shape = None
155 else:
156 # the read shape is relative to each start offset
157 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200158
159 # For Split the offset cannot be extracted from the tensor so it has to
160 # be calculated from the index of the output tensor
161 if axis is not None:
162 # Get the start and end of the split
163 offset_start = [0] * 4
164 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
165 for idx, out in enumerate(outputs):
166 if axis_4D_list is not None:
167 axis_4D = axis_4D_list[idx]
168 else:
169 split_op.ofm_shapes[idx] = Shape4D(out.shape)
170 if axis >= 0:
171 axis_4D = axis + (4 - len(out.shape))
172 else:
173 axis_4D = axis
174
175 if out == tens:
176 ofm_shape_idx = idx
177 read_shape = split_op.ofm_shapes[idx]
178 break
179
180 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
181
182 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
183 new_op.read_shapes[0] = read_shape
184 new_op.run_on_npu = True
185 new_op.set_output_tensor(tens)
186 new_op.ifm_shapes.append(Shape4D(inp.shape))
187 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
188 DebugDatabase.add_optimised(split_op, new_op)
189
190 return tens
191
192
193def remove_SplitSliceRead(op, arch):
194
195 if op.type == Op.SplitSliceRead:
196 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
197 if (
198 len(op.ofm.consumer_list) == 1
199 and op.ofm.consumer_list[0] is not None
200 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200201 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200202 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
203 ):
204 # SplitSliceRead can be performed by tensor consumer
205 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200206 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200207 else:
208 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
209 avgpool_op.add_input_tensor(op.ifm)
210 avgpool_op.outputs = [op.ofm]
211 op.ofm.ops.remove(op)
212 op.ofm.ops.append(avgpool_op)
213 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
214 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
215 avgpool_op.read_offsets[0] = op.read_offsets[0]
216 avgpool_op.read_shapes[0] = op.read_shapes[0]
217
218 op.ifm.consumer_list.remove(op)
219 DebugDatabase.add_optimised(op, avgpool_op)
220
221
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200222def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
223 k_w, k_h = kernel.dilated_wh()
224 s_x, s_y = kernel.stride
225 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
226 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
227 if padding_type == Padding.SAME:
228 left_pad = (xpad + 0) // 2
229 right_pad = (xpad + 1) // 2
230 top_pad = (ypad + 0) // 2
231 bottom_pad = (ypad + 1) // 2
232 elif padding_type == Padding.VALID:
233 left_pad = 0
234 right_pad = 0
235 top_pad = 0
236 bottom_pad = 0
237 elif padding_type == Padding.EXPLICIT:
238 # Padding is specified in a PAD operator which has been bypassed.
239 top, left, bottom, right = explicit_padding
240 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
241 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Rickard Bolin9ae34552022-06-09 13:07:17 +0000242 elif padding_type == Padding.TILE:
243 # The values in the explicit padding only represent the "direction" in which to pad
244 top_pad, left_pad, bottom_pad, right_pad = explicit_padding
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200245 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000246 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200247 padding = (top_pad, left_pad, bottom_pad, right_pad)
248 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
249 return padding, skirt
250
251
252def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
253 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
254 if padding_type == Padding.SAME:
255 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
256 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
257 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
258 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
259 left_pad = max(kernel_width - 1 - right_pad, 0)
260 top_pad = max(kernel_height - 1 - bottom_pad, 0)
261 elif padding_type == Padding.VALID:
262 right_pad = max(kernel_width - 2, 0)
263 bottom_pad = max(kernel_height - 2, 0)
264 left_pad = kernel_width - 1
265 top_pad = kernel_height - 1
266 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000267 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200268 padding = (top_pad, left_pad, bottom_pad, right_pad)
269 skirt = padding
270 return padding, skirt
271
272
273def fixup_conv2d_backprop(op, arch, nng):
274 if op.type == Op.Conv2DBackpropInput:
275 # flip the inputs
276 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
277 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000278 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200279
280 # Update strides
281 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
wilisa0179a89042022-11-02 17:18:43 +0000282 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200283
284 return op
285
286
287# Convert the op to an elementwise add
Tim Hall885033b2022-07-21 11:46:03 +0100288def convert_resize_1x1_to_add(op):
289 op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200290 op.name = op.name + "_add"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200291 # Create an input tensor filled with zeros
wilisa018289d512023-01-12 08:17:23 +0000292 name = op.inputs[1].name + "_add"
293 dtype = op.inputs[0].dtype
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200294 shape = op.ofm_shapes[0].as_list()
wilisa018289d512023-01-12 08:17:23 +0000295 values = np.zeros(shape, dtype.as_numpy_type())
296 quantization = QuantizationParameters(0.0, 255.0)
297 quantization.scale_f32 = 1.0
298 quantization.zero_point = 0
wilisa0116b5e5e2023-02-14 12:03:59 +0000299 op.inputs[1] = op.inputs[0]
300 op.set_input_tensor(create_const_tensor(name, shape, dtype, values, quantization=quantization), 0)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200301 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000302 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200303
304 return op
305
306
Tim Hall885033b2022-07-21 11:46:03 +0100307# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
308# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
309# to select the appropriate nearest neighbor value
310def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
311 ifm = op.ifm
312 ofm = op.ofm
313 output_depth = ofm.shape[-1]
314 dw_op_attrs = {
315 "padding": Padding.VALID,
316 "stride_h": 1,
317 "stride_w": 1,
318 "strides": (1, 1, 1, 1),
319 "depth_multiplier": 1,
320 "channel_multiplier": 1,
321 "dilation_h_factor": 1,
322 "dilation_w_factor": 1,
323 "dilation": (1, 1, 1, 1),
324 }
325
326 # change resizebilinear to depthwise
327 op.type = Op.DepthwiseConv2DBias
328 op.attrs.update(dw_op_attrs)
329 op.set_input_tensor(ifm, 0) # ifm tensor index
330 op.activation = None
331
332 # add input resample to resize by x2
333 op.ifm_resampling_mode = resampling_mode.NEAREST
334
335 # don't care about the rounding mode as it is nearest neighbor
336
337 # setup weight tensor
338 weight_quant = QuantizationParameters()
339 weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
340 weight_quant.zero_point = 0
341 weight_quant.quant_dim = 0
342 ofm_dtype = ofm.dtype
Tim Hall3b1578e2023-01-13 17:57:25 +0000343 if ofm_dtype.type == BaseType.UnsignedInt:
Tim Hall885033b2022-07-21 11:46:03 +0100344 weight_quant.quant_min = 0
345 weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
346 else:
Tim Hall885033b2022-07-21 11:46:03 +0100347 weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
348 weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
349
350 weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
351
352 # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
353 # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
354 # below-and-right (i.e. next) to it (D).
355 # 0---1---2
356 # | A | B |
357 # 1---*---+
358 # | C | D |
359 # 2---+---+
360 weight_values = [0] * (upscale_factor * upscale_factor)
361 centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
362 weight_values[centre_coeff] = 1
363
364 # add weight tensor, this will discard the size tensor of the resize op
365 op.set_input_tensor(
366 create_const_tensor(
367 "weights",
368 weight_shape,
Tim Hall3b1578e2023-01-13 17:57:25 +0000369 ofm_dtype,
Tim Hall885033b2022-07-21 11:46:03 +0100370 np.array(weight_values).reshape(weight_shape),
Tim Hall885033b2022-07-21 11:46:03 +0100371 quantization=weight_quant,
372 ),
373 1, # inputs tensor weight index
374 )
375
376 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
377 # need to append the bias tensor as resize ops only have 2 inputs
378 assert len(op.inputs) == 2
379 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +0200380 fixup_bias_tensors(op, None, None, DataType.int32)
Tim Hall885033b2022-07-21 11:46:03 +0100381
382 # finally update the shape incase we've change the tensor shapes or connections
383 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000384 DebugDatabase.add_optimised(op, op)
Tim Hall885033b2022-07-21 11:46:03 +0100385
386 return op
387
388
389# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
390# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
391# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
392def convert_resize_to_upscale_and_average_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200393 pre_op = op
394 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000395 dtype = op.ifm.dtype
Tim Hall885033b2022-07-21 11:46:03 +0100396
Rickard Boline546def2022-01-25 15:45:00 +0000397 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Tim Hall47c76362022-07-18 21:26:47 +0100398 op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
Tim Hall3c5cfe92022-03-16 16:31:57 +0000399 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200400
401 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
Tim Hall47c76362022-07-18 21:26:47 +0100402
403 # Get upscale factor that was calculated in the supported operators check
404 upscale_factor = op.attrs["upscale_factor"]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200405
Rickard Boline546def2022-01-25 15:45:00 +0000406 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100407 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
408 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
Rickard Boline546def2022-01-25 15:45:00 +0000409 n = int(np.log2(upscale_factor))
410
Tim Hall885033b2022-07-21 11:46:03 +0100411 # Perform x2 upscaling n-1 times
Rickard Boline546def2022-01-25 15:45:00 +0000412 scaled_op = pre_op
413 for count in range(n - 1):
414 if count > 0:
415 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200416 scaled_op.inputs[0] = pre_op.outputs[0]
417
Tim Hall885033b2022-07-21 11:46:03 +0100418 # Nearest neighbor x2 upscaling
Tim Hall47c76362022-07-18 21:26:47 +0100419 upscaled_shape = upscaled_shape * 2
Rickard Boline546def2022-01-25 15:45:00 +0000420 shape = op.ofm_shapes[0].as_list()
421 shape[1:3] = upscaled_shape
422 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
423 out_tens.quantization = op.outputs[0].quantization.clone()
424 scaled_op.set_output_tensor(out_tens)
425 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200426
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200427 scaled_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000428 DebugDatabase.add_optimised(op, scaled_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200429
Tim Hall885033b2022-07-21 11:46:03 +0100430 # Last x2 upscaling
Rickard Boline546def2022-01-25 15:45:00 +0000431 if n > 1:
432 scaled_op = op.clone(f"_{n-1}")
433 scaled_op.inputs[0] = pre_op.outputs[0]
Tim Hall885033b2022-07-21 11:46:03 +0100434
435 if scaled_op.original_type == Op.ResizeBilinear:
436 if scaled_op.attrs["align_corners"]:
437 # no padding
438 scaled_op.attrs["padding"] = Padding.VALID
439 else:
440 # padding to the right and bottom (limits average pool to 8x8 kernel)
441 scaled_op.attrs["padding"] = Padding.EXPLICIT
442 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
443
444 # kernal size dependent on the upscaling factor
445 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
446 else: # Op.ResizeNearestNeighbor
447 if scaled_op.attrs["align_corners"]:
448 # use depthwise conv to select the correct value
449 scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
450 else:
Johan Alfvéna64616c2022-10-17 12:29:12 +0200451 # Keep 1x1 kernel and average pool, this applies both when
452 # half-pixel-centers is True and False. Calculations are the
453 # same in the reference.
Tim Hall885033b2022-07-21 11:46:03 +0100454 pass
455
Rickard Boline546def2022-01-25 15:45:00 +0000456 scaled_op.outputs = outputs
457 scaled_op.outputs[0].ops = [scaled_op]
458 scaled_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000459 DebugDatabase.add_optimised(op, scaled_op)
Rickard Boline546def2022-01-25 15:45:00 +0000460
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200461 return op
462
463
Rickard Bolin6986a072022-12-19 12:33:40 +0000464def convert_argmax_to_depthwise_conv_and_max_pool(op, arch, nng):
465 """
466 Convert ArgMax to DWConv2D->MaxPool->DWConv2D, see details below.
467
468 Example:
469 arr = [4, [00000100,
470 6, = 00000110, # <-- This is the largest value, so we're expecting argmax(arr) = 1
471 5] 00000101]
472
473 Use 16-bit precision and shift all values 7 bits to the left:
474 Shifted_arr = [0000001000000000,
475 0000001100000000,
476 0000001010000000]
477
478 Add "c - index of channel" to each channel:
479 Shifted_arr_plus_reverse_idx = [0000001000000010, (+2)
480 0000001100000001, (+1)
481 0000001010000000] (+0)
482
483 The index is reversed since ArgMax selects the lowest index if maximum value is found at two index. The index will
484 act as a tie-breaker between channels with equal values and since we want the smallest channel index to be chosen
485 we reverse the index before the maxpool and then subtract the index from the number of channel after the maxpool to
486 get the correct index.
487
488 Find the maximum value in the array:
489 val = max(shifted_arr_plus_reverse_idx) = 0000001100000001
490
491 Subtract the value from the number of channels:
492 shifted_arr_plus_idx = (c-1) - val = 2 - 1 = 1
493
494 Extract the 7 lowest bits using a LUT to cut off the 9 most significant bits:
495 idx = LUT(val) = 0000000000000001 = 1
496 """
497
498 if op.type == Op.ArgMax:
499 ifm, ofm = op.inputs[0], op.outputs[0]
500 identity_quant = QuantizationParameters()
501 identity_quant.zero_point = 0
502 identity_quant.scale_f32 = 1.0
503 if ofm.quantization is None:
504 ofm.quantization = identity_quant
505 # Add last dimension to ofm shape
506 ofm.shape += [1]
507 ofm.ops = []
508
509 # Create 1x1 Depthwise convolution with 2**7 weights for each channel to convert precision to 16 bit and shift
510 # all values 7 bits to the left
511 # Set necessary depthwise attributes
512 dw_op_attrs = {
513 "padding": Padding.VALID,
514 "stride_h": 1,
515 "stride_w": 1,
516 "strides": (1, 1, 1, 1),
517 "depth_multiplier": 1,
518 "channel_multiplier": 1,
519 "dilation_h_factor": 1,
520 "dilation_w_factor": 1,
521 "dilation": (1, 1, 1, 1),
522 "explicit_padding": None,
523 }
524 op.name = "depthwise_conv_SHL_7"
525 op.type = Op.DepthwiseConv2DBias
526 op.attrs.update(dw_op_attrs)
527 n, h, w, c = ifm.shape
528 shape = [1, 1, 1, c]
529 kernel = np.dstack([2**7] * c)
530 op.inputs = []
531 op.add_input_tensor(ifm)
532 op.add_input_tensor(
533 create_const_tensor(
534 "weights",
535 shape,
536 DataType.uint8,
537 np.array(kernel).reshape(shape),
538 quantization=identity_quant,
539 ),
540 )
541 # Let the bias for each channel be the "reverse" index of the channel it is in, ie c - channel_idx
542 reverse_idxs = list(reversed(range(c)))
543 bias_tensor = create_const_tensor(op.name + "_bias", [c], DataType.int64, reverse_idxs)
544 op.add_input_tensor(bias_tensor)
545
546 intermediate_tens = Tensor([n, h, w, c], DataType.int16, "int16_and_shifted_7_bits_left")
547 intermediate_tens.quantization = ifm.quantization
548 op.set_output_tensor(intermediate_tens)
549 op.set_ifm_ofm_shapes()
550 orig_ifm_shape = op.ifm_shapes[0]
551 DebugDatabase.add_optimised(op, op)
552
553 # To extract 7 least significant bits and swap reverse index back to real index using a LUT activation, we set
554 # the base value to c-1 and slope to -128. The 16-bit LUT uses a table of 32-bit values where the top 16 bits
555 # represent the slope and bottom 16 bits the base which are used to interpolate the activation value.
556 slope = (-128 & 0xFFFF) << 16 # Top 16 bits of 32 bit LUT table value
557 base = c - 1 # Bottom 16 bits of the LUT table value
558 lut_tensor = create_const_tensor(
559 "maxpool_LUT_extract_7_LSB",
560 [1, 1, 1, 512],
561 DataType.uint32,
562 [slope + base] * 512,
563 TensorPurpose.LUT,
564 )
565
566 # Split large feature maps into smaller chunks since the Depthwise Maxpool height dimension can overflow due to
567 # flattening the ifm to (H*W)xCx1
568 max_height = 2**16 // orig_ifm_shape.width
569 num_full_height_ops = orig_ifm_shape.height // max_height
570 last_op_height = orig_ifm_shape.height - max_height * num_full_height_ops
571 op_heights = [max_height] * num_full_height_ops
572 if last_op_height > 0:
573 op_heights.append(last_op_height)
574
575 # Create maxpool output tensor which is reshaped to 1x(H*W)x1x1. The product H*W might be larger than the
576 # maximum allowed height, but that's handled by reading and writing the data in chunks
577 maxpool_ofm = Tensor([1, orig_ifm_shape.height * orig_ifm_shape.width, 1, 1], DataType.int16, "argmax_maxpool")
578 maxpool_ofm.quantization = identity_quant
579
580 for op_idx, op_height in enumerate(op_heights):
581 maxpool_op = create_depthwise_maxpool(
582 f"dw_maxpool_{op_idx}", intermediate_tens, orig_ifm_shape, identity_quant
583 )
584 maxpool_op.outputs = [maxpool_ofm]
585 maxpool_ofm.ops.append(maxpool_op)
586 maxpool_op.ofm_shapes = [Shape4D(maxpool_ofm.shape)]
587 maxpool_op.set_activation_lut(lut_tensor)
588
589 # Set read and write shapes/offsets to read/write chunks of the IFM/OFM
590 maxpool_op.read_shapes[0] = Shape4D([1, op_height * orig_ifm_shape.width, orig_ifm_shape.depth, 1])
591 maxpool_op.read_offsets[0] = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0])
592 maxpool_op.write_shape = Shape4D([1, op_height * orig_ifm_shape.width, 1, 1])
593 maxpool_op.write_offset = Shape4D([0, sum(op_heights[:op_idx]) * orig_ifm_shape.width, 0, 0])
594 DebugDatabase.add_optimised(op, maxpool_op)
595
596 # Convert output to OFM dtype and reshape back to original OFM shape with 1x1 DWConv
597 dw_conv = Operation(Op.DepthwiseConv2DBias, f"depthwise_conv_convert_to_32bit_{op_idx}")
598 dw_conv.attrs.update(dw_op_attrs)
599 dw_conv.inputs = [maxpool_op.ofm]
600 dw_conv.add_input_tensor(
601 create_const_tensor(
602 "weights",
603 [1, 1, 1, 1],
604 DataType.uint8,
605 np.array([1]).reshape([1, 1, 1, 1]),
606 quantization=identity_quant,
607 ),
608 )
609 dw_conv.add_input_tensor(create_const_tensor(dw_conv.name + "_bias", [1], DataType.int64, [0]))
610 ofm.ops.append(dw_conv)
611 dw_conv.outputs = [ofm]
612 dw_conv.ifm_shapes.append(Shape4D([1, orig_ifm_shape.height, orig_ifm_shape.width, 1]))
613 dw_conv.ofm_shapes.append(Shape4D(ofm.shape))
614 DebugDatabase.add_optimised(op, dw_conv)
615
616 return op
617
618
Rickard Bolinfea15162022-07-04 16:19:16 +0000619def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
620 def _compute_interpolation_values(index, input_size, output_size):
621 scale = input_size / output_size
622 scaled_value = (index + 0.5 * half_pixel_centers) * scale - 0.5 * half_pixel_centers
623 lower_bound = max(np.floor(scaled_value), 0)
624
625 return scaled_value, lower_bound
626
627 def _compute_kernels(input_height, input_width, output_height, output_width):
628 kernels = []
629 for y in (1, 2):
630 for x in (1, 2):
631 sv_h, lb_h = _compute_interpolation_values(y, input_height, output_height)
632 sv_w, lb_w = _compute_interpolation_values(x, input_width, output_width)
633
634 # Interpolation values calculated for (x, y) = ([1, 2], [1, 2]) will always generalize to the whole
635 # input for upscale = 2 and input sizes >= 2x2 and be in the correct order for going left-to-right,
636 # top-to-bottom - same as the depthwise convolution strides across each tile
637 kernel = np.zeros((2, 2))
638 kernel[1, 1] = (1 - (sv_h - lb_h)) * (1 - (sv_w - lb_w))
639 kernel[0, 1] = (sv_h - lb_h) * (1 - (sv_w - lb_w))
640 kernel[1, 0] = (1 - (sv_h - lb_h)) * (sv_w - lb_w)
641 kernel[0, 0] = (sv_h - lb_h) * (sv_w - lb_w)
642 kernel *= 16
643 kernels.append(kernel)
644
645 return kernels
646
647 def _build_convolutions(op, kernels):
648 dw_op_attrs = {
649 "padding": Padding.TILE,
650 "stride_h": 1,
651 "stride_w": 1,
652 "strides": (1, 1, 1, 1),
653 "depth_multiplier": 1,
654 "channel_multiplier": 1,
655 "dilation_h_factor": 1,
656 "dilation_w_factor": 1,
657 "dilation": (1, 1, 1, 1),
658 }
659 ifm = op.ifm
660 ofm = op.ofm
661 ofm.ops = []
662 elem_size = 2 if ofm.dtype == DataType.int16 else 1
663
664 n, h, w, c = ifm.shape
665 _, _, ow, _ = ofm.shape
666
667 intermediate_tens = Tensor(ifm.shape, ifm.dtype, "intermediate_tens")
668 intermediate_tens.quantization = op.outputs[0].quantization.clone()
669 avgpool_op = op
670 avgpool_op.name = "rb_init_avgpool"
671 avgpool_op.type = Op.AvgPool
672 avgpool_op.attrs["padding"] = Padding.VALID
673 avgpool_op.attrs["stride_w"] = 1
674 avgpool_op.attrs["stride_h"] = 1
675 avgpool_op.attrs["filter_width"] = 1
676 avgpool_op.attrs["filter_height"] = 1
677 avgpool_op.attrs["strides"] = [1, 1, 1, 1]
678 avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
679
680 avgpool_op.add_input_tensor(ifm)
681 avgpool_op.set_output_tensor(intermediate_tens)
682 avgpool_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000683 DebugDatabase.add_optimised(op, op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000684
685 dw_conv = Operation(Op.DepthwiseConv2DBias, "depthwise_conv")
686 dw_conv._original_type = Op.ResizeBilinear
687 dw_conv.write_shape = Shape4D(n, h, w, c)
688 dw_conv.write_offset = Shape4D(0, 0, 0, 0)
689
690 # Set the output rounding mode. Resize bilinear requires rounding away from zero. Therefore, we need to
691 # adjust the accumulated value by a "small" amount before applying natural rounding. The "small" amount
692 # should be big enough to cause a x.5 to be rounded correctly but small enough not to cause smaller
693 # values to be incorrectly rounded
694 ofm.quantization.next_after = True
695 dw_conv.rounding_mode = NpuRoundingMode.NATURAL
696
697 # Double height and width stride to write the output of each of the four depthwise convolutions below
698 # interleaved with each other when combined with OFM tile base offsets.
699 dw_conv.ofm_stride_multiplier = [1, 2, 2] # C/H/W
700
701 # Choose tile padding direction - pad by 1 with edge values in two direction.
702 # For example, TL (top left) will pad top and left in H/W-plane in all channels.
703 directions = [[1, 1, 0, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 1]] # TL, TR, BL, BR
704 for i in (0, 1):
705 for j in (0, 1):
706 index = i * 2 + j
707 dw_conv.name = f"depthwise_conv_{index}"
708 dw_op_attrs["explicit_padding"] = directions[index]
709 dw_conv.attrs.update(dw_op_attrs)
710
711 # This will offset the start of the write by modifying the Tile 0 base address
712 dw_conv.tile_base_offsets_ofm[0] = (i * ow + j) * c * elem_size
713
714 ofm.ops.append(dw_conv)
715 dw_conv.outputs = [ofm]
716
717 kernel = kernels[index]
718 shape = [2, 2, 1, c]
719 kernel = np.dstack([kernel] * c)
720
721 quant = QuantizationParameters()
722 quant.zero_point = 0
723 quant.scale_f32 = 1.0 / 16
724
725 dw_conv.inputs = []
726 dw_conv.add_input_tensor(intermediate_tens)
727 dw_conv.add_input_tensor(
728 create_const_tensor(
729 "weights",
730 shape,
731 intermediate_tens.dtype,
732 np.array(kernel).reshape(shape),
Rickard Bolinfea15162022-07-04 16:19:16 +0000733 quantization=quant,
734 ),
735 )
736
737 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
738 # need to append the bias tensor as resize ops only have 2 inputs
739 assert len(dw_conv.inputs) == 2
740 dw_conv.inputs.append(None)
Rickard Bolin017b4cc2022-09-23 10:16:48 +0000741 fixup_bias_tensors(dw_conv, None, None, dtype=DataType.int32)
Rickard Bolinfea15162022-07-04 16:19:16 +0000742
743 dw_conv.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000744 DebugDatabase.add_optimised(op, dw_conv)
745
Rickard Bolinfea15162022-07-04 16:19:16 +0000746 dw_conv = dw_conv.clone(f"_{index}")
747 return op
748
749 _, input_height, input_width, _ = op.ifm.shape
750 _, output_height, output_width, _ = op.ofm.shape
751
752 kernels = _compute_kernels(input_height, input_width, output_height, output_width)
753 op = _build_convolutions(op, kernels)
754
755 return op
756
757
Tim Hall885033b2022-07-21 11:46:03 +0100758def fixup_resize(op, arch, nng):
759 if op.type.is_resize_op() and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200760 if op.ifm_shapes[0] == op.ofm_shapes[0]:
Tim Hall885033b2022-07-21 11:46:03 +0100761 # Bypass the resize op which is essentially a NOP
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200762 op.inputs = op.inputs[:1]
763 op.type = Op.Identity
764 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
Tim Hall885033b2022-07-21 11:46:03 +0100765 convert_resize_1x1_to_add(op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000766 elif op.type == Op.ResizeBilinear and op.attrs.get("half_pixel_centers", False):
767 convert_resizebilinear_to_depthwise_convolutions(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200768 else:
Tim Hall885033b2022-07-21 11:46:03 +0100769 convert_resize_to_upscale_and_average_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200770
771 return op
772
773
774def convert_nop_split_to_identity(op, arch, nng):
775 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
776 # the list comprehension should return a list with a single tensor
777 # if it shouldn't, remove_passthrough_tensor will fail appropriately
778 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
779 op.type = Op.Identity
780 return op
781
782
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100783def rewrite_fully_connected_input(op: Operation, arch, nng):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200784
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100785 if op.type == Op.FullyConnected:
786 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
787 assert new_shape is not None, "Tensor can not be reshaped to 2D"
788 op.ifm_shapes[0] = new_shape
Johan Alfvén65835e02022-10-13 10:49:30 +0200789
790 if op.ifm_shapes[0].batch > 1 and op.ofm_shapes[0].batch == 1:
791 # If IFM is batching then also make sure OFM is batching
792 h, w = op.ofm_shapes[0].height, op.ofm_shapes[0].width
793 op.ofm_shapes[0] = Shape4D([h * w, 1, 1, op.ofm_shapes[0].depth])
794
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200795 return op
796
797
798def convert_batched_fc_shape(op, arch, nng):
799 if op.type == Op.FullyConnected:
800 # Check if the first dimension indicates batching
801 if op.ifm_shapes[0].batch > 1:
802 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
803 n = op.ifm_shapes[0].batch
804 h, w = batching_split.get(n, (1, n))
805 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
806
807 # Reshape Weights to be 4D. IO becomes HWIO
808 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100809 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
810 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200811
812 n = op.ofm_shapes[0].batch
813 h, w = batching_split.get(n, (1, n))
814 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
815 return op
816
817
818def unfuse_activation_function(op):
819 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
820 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
821 op.activation = None
822 out_tens = op.outputs[0]
823 intermediate_tens = out_tens.clone("_act_intermediate")
824 act_op.set_output_tensor(out_tens)
825 act_op.add_input_tensor(intermediate_tens)
826 op.set_output_tensor(intermediate_tens)
827 act_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000828 DebugDatabase.add_optimised(op, act_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200829
830
831def rewrite_stridedslice_output(op, arch, nng):
832 if not op.run_on_npu or op.type != Op.StridedSlice:
833 return op
834
835 new_axis_mask = op.attrs["new_axis_mask"]
836 shrink_axis_mask = op.attrs["shrink_axis_mask"]
837
838 if shrink_axis_mask == 0 and new_axis_mask == 0:
839 return op
840
841 axis_4D = [0] * len(op.outputs)
842 for idx, out_tens in enumerate(op.outputs):
843 output_shape = list(out_tens.shape)
844
845 if shrink_axis_mask != 0:
846 n = 0
847 axis = 0
848 while shrink_axis_mask:
849 prev_mask = shrink_axis_mask
850 n += 1
851 shrink_axis_mask &= shrink_axis_mask - 1
852 axis = int(math.log2(prev_mask - shrink_axis_mask))
853 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
854
855 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
856 op.attrs["shrink_axis_mask"] = 0
857 if axis >= 0:
858 axis_4D[idx] = axis + (4 - len(output_shape))
859 else:
860 axis_4D[idx] = axis
861 op.ofm_shapes[idx] = Shape4D(output_shape)
862
863 elif new_axis_mask != 0:
864 n = 0
865 axis = 0
866 while new_axis_mask:
867 prev_mask = new_axis_mask
868 n += 1
869 new_axis_mask &= new_axis_mask - 1
870 axis = int(math.log2(prev_mask - new_axis_mask))
871 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
872 new_axis_mask >>= 1
873
874 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
875 op.attrs["new_axis_mask"] = 0
876 if axis >= 0:
877 axis_4D[idx] = axis + (4 - len(output_shape))
878 else:
879 axis_4D[idx] = axis
880 op.ofm_shapes[idx] = Shape4D(output_shape)
881
882 op.attrs["split_axis_4D"] = axis_4D
883 return op
884
885
886def rewrite_unpack_output(op, arch, nng):
887 tens = op.outputs[0]
888 if op.run_on_npu and op.type == Op.Unpack:
889 # Unpack is also referred to as Unstack
890 axis = int(op.attrs["axis"])
891 if axis < 0: # Convert to positive axis
892 axis = len(op.inputs[0].shape) + 1 + axis
893 op.type = Op.UnpackReshaped
894 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
895
896 axis_4D = axis + (4 - len(desired_output_shape))
897 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
898
899 for idx, out_tens in enumerate(op.outputs):
900 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
901 return op
902
903
904def add_padding_fields(op, arch, nng):
905 if op.run_on_npu:
906 if "padding" in op.attrs:
907 input_shape = op.ifm_shapes[0]
908 output_shape = op.ofm_shapes[0]
909 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
910 kernel_size = op.inputs[1].shape[:2]
911 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
912 kernel_size = op.attrs["ksize"][1:3]
913 else:
914 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
915
916 if op.type == Op.Conv2DBackpropInputSwitchedBias:
917 upscaling_factor = output_shape.height // input_shape.height
918 padding, skirt = calc_upscaled_padding_and_skirt(
919 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
920 )
921 else:
922 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200923 op.attrs["padding"],
924 op.kernel,
925 input_shape,
926 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200927 )
928
929 op.attrs["explicit_padding"] = padding
930 op.attrs["skirt"] = skirt
931
932 return op
933
934
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200935def reorder_depthwise_weights(op, arch, nng):
936 if op.type.is_depthwise_conv2d_op():
937 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100938 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
939 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200940 weight_tensor.weight_transpose_depthwise = True
941
942 return op
943
944
Raul Farkas72c6a242023-03-16 16:38:05 +0000945def fixup_strided_conv(op: Operation, arch, nng):
946 """Optimize or fixup strided Conv2DBias
947 Optimization:
948 Reduce, when possible, the Conv2DBias stride from 2 to 1 by re-shaping
949 both IFM and filter.
950
951 Fixup:
952 Introduce software support for Conv2DBias with stride_width = 4 by
953 reducing it to 1 when possible by re-shaping both IFM and filter.
954 """
Raul Farkas090f18a2023-01-24 16:29:06 +0000955 if op.type != Op.Conv2DBias:
Louis Verhaard43d27582022-03-17 14:06:00 +0100956 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200957 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100958 weight_tensor = op.weights
959 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200960 if (
Raul Farkas090f18a2023-01-24 16:29:06 +0000961 (stride_x == 2 or stride_x == 4)
Louis Verhaard43d27582022-03-17 14:06:00 +0100962 and ifm_shape.depth <= 4
963 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200964 and weight_tensor is not None
965 and weight_tensor.shape[1] >= 2
966 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100967 k_w, _ = op.get_kernel_size()
Raul Farkas090f18a2023-01-24 16:29:06 +0000968 curr_padding_x = needed_total_padding(ifm_shape.width, stride_x, k_w)
969 optimised_padding_x = needed_total_padding(ifm_shape.width // stride_x, 1, (k_w + 1) // stride_x)
970 padding_type = op.attrs.get("padding", None)
971
972 # If padding is enabled, check if current padding matches optimised padding
973 if not padding_type or (padding_type != Padding.VALID and curr_padding_x != optimised_padding_x):
Louis Verhaard43d27582022-03-17 14:06:00 +0100974 # Horizontal padding would become different after optimisation; this would not work
975 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200976 # IFM
Raul Farkas090f18a2023-01-24 16:29:06 +0000977 op.ifm_shapes[0] = Shape4D(
978 [ifm_shape.batch, ifm_shape.height, ifm_shape.width // stride_x, ifm_shape.depth * stride_x]
979 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200980
981 # Weights
982 weight_shape = weight_tensor.shape
983 if weight_shape[1] % 2 != 0:
984 weight_shape[1] = weight_shape[1] + 1
985 padded_array = np.zeros(weight_shape)
986 for i in range(weight_shape[0]):
987 padded_array[i] = np.vstack(
988 [
James Peet7519d502021-07-19 16:47:58 +0100989 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200990 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
991 ]
992 )
James Peet7519d502021-07-19 16:47:58 +0100993 weight_tensor.values = padded_array
Raul Farkas090f18a2023-01-24 16:29:06 +0000994
995 # Change weight shape based on stride_x
996 weight_shape[1] //= stride_x
997 weight_shape[2] *= stride_x
998
James Peet7519d502021-07-19 16:47:58 +0100999 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001000 weight_tensor.set_all_shapes(weight_shape)
1001 # If multiple copies of the weights are used, we could avoid
1002 # them having the same address by changing the value_id
1003 weight_tensor.value_id = uuid.uuid4()
1004
1005 # Strides
1006 stride_x = 1
1007 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
1008
Raul Farkas72c6a242023-03-16 16:38:05 +00001009 op.ifm.force_linear_format = True
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001010 return op
1011
1012
1013def convert_conv_to_fc(op, arch, nng):
1014 # Conv 1x1 can be equivalent to Fully Connected.
1015 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
1016 # caching/double buffering for the weights.
1017 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
1018 if op.type == Op.Conv2DBias:
1019 h = op.ifm_shapes[0].height
1020 w = op.ifm_shapes[0].width
1021 kh, kw, _, _ = op.inputs[1].shape
1022 if h == 1 and w == 1 and kh == 1 and kw == 1:
1023 # Overwrite this op as a Fully Connected Op
1024 op.name += "_fc"
1025 op.type = Op.FullyConnected
1026 op.attrs = {
1027 "weights_format": 0,
1028 }
1029 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
1030 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +01001031 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
1032 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001033
1034 DebugDatabase.add_optimised(op, op)
1035 return op
1036
1037
1038def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
1039 if op.run_on_npu and op.type.is_relu_op():
1040 ifm = op.inputs[0]
1041 ofm = op.outputs[0]
1042 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
1043 # and requires its own to be inserted
1044 if not check_quantized_tens_scaling_equal(ifm, ofm):
1045 # Override this op with its own primary op (avgpool)
1046 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
1047 # And fuse the original activation function to it
1048 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +02001049 # Add explicit rescaling
1050 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
1051 multiplier, shift = scaling.quantise_scale(rescale)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001052 relu_fused_op.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001053 # Tidy up and assign the ifm and ofm to the new op
1054 ifm.consumer_list.remove(op)
1055
1056 relu_fused_op.add_input_tensor(ifm)
1057 relu_fused_op.set_output_tensor(ofm)
1058 relu_fused_op.set_ifm_ofm_shapes()
1059 op = relu_fused_op
1060 return op
1061
1062
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001063def convert_softmax(op, arch, nng):
1064 if op.type == Op.Softmax and op.run_on_npu:
1065 softmax = SoftMax(op)
1066 op = softmax.get_graph()
1067 return op
1068
1069
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001070def convert_prelu(op, arch, nng):
1071 if op.type == Op.Prelu:
1072 ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
1073 if None in (ifm, alpha, ofm):
1074 return op
1075
Fredrik Svedberg66591652022-08-29 10:51:27 +02001076 if alpha.values is not None:
1077 # If const alpha check for possible optimisations
1078 alpha_zp = alpha.quantization.zero_point
1079 alpha_scale = alpha.quantization.scale_f32
1080 # If all alpha values are the same the PReLU can be converted to LeakyRelu
Rickard Bolin5fdcf172022-12-19 12:56:17 +00001081 alpha_min = (alpha.values.min().astype(int) - alpha_zp) * alpha_scale
1082 alpha_max = (alpha.values.max().astype(int) - alpha_zp) * alpha_scale
Fredrik Svedberg66591652022-08-29 10:51:27 +02001083 if alpha_min == alpha_max:
1084 # or even a Relu
1085 if alpha_min == 0:
1086 new_op = Op.Relu
1087 else:
1088 new_op = Op.LeakyRelu
1089 op.attrs["alpha"] = alpha_min
1090 # setup alpha_scaling for bit exact result
1091 ifm_scale = ifm.quantization.scale_f32
1092 ofm_scale = ofm.quantization.scale_f32
1093 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha_scale, ofm_scale)
1094 op.attrs["alpha_scaling"] = (alpha.values.min() - alpha_zp, alpha_scale, alpha_shift)
1095 # Change op type
1096 op.type = new_op
1097 op.name = op.name.replace("Prelu", new_op.name)
1098 del op.inputs[1] # Remove alpha tensor
1099 return op
1100 elif alpha_max < 1:
1101 # If alpha_max is less than 1 convert PReLU to Max(alpha * IFM, identity * IFM)
1102 # Multiply with alpha tensor
1103 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
1104 mul_alpha.add_input_tensor(ifm)
1105 mul_alpha.add_input_tensor(alpha)
1106 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1107 mul_alpha.set_output_tensor(fm_alpha)
1108 mul_alpha.set_ifm_ofm_shapes()
1109 DebugDatabase.add_optimised(op, mul_alpha)
1110 if check_quantized_tens_scaling_equal(ifm, ofm):
1111 # No scaling is needed
1112 fm_id = ifm
1113 else:
1114 # Add multiplication with identity
1115 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1116 mul_identity.add_input_tensor(ifm)
1117 # Create const tensor containing identity as scalar
1118 quantization = ifm.quantization.clone()
1119 quantization.scale_f32 = np.float32(1)
1120 quantization.zero_point = 0
1121 one = create_const_tensor("one_const", [], ifm.dtype, [1], quantization=quantization)
1122 mul_identity.add_input_tensor(one)
1123 # Make sure that fm_id is allocated to a different address than fm_alpha
1124 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1125 mul_identity.set_output_tensor(fm_id)
1126 mul_identity.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +00001127 DebugDatabase.add_optimised(op, mul_identity)
Fredrik Svedberg66591652022-08-29 10:51:27 +02001128
1129 # Combine scaled and alpha multiplied values
1130 max_op = Operation(Op.Maximum, op.name + "_max")
1131 max_op.add_input_tensor(fm_alpha)
1132 max_op.add_input_tensor(fm_id)
1133 max_op.set_output_tensor(ofm)
1134 max_op.set_ifm_ofm_shapes()
1135
1136 DebugDatabase.add_optimised(op, max_op)
1137 ifm.consumer_list.remove(op)
1138 return max_op
1139
1140 # Catch all PReLU conversion for the cases that could not be optimised above
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001141 no_scale_quant = ifm.quantization.clone()
1142 no_scale_quant.scale_f32 = None
1143 no_scale_quant.zero_point = 0
Fredrik Svedberg66591652022-08-29 10:51:27 +02001144 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001145
1146 # Select values < 0
1147 min_op = Operation(Op.Minimum, op.name + "_min")
1148 min_op.add_input_tensor(ifm)
1149 min_op.add_input_tensor(zero)
1150 fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
1151 min_op.set_output_tensor(fm_negative)
1152 min_op.set_ifm_ofm_shapes()
1153 DebugDatabase.add_optimised(op, min_op)
1154
1155 # and multiply with alpha tensor
1156 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
1157 mul_alpha.add_input_tensor(fm_negative)
1158 mul_alpha.add_input_tensor(alpha)
1159 fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
1160 mul_alpha.set_output_tensor(fm_alpha)
1161 mul_alpha.set_ifm_ofm_shapes()
1162 DebugDatabase.add_optimised(op, mul_alpha)
1163
1164 # Select (and scale) values > 0
1165 relu_op = Operation(Op.Relu, op.name + "_relu")
1166 relu_op.add_input_tensor(ifm)
1167 fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1168 relu_op.set_output_tensor(fm_scaled)
1169 relu_op.set_ifm_ofm_shapes()
1170 DebugDatabase.add_optimised(op, relu_op)
1171
1172 # Add scaled and alpha multiplied values (without scaling)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001173 add_op = Operation(Op.Add, op.name + "_add")
1174 add_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001175 add_op.add_input_tensor(fm_alpha)
1176 add_op.add_input_tensor(fm_scaled)
1177 add_op.set_output_tensor(ofm)
1178 add_op.set_ifm_ofm_shapes()
1179
1180 DebugDatabase.add_optimised(op, add_op)
1181 ifm.consumer_list.remove(op)
1182 op = add_op
1183
1184 return op
1185
1186
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001187def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
1188 r"""Whenever there is a subgraph with this topology:
1189
Jonas Ohlssond8575072022-03-30 10:30:25 +02001190 Input X For X = -1 or X > 0
1191 | \ / This subgraph can be replaced with either
1192 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
1193 | /
1194 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001195 """
1196
1197 if op.type == Op.Maximum:
1198 # finds the Mul input(s) to the Max
1199 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
1200 if len(muls) == 1:
1201 mul = muls[0].ops[0]
1202 elif len(muls) == 2:
1203 # In the case both inputs are Muls, find the one with the same input as the Max
Fredrik Svedberg66591652022-08-29 10:51:27 +02001204 mul_ifms = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1]
1205 if len(mul_ifms):
1206 mul = mul_ifms[0].ops[0]
1207 else:
1208 # Not using same input
1209 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001210 else:
1211 # No Mul inputs
1212 return op
1213
1214 # make sure the Mul doesn't have any other consumers
1215 mul_ofm = mul.outputs[0]
1216 if len(mul_ofm.consumers()) != 1:
1217 return op
1218 # make sure the Mul doesn't have a fused activation function
1219 if mul.activation:
1220 return op
1221 ifm, ofm = op.get_ifm_ofm()
1222 if ifm is None or ofm is None:
1223 return op
1224
1225 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1226 return op
1227 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
1228 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
1229 return op
1230
1231 # finds the branched input that goes to both the Max and the Mul
1232 shared = set(op.inputs) & set(mul.inputs)
1233 if len(shared) == 1:
1234 shared_in = shared.pop()
1235 # find the constant scalar input to the Mul
1236 const_tens = (set(mul.inputs) - {shared_in}).pop()
1237 # check that it is a scalar
1238 if const_tens.shape != []:
1239 return op
1240 const = const_tens.ops[0]
1241 # check that it is a constant
1242 if const.type != Op.Const:
1243 return op
1244 # Remove the Mul from the shared input's consumers
1245 shared_in.consumer_list.remove(mul)
1246 else:
1247 return op
1248
1249 val = const.outputs[0].values
1250 if val >= 0:
1251 new_op = Op.LeakyRelu
1252 op.attrs["alpha"] = val
1253 # to produce bit exact results, the alpha is not enough;
1254 # save additional scaling info in attr "alpha_scale", to be used as input
1255 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +01001256 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001257 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
1258 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
1259 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
1260 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
1261 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
1262 elif val == -1:
1263 new_op = Op.Abs
1264 else:
1265 return op
1266
1267 op.type = new_op
1268 op.name = op.name.replace("Maximum", new_op.name)
1269 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
1270 op.inputs = [shared_in]
1271 op.set_ifm_ofm_shapes()
1272
1273 # Record optimisation in debug database
1274 DebugDatabase.add_optimised(op, op)
1275
1276 return op
1277
1278
1279def convert_hardswish_to_lut(op, arch, nng):
1280 if op.type == Op.HardSwish:
1281 ifm, ofm = op.get_ifm_ofm()
1282 # Generate the LUT
1283 ifm_scale = np.double(ifm.quantization.scale_f32)
1284 ofm_scale = np.double(ofm.quantization.scale_f32)
1285 zp_in = ifm.quantization.zero_point
1286 zp_out = ofm.quantization.zero_point
1287 ifm_scale_hires = (1 / 128) * ifm_scale
1288 relu_multiplier = np.double(3 / 32768)
1289 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
1290 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
1291 # Use 16bit scale
1292 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
1293 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
1294
1295 values = []
1296 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1297 quantized_min = min(ix)
1298 quantized_max = max(ix)
1299 for x in ix:
1300 input_value = x - zp_in
1301 input_value_hires = input_value * 128
1302 # Compute the input value on essentially the output scale, not shifted yet
1303 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
1304 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
1305 relu_value = np.int16(input_value_hires)
1306 if relu_shift < 31:
1307 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
1308
1309 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
1310
1311 if relu_shift < 31:
1312 relu_value = fp_math.shift_left16(relu_value, 1)
1313
1314 if relu_shift > 31:
1315 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
1316
1317 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
1318 # Now convert that to a 16bit fixedpoint value in [0, 1]
1319 relu_value = (relu_value + (1 << 15)) >> 1
1320 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
1321 shift = 31 - out_shift
1322 shift = -shift if shift < 0 else 0
1323 # Finally apply the output shift
1324 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
1325 lut_result = min(quantized_max, max(quantized_min, lut_result))
1326 values.append(lut_result)
1327 return convert_to_lut(op, values, "hardswish")
1328 return op
1329
1330
1331def convert_lrelu_to_mul_max(op, arch):
1332 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
1333 # (the opposite of convert_mul_max_to_abs_or_lrelu)
1334 ifm, ofm = op.get_ifm_ofm()
1335 if ifm is None or ofm is None:
1336 return op
1337
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001338 alpha = np.float32(op.attrs["alpha"])
1339 use_mul_max = 0 < alpha < 1
Fredrik Svedberg36424312022-09-16 09:39:26 +02001340 is_converted_prelu = "alpha_scaling" in op.attrs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001341 if use_mul_max:
1342 mul_ifm = ifm
1343 new_op = Op.Maximum
1344 else:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001345 # Need to use a different approach for alpha < 0 or alpha > 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001346 no_scale_quant = ifm.quantization.clone()
1347 no_scale_quant.scale_f32 = None
1348 no_scale_quant.zero_point = 0
1349 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
1350
1351 # Select values < 0
1352 min_op = Operation(Op.Minimum, op.name + "_min")
1353 min_op.add_input_tensor(ifm)
1354 min_op.add_input_tensor(zero)
1355 mul_ifm = ifm.clone(op.name + "_negative", set_unique=True)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001356 if alpha < 0 and not is_converted_prelu:
1357 # For negative alpha that is not from a converted PReLU we need to use
1358 # int32 Mul below to perform the (negative) alpha scaling
1359 mul_ifm.dtype = DataType.int32
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001360 min_op.set_output_tensor(mul_ifm)
1361 min_op.set_ifm_ofm_shapes()
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001362 new_op = Op.Add
1363 op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001364 DebugDatabase.add_optimised(op, min_op)
1365
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001366 # Add multiplication with alpha
1367 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001368 mul_alpha.add_input_tensor(mul_ifm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001369 # Create const tensor containing alpha as scalar
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001370 quantization = ifm.quantization.clone()
1371 quantization.min = 0
1372 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
1373 quantization.zero_point = 0
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001374 alpha_dtype = mul_ifm.dtype
Fredrik Svedberg36424312022-09-16 09:39:26 +02001375 if is_converted_prelu:
1376 # The LeakyRelu was the result from convert_prelu and the scaling is provided
Fredrik Svedberg66591652022-08-29 10:51:27 +02001377 scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001378 mul_alpha.explicit_scaling = ExplicitScaling(False, [alpha_shift], [alpha_scale])
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001379 elif alpha == 0 or np.isinf(1 / alpha):
1380 # Handling of alpha near or at zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001381 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001382 scalar = 0
1383 else:
1384 quantization.scale_f32 = alpha
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001385 if alpha_dtype == DataType.int32:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001386 # When the datatype is int32 (alpha negative) we need to do the scaling with the multiplication
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001387 scalar, _ = scaling.elementwise_mul_scale(ifm.quantization.scale_f32, alpha, ofm.quantization.scale_f32)
1388 else:
1389 scalar = 1
Tim Hall3b1578e2023-01-13 17:57:25 +00001390 alpha_tens = create_const_tensor(op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], quantization=quantization)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001391 mul_alpha.add_input_tensor(alpha_tens)
1392 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1393 mul_alpha.set_output_tensor(fm_alpha)
1394 mul_alpha.set_ifm_ofm_shapes()
1395 DebugDatabase.add_optimised(op, mul_alpha)
1396
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001397 if not use_mul_max:
1398 relu_op = Operation(Op.Relu, op.name + "_relu")
1399 relu_op.add_input_tensor(ifm)
1400 fm_id = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1401 relu_op.set_output_tensor(fm_id)
1402 relu_op.set_ifm_ofm_shapes()
1403 DebugDatabase.add_optimised(op, relu_op)
1404 elif check_quantized_tens_scaling_equal(ifm, ofm):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001405 # No identity multiplication is needed
1406 fm_id = ifm
1407 else:
1408 # Add multiplication with identity
1409 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1410 mul_identity.add_input_tensor(ifm)
1411 # Create const tensor containing identity as scalar
1412 quantization = ifm.quantization.clone()
1413 quantization.min = 0
1414 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001415 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001416 quantization.zero_point = 0
Tim Hall3b1578e2023-01-13 17:57:25 +00001417 identity_tens = create_const_tensor(op.name + "_id_scalar", [], ifm.dtype, [1], quantization=quantization)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001418 mul_identity.add_input_tensor(identity_tens)
1419 # Make sure that fm_id is allocated to a different address than fm_alpha
1420 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1421 mul_identity.set_output_tensor(fm_id)
1422 mul_identity.set_ifm_ofm_shapes()
1423 DebugDatabase.add_optimised(op, mul_identity)
1424
1425 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001426 op.type = new_op
1427 op.name = op.name.replace("LeakyRelu", new_op.name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001428 op.inputs = []
1429 ifm.consumer_list.remove(op)
1430 op.add_input_tensor(fm_alpha)
1431 op.add_input_tensor(fm_id)
1432 op.set_ifm_ofm_shapes()
1433
1434 DebugDatabase.add_optimised(op, op)
1435 return op
1436
1437
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001438def convert_to_lut8(op, fn, fn_name):
1439 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
1440 # fn is a function(real) -> real
1441 ifm, ofm = op.get_ifm_ofm()
1442 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1443 return op
1444 # Generate the LUT
1445 ifm_scale = np.double(ifm.quantization.scale_f32)
1446 ofm_scale = np.double(ofm.quantization.scale_f32)
1447 zp_in = ifm.quantization.zero_point
1448 zp_out = ofm.quantization.zero_point
1449 values = []
1450 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1451 quantized_min = min(ix)
1452 quantized_max = max(ix)
1453 for x in ix:
1454 x_real = ifm_scale * (x - zp_in)
1455 y_real = fn(x_real)
1456 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1457 lut_result = min(quantized_max, max(quantized_min, lut_result))
1458 values.append(lut_result)
1459 return convert_to_lut(op, values, fn_name)
1460
1461
1462def convert_lrelu_to_lut(op, arch):
1463 ifm, ofm = op.get_ifm_ofm()
1464 # Generate the LUT
1465 alpha = op.attrs["alpha"]
1466 ifm_scale = np.double(ifm.quantization.scale_f32)
1467 ofm_scale = np.double(ofm.quantization.scale_f32)
1468 zp_in = ifm.quantization.zero_point
1469 zp_out = ofm.quantization.zero_point
1470 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1471 alpha_scalar = 1
1472 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1473 if "alpha_scaling" in op.attrs:
1474 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1475 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1476 values = []
1477 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1478 quantized_min = min(ix)
1479 quantized_max = max(ix)
1480 for x in ix:
1481 if x < zp_in:
1482 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1483 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1484 )
1485 else:
1486 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1487 lut_result = min(quantized_max, max(quantized_min, lut_result))
1488 values.append(lut_result)
1489 return convert_to_lut(op, values, "lrelu")
1490
1491
1492def convert_lrelu(op, arch, nng):
1493 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
1494 if op.type != Op.LeakyRelu:
1495 return op
1496 ifm, ofm = op.get_ifm_ofm()
1497 if ifm is None or ofm is None:
1498 return op
Fredrik Svedberg36424312022-09-16 09:39:26 +02001499 alpha = op.attrs["alpha"]
1500 if alpha == 0:
1501 # When alpha is 0 the opertion can be converted to a ReLU
1502 op.type = Op.Relu
1503 op.name = op.name.replace("LeakyRelu", op.type.name)
1504 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001505 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1506 # use LUT for int8/uint8
1507 return convert_lrelu_to_lut(op, arch)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001508 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16 and alpha > 0:
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001509 # use LeakyRelu unmodified for int16 with equal input/output scaling and positive alpha
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001510 return op
1511 return convert_lrelu_to_mul_max(op, arch)
1512
1513
1514def convert_tanh_sigmoid_to_lut(op, arch, nng):
1515 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
1516 if op.type == Op.Sigmoid:
1517 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1518 elif op.type == Op.Tanh:
1519 return convert_to_lut8(op, math.tanh, "tanh")
1520 return op
1521
1522
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001523def fuse_activation_function_with_prev(op, arch, nng):
1524 # if op is a no-op: attempts to move the activation function to the preceding op
1525 if not op.attrs.get("is_nop", False) or op.activation is None:
1526 return op
1527 ifm, ofm = op.get_ifm_ofm()
1528 if ifm is None or ofm is None:
1529 return op
1530 # finds the input(s) to the operation
1531 prev_op = ifm.ops[0]
1532 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1533 fuse = (
1534 prev_op.run_on_npu
1535 and prev_op.type.npu_block_type != NpuBlockType.Default
1536 and len(ifm.ops) == 1
1537 and len(prev_op.outputs[0].consumers()) == 1
1538 and prev_op.activation is None
1539 )
1540 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1541 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1542 # LUT currently only works correctly for elementwise ops
1543 fuse = False
1544 if not fuse:
1545 return op
1546 # Move the fused activation function + corresponding info to prev_op
1547 prev_op.activation = op.activation
1548 prev_op.forced_output_quantization = op.forced_output_quantization
1549 if op.activation_lut is not None:
1550 prev_op.set_activation_lut(op.activation_lut)
1551 # Bypass op
1552 prev_op.set_output_tensor(ofm)
wilisa0179a89042022-11-02 17:18:43 +00001553 DebugDatabase.add_optimised(prev_op, prev_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001554 return op
1555
1556
1557def _leading_pad_ok(leading_pad, stride, kernel_size):
1558 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1559 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1560 max_size = kernel_size // 2
1561 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1562
1563
1564def replace_pad_by_hw_pad(op: Operation, arch, nng):
1565 """
1566 Tries to completely remove a PAD operator by using hardware padding.
1567 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1568 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1569 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1570 if both operations can be run on the NPU.
1571 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1572 """
1573 if (
1574 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +00001575 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001576 and op.run_on_npu
1577 and op.attrs["padding"] == Padding.VALID
1578 ):
1579 pad_op = op.ifm.ops[0]
1580 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1581 return op
1582 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1583 return op
1584 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1585 k = op.kernel
1586 k_w, k_h = k.dilated_wh()
1587
1588 # Check if the PAD operator can be replaced by hardware padding
1589 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1590 # Too much padding, it would require hardware padding to actually insert zeros
1591 return op
1592 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1593 return op
1594
1595 if op.type.is_avgpool_op():
1596 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1597 for pad, k_size in (
1598 (left, k_w),
1599 (right, k_w),
1600 (top, k_h),
1601 (bottom, k_h),
1602 ):
1603 if pad not in (0, k_size // 2):
1604 return op
1605 # Average pool is converted to depthwise, because NPU average pool + same padding
1606 # has a special implementation that is different from PAD followed by average pool with
1607 # valid padding.
1608 k_w, k_h = op.kernel.width, op.kernel.height
1609 ifm = op.ifm
1610 # Remember other inputs
1611 other_inputs = op.inputs[1:]
1612 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1613 quantization = QuantizationParameters(0.0, 255.0)
1614 quantization.scale_f32 = 1.0 / (k_w * k_h)
1615 quantization.zero_point = 0
1616 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1617 weights = np.full(shape, 1)
1618
1619 weight_tens = create_const_tensor(
1620 op.name + "_weights",
1621 shape,
1622 op.ifm.dtype,
1623 weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001624 purpose=TensorPurpose.Weights,
1625 quantization=quantization,
1626 )
James Peet7519d502021-07-19 16:47:58 +01001627 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001628 op.type = Op.DepthwiseConv2DBias
1629 op.inputs = []
1630 op.add_input_tensor(ifm)
1631 op.add_input_tensor(weight_tens)
1632 # Add bias tensor, all biases set to 0
1633 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001634 fixup_bias_tensors(op, arch, nng, DataType.int32)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001635 # Add other inputs
1636 op.inputs.extend(other_inputs)
1637 op.rounding_mode = NpuRoundingMode.NATURAL
1638
1639 # Bypass the PAD operator
1640 op.set_input_tensor(pad_op.ifm, 0)
1641 # Adjust the padding attributes of the convolution operator
1642 op.attrs["padding"] = Padding.EXPLICIT
1643 op.attrs["explicit_padding"] = (top, left, bottom, right)
1644 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +00001645 DebugDatabase.add_optimised(op, op)
1646
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001647 return op
1648
1649
1650def convert_pad(op: Operation, arch, nng):
1651 """
1652 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1653 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1654 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1655 """
1656 if op.type != Op.Pad or not op.run_on_npu:
1657 return op
1658 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1659
1660 ifm = op.ifm
1661 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001662 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001663 ofm = op.ofm
1664 assert ofm is not None
1665 ofm.ops = []
1666 ofm_shape = op.ofm_shapes[0]
1667
1668 # Average pool op that copies IFM to the right place inside the OFM
1669 shp0 = Shape4D(0, 0, 0, 0)
1670 shp_top = shp0.with_height(top)
1671 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1672 avgpool_op.activation = op.activation
1673 quant = ofm.quantization
1674 pad_value = quant.zero_point
1675 # Add operations that fill the borders of the OFM
1676 if top > 0:
1677 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1678 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001679 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001680 )
1681 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1682 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1683 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1684 if bottom > 0:
1685 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1686 zero_tens = create_const_tensor(
1687 op.name + "_bottom",
1688 shape.as_list(),
1689 ofm.dtype,
1690 shape.elements() * [pad_value],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001691 quantization=quant,
1692 )
1693 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1694 create_avg_pool_for_concat(
1695 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1696 )
1697 if left > 0:
1698 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1699 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001700 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001701 )
1702 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1703 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1704 if right > 0:
1705 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1706 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001707 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001708 )
1709 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1710 create_avg_pool_for_concat(
1711 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1712 )
1713
1714 op.type = Op.ConcatTFLite
1715 return avgpool_op
1716
1717
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001718def fixup_bias_tensors(op, arch, nng, dtype=None):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001719 if op.type.needs_bias() and op.bias is None:
1720 # Op has no bias, add bias tensor filled with zeros
1721 nr_biases = op.inputs[1].shape[-1]
1722 bias_values = [0] * nr_biases
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001723 # The DataType of the bias tensor can be explicitly provided or deduced from the ifm
1724 # DataType. Default is int32 bias for 8-bit ifms and int64 for int16 ifms.
1725 # For int16 the selected bias DataType will have an impact on the scaling
1726 # used when encoding the scales and biases later. The default mode will match the
1727 # refence with reduced scaling for int64 bias.
1728 # This means that in cases (in the graph optimiser) where DepthwiseConv2DBias
1729 # is used to emulate average pool int32 bias should be selected for full precision
1730 # int16 scaling.
1731 if dtype is None:
1732 dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32
1733 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001734 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1735
1736 return op
1737
1738
wilisa0146c94772023-02-08 09:56:14 +00001739def detect_asymmetric_weights(op):
1740 # Check all ops (cpu and npu)
1741 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
1742 if op.ifm.dtype in (DataType.int8, DataType.int16):
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001743 if not np.all(op.weights.quantization.zero_point == 0):
wilisa0146c94772023-02-08 09:56:14 +00001744 print(f"Warning: Op {op.type} '{op.name}' has asymmetric weights.", end=" ")
1745 return True
1746 return False
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001747
wilisa0146c94772023-02-08 09:56:14 +00001748
1749def fixup_asymmetric_weights(op, arch, nng):
1750 if detect_asymmetric_weights(op):
1751 if op.run_on_npu:
1752 print("Zero points have been adjusted.")
1753 op.weights.quantization.zero_point *= 0
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001754 return op
1755
1756
wilisa0146c94772023-02-08 09:56:14 +00001757def check_asymmetric_weights(op, arch, nng):
1758 # This function can modify the run_on_npu flag which causes an operator to be placed on the CPU. It is usually only
1759 # set by the supported operator checks. Therefore, it should be run immediately after those checks to avoid the
1760 # possibility of other graph optimiser functions modify the operator (that is later run on the CPU)
1761 if detect_asymmetric_weights(op):
1762 if op.run_on_npu:
1763 print("To run the operator on Ethos-U use the option --force-symmetric-int-weights")
1764 op.run_on_npu = False
1765 return op
1766
1767
1768def fixup_or_check_asymmetric_weights(force_symmetric_int_weights):
1769 if force_symmetric_int_weights:
1770 return fixup_asymmetric_weights
1771 else:
1772 return check_asymmetric_weights
1773
1774
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001775def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1776 if op.type == Op.Mean and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001777 inp, axis = op.inputs
1778 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001779 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001780 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001781 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001782
1783 # Height and width axes have different index depending on dimensions
1784 if axis.shape == [] or axis.shape[0] == 1: # single axis
1785 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1786 if dims in (2, 3):
1787 if axis == 0:
1788 h, w = shape[axis], 1
1789 else:
1790 h, w = 1, shape[axis]
1791 else:
1792 if axis == 1:
1793 h, w = shape[axis], 1
1794 else:
1795 h, w = 1, shape[axis]
1796 else: # multiple axes
1797 axis = sorted(axis.values)
1798 h, w = [shape[i] for i in axis]
1799
1800 # Set necessary depthwise attributes
1801 op.attrs.update(
1802 {
1803 "padding": Padding.VALID,
1804 "stride_h": 1,
1805 "stride_w": 1,
1806 "strides": (1, 1, 1, 1),
1807 "depth_multiplier": 1,
1808 "channel_multiplier": 1,
1809 "dilation_h_factor": 1,
1810 "dilation_w_factor": 1,
1811 "dilation": (1, 1, 1, 1),
1812 }
1813 )
1814 # Change op type
1815 op.type = Op.DepthwiseConv2DBias
1816 # Set IFM/OFM shapes after changing op type
1817 op.set_ifm_ofm_shapes()
1818
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001819 weight_scale, bias = 1, 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001820 ofmq, ifmq = op.ofm.quantization, inp.quantization
Johan Alfvén9d51ec42022-10-27 16:30:01 +02001821 if ifmq.is_scaling_equal(ofmq):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001822 # Here we can just use a simple AvgPool with truncating rounding,
1823 # as we're emulating simple integer division.
1824 op.rounding_mode = NpuRoundingMode.TRUNCATE
1825 op.type = Op.AvgPool
1826 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1827 else:
1828 op.rounding_mode = NpuRoundingMode.NATURAL
1829 weight_scale = 1 / (h * w)
1830 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1831 bias = -ifmq.zero_point * h * w
1832 fiq = ifmq.clone()
1833 fiq.zero_point = 0
1834 op.forced_input_quantization = fiq
1835
1836 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001837 def extend_dims(dim, in_shape):
1838 if dim < 4:
1839 in_shape = [1] + in_shape
1840 if dim == 2:
1841 in_shape += [1]
1842 return in_shape
1843
1844 if dims < 4 or dims_ofm < 4:
1845 # Fix the ofm dimension when keep_dims is false
1846 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1847 if isinstance(axis, int) and dims_ofm + 1 == dims:
1848 ofm_shape.insert(axis, 1)
1849 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1850 for i in axis:
1851 ofm_shape.insert(i, 1)
1852 shape = extend_dims(dims, shape)
1853 dims_ofm = len(ofm_shape)
1854 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1855 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001856
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001857 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001858 weight_shape = None
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001859 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001860 # This can only happen and be done for multiple axes, and
1861 # h * w <= 256 for DepthwiseConv2DBias
1862 # h * w <= 4096 for AvgPool
1863 # which is checked in supported ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001864 shape = [shape[0], 1, h * w, shape[3]]
1865 op.ifm_shapes[0] = Shape4D(shape)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001866 weight_shape = [1, h * w, shape[3], shape[0]]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001867 if h > 256 and op.type == Op.AvgPool:
1868 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1869
1870 # If the AvgPool version is used, we don't need to do anything else
1871 if op.type == Op.AvgPool:
wilisa0179a89042022-11-02 17:18:43 +00001872 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001873 return op
1874
1875 # Make unit weight tensor quantization
1876 weight_quant = ifmq.clone()
1877 weight_quant.min = 0
1878 weight_quant.max = 255
1879 weight_quant.scale_f32 = weight_scale
1880 weight_quant.zero_point = 0
1881
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001882 if weight_shape is None:
1883 # Set weight shape to [H,W,C,B]
1884 weight_shape = [h, w, shape[3], shape[0]]
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001885
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001886 # Add unit weight tensor
1887 op.set_input_tensor(
1888 create_const_tensor(
1889 "weights",
1890 weight_shape,
1891 inp.dtype,
1892 np.ones(weight_shape),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001893 quantization=weight_quant,
1894 ),
1895 1,
1896 )
James Peet7519d502021-07-19 16:47:58 +01001897 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001898
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001899 # Add bias tensor
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001900 bias_shape = [shape[-1]]
1901 op.inputs.append(create_const_tensor("bias", bias_shape, DataType.int32, np.ones(bias_shape) * bias))
wilisa0179a89042022-11-02 17:18:43 +00001902 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001903
1904 return op
1905
1906
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001907def optimise_quantize(op: Operation, arch, nng):
1908
1909 if op.type == Op.Quantize and op.run_on_npu:
1910
1911 ifm, ofm = op.get_ifm_ofm()
1912 input_values = ifm.values
1913
1914 # Guard clause - input not const or no values to quantize
1915 if ifm.ops[0].type != Op.Const or input_values is None:
1916 return op
1917
1918 # Singular val in numpy array, convert to indexable array
1919 if input_values.ndim == 0:
1920 input_values = np.array([input_values])
1921
Fredrik Svedberg11563172022-07-06 14:54:12 +02001922 # requantized int8 to int8 or int16 to int16
1923 if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001924
1925 # scale needs to use double precision to match TFLite reference kernel
1926 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
1927 effective_multiplier, effective_shift = quantise_scale(effective_scale)
1928
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001929 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001930 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001931 input_val = val - ifm.quantization.zero_point
1932
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001933 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
1934 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001935
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001936 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
1937 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001938
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001939 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
1940 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001941
1942 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001943 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001944
1945 quantized_vals = []
1946 for val in input_values:
1947
1948 # Derive quantized value
1949 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001950 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
1951 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001952
1953 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001954 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
1955
1956 # Unsupported data type
1957 else:
1958 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001959
1960 # Make quantize op const and disconnect from parent node
1961
1962 # Remove reference of the current quant op from the parent tensor's consumer list
1963 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1964
1965 # Clear any references to parent node
1966 op.inputs = []
1967
1968 # Convert this quantize op to const
1969 op.type = Op.Const
1970
1971 return op
1972
1973
Ayaan Masood4965fae2022-06-29 11:30:57 +01001974def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
1975 """Static optimisation for SHAPE operator output value known at compile time"""
1976
1977 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
1978
1979 if op.type == Op.Shape and op.run_on_npu:
1980
1981 ifm, ofm = op.get_ifm_ofm()
1982
1983 if len(ifm.shape) != ofm.shape[0]:
1984 return op
1985
1986 # Remove reference of the current shape op from the parent tensor's consumer list
1987 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1988
1989 # Clear any references to parent node
1990 op.inputs = []
1991
1992 # Convert this SHAPE op to const
1993 op.type = Op.Const
1994
1995 # Add size calculation to shape output tensors
1996 ofm.values = np.array(ifm.shape)
1997
1998 return op
1999
2000
Tim Hallea4ba662022-11-11 18:19:53 +00002001def fixup_dilation_gt2(op, arch, nng):
2002 assert op.run_on_npu
2003 if op.type == Op.Conv2DBias or op.type == Op.DepthwiseConv2DBias:
2004 dilation_w, dilation_h = op.get_kernel_dilation()
2005
2006 # if dilation in either axis is greater than that supported by the hardware then we must manually dilate the
2007 # kernel
2008 if dilation_w > 2 or dilation_h > 2:
2009 kernel_w, kernel_h = op.get_kernel_size()
2010 kernel_ic = op.weights.shape[-2]
2011 kernel_oc = op.weights.shape[-1]
2012
2013 # if the dilation is a multiple of 2 then the hardware dialtion can be enabled to provide that multiple
2014 # of 2. this allows the kernel size to be reduced (via the scaled dilation) by half in that dimension.
2015 # odd = 1, even = 2
2016 hw_dilation_h = 1 if (dilation_h & 1) else 2
2017 hw_dilation_w = 1 if (dilation_w & 1) else 2
2018
2019 scale_dilation_h = dilation_h // hw_dilation_h
2020 scale_dilation_w = dilation_w // hw_dilation_w
2021
2022 # create new empty kernel (HWIO format)
2023 new_kernel_h = (kernel_h - 1) * scale_dilation_h + 1
2024 new_kernel_w = (kernel_w - 1) * scale_dilation_w + 1
2025
2026 new_kernel_shape = [new_kernel_h, new_kernel_w, kernel_ic, kernel_oc]
2027 new_kernel_values = np.zeros(new_kernel_shape, dtype=op.weights.values.dtype)
2028
2029 # copy the original kernel values into the new sparse kernel
2030 for h in range(0, kernel_h):
2031 for w in range(0, kernel_w):
2032 new_h = h * scale_dilation_h
2033 new_w = w * scale_dilation_w
2034 new_kernel_values[new_h, new_w, :, :] = op.weights.values[h, w, :, :]
2035
2036 # update the weight tensor with the new dilated kernel
2037 op.weights.shape = new_kernel_shape
2038 op.weights.values = new_kernel_values
2039
2040 # enable(=2) / disable(=1) hardware dilation
2041 op.attrs["dilation"] = (1, hw_dilation_h, hw_dilation_w, 1) # nhwc format
2042 op.attrs["dilation_h_factor"] = hw_dilation_h
2043 op.attrs["dilation_w_factor"] = hw_dilation_w
2044
2045 return op
2046
2047
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002048def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02002049 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002050 return op
2051
2052
wilisa0146c94772023-02-08 09:56:14 +00002053def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
Fredrik Svedberg11563172022-07-06 14:54:12 +02002054 # Compile time static optimisations
wilisa0146c94772023-02-08 09:56:14 +00002055 optimisation_list = [
2056 optimise_quantize,
2057 convert_shape_op_to_constant_tensor,
2058 fixup_or_check_asymmetric_weights(force_symmetric_int_weights),
2059 ]
Ayaan Masood25f48dd2022-06-29 18:16:04 +01002060
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002061 for idx, sg in enumerate(nng.subgraphs):
2062 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002063 nng,
2064 sg,
2065 arch,
2066 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01002067 optimisation_list,
2068 rewrite_unsupported=False,
2069 )
2070
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002071 # Pre-processing step
wilisa0146c94772023-02-08 09:56:14 +00002072 pre_process_list = [supported_operator_check, set_ifm_ofm_op_shapes]
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02002073
Ayaan Masood4965fae2022-06-29 11:30:57 +01002074 for idx, sg in enumerate(nng.subgraphs):
2075 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
2076 nng,
2077 sg,
2078 arch,
2079 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02002080 pre_process_list,
2081 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002082 )
2083
2084 # Handle Concat Ops
2085 for idx, sg in enumerate(nng.subgraphs):
2086 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
2087 sg.refresh_after_modification()
2088
2089 # Handle Split Ops
2090 for idx, sg in enumerate(nng.subgraphs):
2091 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
2092 nng,
2093 sg,
2094 arch,
2095 [],
2096 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
2097 rewrite_unsupported=False,
2098 )
2099
2100 for idx, sg in enumerate(nng.subgraphs):
2101 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002102 nng,
2103 sg,
2104 arch,
2105 [rewrite_split_ops],
2106 [],
2107 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002108 )
2109
Johan Alfvena5e1b622023-02-02 14:59:03 +01002110 # Bypass or rewrite memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002111 for idx, sg in enumerate(nng.subgraphs):
2112 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002113 nng,
2114 sg,
2115 arch,
2116 [],
Johan Alfvena5e1b622023-02-02 14:59:03 +01002117 [bypass_memory_only_ops],
Jonas Ohlssond8575072022-03-30 10:30:25 +02002118 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002119 )
2120
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002121 # Rewrite of operators
2122 op_rewrite_list = [
2123 set_tensor_equivalence,
2124 convert_mean_to_depthwise_conv_or_avgpool,
2125 convert_depthwise_to_conv,
2126 convert_conv_to_fc,
2127 convert_softmax,
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02002128 convert_prelu,
Fredrik Svedberg36424312022-09-16 09:39:26 +02002129 convert_mul_max_to_abs_or_lrelu,
2130 convert_lrelu,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002131 convert_hardswish_to_lut,
2132 rewrite_fully_connected_input,
2133 convert_batched_fc_shape,
2134 fixup_conv2d_backprop,
2135 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002136 reorder_depthwise_weights,
Rickard Bolin6986a072022-12-19 12:33:40 +00002137 convert_argmax_to_depthwise_conv_and_max_pool,
Tim Hall885033b2022-07-21 11:46:03 +01002138 fixup_resize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002139 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01002140 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002141 convert_tanh_sigmoid_to_lut,
2142 replace_pad_by_hw_pad,
Tim Hallea4ba662022-11-11 18:19:53 +00002143 fixup_dilation_gt2,
Raul Farkas72c6a242023-03-16 16:38:05 +00002144 fixup_strided_conv,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002145 ]
2146
2147 for idx, sg in enumerate(nng.subgraphs):
2148 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002149 nng,
2150 sg,
2151 arch,
2152 [],
2153 op_rewrite_list,
2154 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002155 )
2156
2157 for idx, sg in enumerate(nng.subgraphs):
2158 # remove passthrough tensors and attempt further optimizations
2159 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
2160 nng,
2161 sg,
2162 arch,
2163 [remove_passthrough_tensor],
2164 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
2165 )
2166
2167 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
2168 # since ifm/ofm_shapes are of importance to this function
2169 for sg in nng.subgraphs:
2170 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
2171 sg.refresh_after_modification()
2172
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01002173 # Make sure that const optimisations on subgraph outputs are handled correctly
2174 for sg in nng.subgraphs:
2175 for ofm in sg.output_tensors:
2176 if ofm.is_const and ofm.ops[0].type_changed:
2177 # Subgraph output cannot be const - insert a memory copy
2178 op = ofm.ops[0]
2179 ofm_clone = ofm.clone()
2180 ofm_clone.values = ofm.values
2181 ofm.values = None
Tim Hall3b1578e2023-01-13 17:57:25 +00002182 zero = create_const_tensor("zero", [1], ofm.dtype, [0], quantization=ofm.quantization)
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01002183 memcpy = create_add_nop(f"{ofm.name}_copy")
2184 memcpy.add_input_tensor(ofm_clone)
2185 memcpy.add_input_tensor(zero)
2186 memcpy.set_output_tensor(ofm)
2187 memcpy.set_ifm_ofm_shapes()
2188 op.set_output_tensor(ofm_clone)
2189 DebugDatabase.add_optimised(op, memcpy)
2190
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002191 return nng