blob: 3a49309d255b08bea7142cfc7c1c913219377a0a [file] [log] [blame]
Tim Hall3b1578e2023-01-13 17:57:25 +00001# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020017# Description:
18# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
19# to do the traversal of the graph.
20import math
21import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020022
23import numpy as np
24
25from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020026from . import rewrite_graph
27from . import scaling
28from .api import NpuRoundingMode
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020029from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020030from .data_type import DataType
31from .debug_database import DebugDatabase
32from .errors import UnsupportedFeatureError
33from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020034from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020035from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020036from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020037from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020038from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020039from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020040from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020041from .graph_optimiser_util import needed_total_padding
42from .graph_optimiser_util import set_ifm_ofm_op_shapes
43from .graph_optimiser_util import set_tensor_equivalence
44from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020045from .numeric_util import round_away_zero
46from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020047from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020048from .operation import NpuBlockType
49from .operation import Op
50from .operation import Operation
51from .operation import Padding
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +010052from .operation_util import create_add_nop
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020053from .operation_util import create_avgpool_nop
54from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010055from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020056from .shape4d import Shape4D
57from .softmax import SoftMax
58from .tensor import check_quantized_tens_scaling_equal
59from .tensor import create_const_tensor
60from .tensor import create_equivalence_id
61from .tensor import QuantizationParameters
62from .tensor import Tensor
63from .tensor import TensorPurpose
64from .tflite_mapping import optype_to_builtintype
65
66passthrough_nodes = (Op.Identity,)
67
68
69def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
70 """Creates an average pool for the given concat op/input feature map"""
71 ofm = concat_op.ofm
72 avgpool_op = create_avgpool_nop(name)
73 avgpool_op.inputs = [ifm]
74 avgpool_op.outputs = [ofm]
75
76 avgpool_op.write_offset = write_offset
77 avgpool_op.write_shape = ifm_shape
78 ofm.ops.append(avgpool_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020079 avgpool_op.ifm_shapes.append(ifm_shape)
80 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
81 avgpool_op.memory_function = Op.ConcatSliceWrite
wilisa0179a89042022-11-02 17:18:43 +000082 DebugDatabase.add_optimised(concat_op, avgpool_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020083 return avgpool_op
84
85
86def remove_passthrough_tensor(tens, arch, nng):
87 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
88 assert len(tens.ops[0].inputs) == 1
89 tens = tens.ops[0].inputs[0]
90 return tens
91
92
93def rewrite_concat_ops(op, arch):
94 if not op.run_on_npu or not op.type.is_concat_op():
95 return
96
97 axis_4D = 0
98 ofm = op.ofm
99 ofm.ops = []
100 offset = 0
101
102 unfuse_activation_function(op)
103
104 if op.type == Op.Pack:
105 # Pack is also referred to as Stack
106 axis = int(op.attrs["axis"])
107 if axis < 0: # Convert to positive axis
108 axis = len(op.inputs[0].shape) + 1 + axis
109
110 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
111
112 axis_4D = axis + (4 - len(desired_shape))
113
114 for idx, inp in enumerate(op.inputs):
115 op.ifm_shapes[idx] = Shape4D(desired_shape)
116 op.type = Op.PackReshaped
117
118 inputs, axis = op.get_concat_inputs_axis()
119 for idx, inp in enumerate(inputs):
120 if op.type != Op.PackReshaped:
121 op.ifm_shapes[idx] = Shape4D(inp.shape)
122 if axis >= 0:
123 axis_4D = axis + (4 - len(inp.shape))
124 else:
125 axis_4D = axis
126 write_offset = [0, 0, 0, 0]
127 write_offset[axis_4D] = offset
128 concat_end = offset + op.ifm_shapes[idx][axis_4D]
129 create_avg_pool_for_concat(
130 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
131 )
132 offset = concat_end
133 assert ofm.shape[axis] == offset
134
135 return op
136
137
138def rewrite_split_ops(tens, arch, nng):
139
140 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
141 split_op = tens.ops[0]
142
143 # Not supported so leave it and run on CPU
144 if not split_op.run_on_npu:
145 return tens
146
147 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
148
149 tens.ops = []
150 new_op = Operation(Op.SplitSliceRead, split_op.name)
151 new_op.inputs = [inp]
152 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000153 if None in (offset_end, offset_start):
154 read_shape = None
155 else:
156 # the read shape is relative to each start offset
157 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200158
159 # For Split the offset cannot be extracted from the tensor so it has to
160 # be calculated from the index of the output tensor
161 if axis is not None:
162 # Get the start and end of the split
163 offset_start = [0] * 4
164 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
165 for idx, out in enumerate(outputs):
166 if axis_4D_list is not None:
167 axis_4D = axis_4D_list[idx]
168 else:
169 split_op.ofm_shapes[idx] = Shape4D(out.shape)
170 if axis >= 0:
171 axis_4D = axis + (4 - len(out.shape))
172 else:
173 axis_4D = axis
174
175 if out == tens:
176 ofm_shape_idx = idx
177 read_shape = split_op.ofm_shapes[idx]
178 break
179
180 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
181
182 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
183 new_op.read_shapes[0] = read_shape
184 new_op.run_on_npu = True
185 new_op.set_output_tensor(tens)
186 new_op.ifm_shapes.append(Shape4D(inp.shape))
187 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
188 DebugDatabase.add_optimised(split_op, new_op)
189
190 return tens
191
192
193def remove_SplitSliceRead(op, arch):
194
195 if op.type == Op.SplitSliceRead:
196 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
197 if (
198 len(op.ofm.consumer_list) == 1
199 and op.ofm.consumer_list[0] is not None
200 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200201 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200202 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
203 ):
204 # SplitSliceRead can be performed by tensor consumer
205 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200206 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200207 else:
208 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
209 avgpool_op.add_input_tensor(op.ifm)
210 avgpool_op.outputs = [op.ofm]
211 op.ofm.ops.remove(op)
212 op.ofm.ops.append(avgpool_op)
213 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
214 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
215 avgpool_op.read_offsets[0] = op.read_offsets[0]
216 avgpool_op.read_shapes[0] = op.read_shapes[0]
217
218 op.ifm.consumer_list.remove(op)
219 DebugDatabase.add_optimised(op, avgpool_op)
220
221
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200222def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
223 k_w, k_h = kernel.dilated_wh()
224 s_x, s_y = kernel.stride
225 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
226 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
227 if padding_type == Padding.SAME:
228 left_pad = (xpad + 0) // 2
229 right_pad = (xpad + 1) // 2
230 top_pad = (ypad + 0) // 2
231 bottom_pad = (ypad + 1) // 2
232 elif padding_type == Padding.VALID:
233 left_pad = 0
234 right_pad = 0
235 top_pad = 0
236 bottom_pad = 0
237 elif padding_type == Padding.EXPLICIT:
238 # Padding is specified in a PAD operator which has been bypassed.
239 top, left, bottom, right = explicit_padding
240 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
241 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Rickard Bolin9ae34552022-06-09 13:07:17 +0000242 elif padding_type == Padding.TILE:
243 # The values in the explicit padding only represent the "direction" in which to pad
244 top_pad, left_pad, bottom_pad, right_pad = explicit_padding
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200245 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000246 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200247 padding = (top_pad, left_pad, bottom_pad, right_pad)
248 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
249 return padding, skirt
250
251
252def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
253 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
254 if padding_type == Padding.SAME:
255 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
256 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
257 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
258 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
259 left_pad = max(kernel_width - 1 - right_pad, 0)
260 top_pad = max(kernel_height - 1 - bottom_pad, 0)
261 elif padding_type == Padding.VALID:
262 right_pad = max(kernel_width - 2, 0)
263 bottom_pad = max(kernel_height - 2, 0)
264 left_pad = kernel_width - 1
265 top_pad = kernel_height - 1
266 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000267 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200268 padding = (top_pad, left_pad, bottom_pad, right_pad)
269 skirt = padding
270 return padding, skirt
271
272
273def fixup_conv2d_backprop(op, arch, nng):
274 if op.type == Op.Conv2DBackpropInput:
275 # flip the inputs
276 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
277 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000278 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200279
280 # Update strides
281 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
wilisa0179a89042022-11-02 17:18:43 +0000282 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200283
284 return op
285
286
287# Convert the op to an elementwise add
Tim Hall885033b2022-07-21 11:46:03 +0100288def convert_resize_1x1_to_add(op):
289 op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200290 op.name = op.name + "_add"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200291 # Create an input tensor filled with zeros
wilisa018289d512023-01-12 08:17:23 +0000292 name = op.inputs[1].name + "_add"
293 dtype = op.inputs[0].dtype
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200294 shape = op.ofm_shapes[0].as_list()
wilisa018289d512023-01-12 08:17:23 +0000295 values = np.zeros(shape, dtype.as_numpy_type())
296 quantization = QuantizationParameters(0.0, 255.0)
297 quantization.scale_f32 = 1.0
298 quantization.zero_point = 0
wilisa0116b5e5e2023-02-14 12:03:59 +0000299 op.inputs[1] = op.inputs[0]
300 op.set_input_tensor(create_const_tensor(name, shape, dtype, values, quantization=quantization), 0)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200301 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000302 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200303
304 return op
305
306
Tim Hall885033b2022-07-21 11:46:03 +0100307# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
308# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
309# to select the appropriate nearest neighbor value
310def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
311 ifm = op.ifm
312 ofm = op.ofm
313 output_depth = ofm.shape[-1]
314 dw_op_attrs = {
315 "padding": Padding.VALID,
316 "stride_h": 1,
317 "stride_w": 1,
318 "strides": (1, 1, 1, 1),
319 "depth_multiplier": 1,
320 "channel_multiplier": 1,
321 "dilation_h_factor": 1,
322 "dilation_w_factor": 1,
323 "dilation": (1, 1, 1, 1),
324 }
325
326 # change resizebilinear to depthwise
327 op.type = Op.DepthwiseConv2DBias
328 op.attrs.update(dw_op_attrs)
329 op.set_input_tensor(ifm, 0) # ifm tensor index
330 op.activation = None
331
332 # add input resample to resize by x2
333 op.ifm_resampling_mode = resampling_mode.NEAREST
334
335 # don't care about the rounding mode as it is nearest neighbor
336
337 # setup weight tensor
338 weight_quant = QuantizationParameters()
339 weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
340 weight_quant.zero_point = 0
341 weight_quant.quant_dim = 0
342 ofm_dtype = ofm.dtype
Tim Hall3b1578e2023-01-13 17:57:25 +0000343 if ofm_dtype.type == BaseType.UnsignedInt:
Tim Hall885033b2022-07-21 11:46:03 +0100344 weight_quant.quant_min = 0
345 weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
346 else:
Tim Hall885033b2022-07-21 11:46:03 +0100347 weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
348 weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
349
350 weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
351
352 # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
353 # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
354 # below-and-right (i.e. next) to it (D).
355 # 0---1---2
356 # | A | B |
357 # 1---*---+
358 # | C | D |
359 # 2---+---+
360 weight_values = [0] * (upscale_factor * upscale_factor)
361 centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
362 weight_values[centre_coeff] = 1
363
364 # add weight tensor, this will discard the size tensor of the resize op
365 op.set_input_tensor(
366 create_const_tensor(
367 "weights",
368 weight_shape,
Tim Hall3b1578e2023-01-13 17:57:25 +0000369 ofm_dtype,
Tim Hall885033b2022-07-21 11:46:03 +0100370 np.array(weight_values).reshape(weight_shape),
Tim Hall885033b2022-07-21 11:46:03 +0100371 quantization=weight_quant,
372 ),
373 1, # inputs tensor weight index
374 )
375
376 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
377 # need to append the bias tensor as resize ops only have 2 inputs
378 assert len(op.inputs) == 2
379 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +0200380 fixup_bias_tensors(op, None, None, DataType.int32)
Tim Hall885033b2022-07-21 11:46:03 +0100381
382 # finally update the shape incase we've change the tensor shapes or connections
383 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000384 DebugDatabase.add_optimised(op, op)
Tim Hall885033b2022-07-21 11:46:03 +0100385
386 return op
387
388
389# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
390# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
391# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
392def convert_resize_to_upscale_and_average_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200393 pre_op = op
394 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000395 dtype = op.ifm.dtype
Tim Hall885033b2022-07-21 11:46:03 +0100396
Rickard Boline546def2022-01-25 15:45:00 +0000397 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Tim Hall47c76362022-07-18 21:26:47 +0100398 op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
Tim Hall3c5cfe92022-03-16 16:31:57 +0000399 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200400
401 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
Tim Hall47c76362022-07-18 21:26:47 +0100402
403 # Get upscale factor that was calculated in the supported operators check
404 upscale_factor = op.attrs["upscale_factor"]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200405
Rickard Boline546def2022-01-25 15:45:00 +0000406 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100407 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
408 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
Rickard Boline546def2022-01-25 15:45:00 +0000409 n = int(np.log2(upscale_factor))
410
Tim Hall885033b2022-07-21 11:46:03 +0100411 # Perform x2 upscaling n-1 times
Rickard Boline546def2022-01-25 15:45:00 +0000412 scaled_op = pre_op
413 for count in range(n - 1):
414 if count > 0:
415 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200416 scaled_op.inputs[0] = pre_op.outputs[0]
417
Tim Hall885033b2022-07-21 11:46:03 +0100418 # Nearest neighbor x2 upscaling
Tim Hall47c76362022-07-18 21:26:47 +0100419 upscaled_shape = upscaled_shape * 2
Rickard Boline546def2022-01-25 15:45:00 +0000420 shape = op.ofm_shapes[0].as_list()
421 shape[1:3] = upscaled_shape
422 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
423 out_tens.quantization = op.outputs[0].quantization.clone()
424 scaled_op.set_output_tensor(out_tens)
425 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200426
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200427 scaled_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000428 DebugDatabase.add_optimised(op, scaled_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200429
Tim Hall885033b2022-07-21 11:46:03 +0100430 # Last x2 upscaling
Rickard Boline546def2022-01-25 15:45:00 +0000431 if n > 1:
432 scaled_op = op.clone(f"_{n-1}")
433 scaled_op.inputs[0] = pre_op.outputs[0]
Tim Hall885033b2022-07-21 11:46:03 +0100434
435 if scaled_op.original_type == Op.ResizeBilinear:
436 if scaled_op.attrs["align_corners"]:
437 # no padding
438 scaled_op.attrs["padding"] = Padding.VALID
439 else:
440 # padding to the right and bottom (limits average pool to 8x8 kernel)
441 scaled_op.attrs["padding"] = Padding.EXPLICIT
442 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
443
444 # kernal size dependent on the upscaling factor
445 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
446 else: # Op.ResizeNearestNeighbor
447 if scaled_op.attrs["align_corners"]:
448 # use depthwise conv to select the correct value
449 scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
450 else:
Johan Alfvéna64616c2022-10-17 12:29:12 +0200451 # Keep 1x1 kernel and average pool, this applies both when
452 # half-pixel-centers is True and False. Calculations are the
453 # same in the reference.
Tim Hall885033b2022-07-21 11:46:03 +0100454 pass
455
Rickard Boline546def2022-01-25 15:45:00 +0000456 scaled_op.outputs = outputs
457 scaled_op.outputs[0].ops = [scaled_op]
458 scaled_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000459 DebugDatabase.add_optimised(op, scaled_op)
Rickard Boline546def2022-01-25 15:45:00 +0000460
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200461 return op
462
463
Rickard Bolinfea15162022-07-04 16:19:16 +0000464def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
465 def _compute_interpolation_values(index, input_size, output_size):
466 scale = input_size / output_size
467 scaled_value = (index + 0.5 * half_pixel_centers) * scale - 0.5 * half_pixel_centers
468 lower_bound = max(np.floor(scaled_value), 0)
469
470 return scaled_value, lower_bound
471
472 def _compute_kernels(input_height, input_width, output_height, output_width):
473 kernels = []
474 for y in (1, 2):
475 for x in (1, 2):
476 sv_h, lb_h = _compute_interpolation_values(y, input_height, output_height)
477 sv_w, lb_w = _compute_interpolation_values(x, input_width, output_width)
478
479 # Interpolation values calculated for (x, y) = ([1, 2], [1, 2]) will always generalize to the whole
480 # input for upscale = 2 and input sizes >= 2x2 and be in the correct order for going left-to-right,
481 # top-to-bottom - same as the depthwise convolution strides across each tile
482 kernel = np.zeros((2, 2))
483 kernel[1, 1] = (1 - (sv_h - lb_h)) * (1 - (sv_w - lb_w))
484 kernel[0, 1] = (sv_h - lb_h) * (1 - (sv_w - lb_w))
485 kernel[1, 0] = (1 - (sv_h - lb_h)) * (sv_w - lb_w)
486 kernel[0, 0] = (sv_h - lb_h) * (sv_w - lb_w)
487 kernel *= 16
488 kernels.append(kernel)
489
490 return kernels
491
492 def _build_convolutions(op, kernels):
493 dw_op_attrs = {
494 "padding": Padding.TILE,
495 "stride_h": 1,
496 "stride_w": 1,
497 "strides": (1, 1, 1, 1),
498 "depth_multiplier": 1,
499 "channel_multiplier": 1,
500 "dilation_h_factor": 1,
501 "dilation_w_factor": 1,
502 "dilation": (1, 1, 1, 1),
503 }
504 ifm = op.ifm
505 ofm = op.ofm
506 ofm.ops = []
507 elem_size = 2 if ofm.dtype == DataType.int16 else 1
508
509 n, h, w, c = ifm.shape
510 _, _, ow, _ = ofm.shape
511
512 intermediate_tens = Tensor(ifm.shape, ifm.dtype, "intermediate_tens")
513 intermediate_tens.quantization = op.outputs[0].quantization.clone()
514 avgpool_op = op
515 avgpool_op.name = "rb_init_avgpool"
516 avgpool_op.type = Op.AvgPool
517 avgpool_op.attrs["padding"] = Padding.VALID
518 avgpool_op.attrs["stride_w"] = 1
519 avgpool_op.attrs["stride_h"] = 1
520 avgpool_op.attrs["filter_width"] = 1
521 avgpool_op.attrs["filter_height"] = 1
522 avgpool_op.attrs["strides"] = [1, 1, 1, 1]
523 avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
524
525 avgpool_op.add_input_tensor(ifm)
526 avgpool_op.set_output_tensor(intermediate_tens)
527 avgpool_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000528 DebugDatabase.add_optimised(op, op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000529
530 dw_conv = Operation(Op.DepthwiseConv2DBias, "depthwise_conv")
531 dw_conv._original_type = Op.ResizeBilinear
532 dw_conv.write_shape = Shape4D(n, h, w, c)
533 dw_conv.write_offset = Shape4D(0, 0, 0, 0)
534
535 # Set the output rounding mode. Resize bilinear requires rounding away from zero. Therefore, we need to
536 # adjust the accumulated value by a "small" amount before applying natural rounding. The "small" amount
537 # should be big enough to cause a x.5 to be rounded correctly but small enough not to cause smaller
538 # values to be incorrectly rounded
539 ofm.quantization.next_after = True
540 dw_conv.rounding_mode = NpuRoundingMode.NATURAL
541
542 # Double height and width stride to write the output of each of the four depthwise convolutions below
543 # interleaved with each other when combined with OFM tile base offsets.
544 dw_conv.ofm_stride_multiplier = [1, 2, 2] # C/H/W
545
546 # Choose tile padding direction - pad by 1 with edge values in two direction.
547 # For example, TL (top left) will pad top and left in H/W-plane in all channels.
548 directions = [[1, 1, 0, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 1]] # TL, TR, BL, BR
549 for i in (0, 1):
550 for j in (0, 1):
551 index = i * 2 + j
552 dw_conv.name = f"depthwise_conv_{index}"
553 dw_op_attrs["explicit_padding"] = directions[index]
554 dw_conv.attrs.update(dw_op_attrs)
555
556 # This will offset the start of the write by modifying the Tile 0 base address
557 dw_conv.tile_base_offsets_ofm[0] = (i * ow + j) * c * elem_size
558
559 ofm.ops.append(dw_conv)
560 dw_conv.outputs = [ofm]
561
562 kernel = kernels[index]
563 shape = [2, 2, 1, c]
564 kernel = np.dstack([kernel] * c)
565
566 quant = QuantizationParameters()
567 quant.zero_point = 0
568 quant.scale_f32 = 1.0 / 16
569
570 dw_conv.inputs = []
571 dw_conv.add_input_tensor(intermediate_tens)
572 dw_conv.add_input_tensor(
573 create_const_tensor(
574 "weights",
575 shape,
576 intermediate_tens.dtype,
577 np.array(kernel).reshape(shape),
Rickard Bolinfea15162022-07-04 16:19:16 +0000578 quantization=quant,
579 ),
580 )
581
582 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
583 # need to append the bias tensor as resize ops only have 2 inputs
584 assert len(dw_conv.inputs) == 2
585 dw_conv.inputs.append(None)
Rickard Bolin017b4cc2022-09-23 10:16:48 +0000586 fixup_bias_tensors(dw_conv, None, None, dtype=DataType.int32)
Rickard Bolinfea15162022-07-04 16:19:16 +0000587
588 dw_conv.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000589 DebugDatabase.add_optimised(op, dw_conv)
590
Rickard Bolinfea15162022-07-04 16:19:16 +0000591 dw_conv = dw_conv.clone(f"_{index}")
592 return op
593
594 _, input_height, input_width, _ = op.ifm.shape
595 _, output_height, output_width, _ = op.ofm.shape
596
597 kernels = _compute_kernels(input_height, input_width, output_height, output_width)
598 op = _build_convolutions(op, kernels)
599
600 return op
601
602
Tim Hall885033b2022-07-21 11:46:03 +0100603def fixup_resize(op, arch, nng):
604 if op.type.is_resize_op() and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200605 if op.ifm_shapes[0] == op.ofm_shapes[0]:
Tim Hall885033b2022-07-21 11:46:03 +0100606 # Bypass the resize op which is essentially a NOP
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200607 op.inputs = op.inputs[:1]
608 op.type = Op.Identity
609 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
Tim Hall885033b2022-07-21 11:46:03 +0100610 convert_resize_1x1_to_add(op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000611 elif op.type == Op.ResizeBilinear and op.attrs.get("half_pixel_centers", False):
612 convert_resizebilinear_to_depthwise_convolutions(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200613 else:
Tim Hall885033b2022-07-21 11:46:03 +0100614 convert_resize_to_upscale_and_average_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200615
616 return op
617
618
619def convert_nop_split_to_identity(op, arch, nng):
620 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
621 # the list comprehension should return a list with a single tensor
622 # if it shouldn't, remove_passthrough_tensor will fail appropriately
623 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
624 op.type = Op.Identity
625 return op
626
627
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100628def rewrite_fully_connected_input(op: Operation, arch, nng):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200629
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100630 if op.type == Op.FullyConnected:
631 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
632 assert new_shape is not None, "Tensor can not be reshaped to 2D"
633 op.ifm_shapes[0] = new_shape
Johan Alfvén65835e02022-10-13 10:49:30 +0200634
635 if op.ifm_shapes[0].batch > 1 and op.ofm_shapes[0].batch == 1:
636 # If IFM is batching then also make sure OFM is batching
637 h, w = op.ofm_shapes[0].height, op.ofm_shapes[0].width
638 op.ofm_shapes[0] = Shape4D([h * w, 1, 1, op.ofm_shapes[0].depth])
639
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200640 return op
641
642
643def convert_batched_fc_shape(op, arch, nng):
644 if op.type == Op.FullyConnected:
645 # Check if the first dimension indicates batching
646 if op.ifm_shapes[0].batch > 1:
647 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
648 n = op.ifm_shapes[0].batch
649 h, w = batching_split.get(n, (1, n))
650 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
651
652 # Reshape Weights to be 4D. IO becomes HWIO
653 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100654 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
655 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200656
657 n = op.ofm_shapes[0].batch
658 h, w = batching_split.get(n, (1, n))
659 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
660 return op
661
662
663def unfuse_activation_function(op):
664 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
665 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
666 op.activation = None
667 out_tens = op.outputs[0]
668 intermediate_tens = out_tens.clone("_act_intermediate")
669 act_op.set_output_tensor(out_tens)
670 act_op.add_input_tensor(intermediate_tens)
671 op.set_output_tensor(intermediate_tens)
672 act_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000673 DebugDatabase.add_optimised(op, act_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200674
675
676def rewrite_stridedslice_output(op, arch, nng):
677 if not op.run_on_npu or op.type != Op.StridedSlice:
678 return op
679
680 new_axis_mask = op.attrs["new_axis_mask"]
681 shrink_axis_mask = op.attrs["shrink_axis_mask"]
682
683 if shrink_axis_mask == 0 and new_axis_mask == 0:
684 return op
685
686 axis_4D = [0] * len(op.outputs)
687 for idx, out_tens in enumerate(op.outputs):
688 output_shape = list(out_tens.shape)
689
690 if shrink_axis_mask != 0:
691 n = 0
692 axis = 0
693 while shrink_axis_mask:
694 prev_mask = shrink_axis_mask
695 n += 1
696 shrink_axis_mask &= shrink_axis_mask - 1
697 axis = int(math.log2(prev_mask - shrink_axis_mask))
698 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
699
700 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
701 op.attrs["shrink_axis_mask"] = 0
702 if axis >= 0:
703 axis_4D[idx] = axis + (4 - len(output_shape))
704 else:
705 axis_4D[idx] = axis
706 op.ofm_shapes[idx] = Shape4D(output_shape)
707
708 elif new_axis_mask != 0:
709 n = 0
710 axis = 0
711 while new_axis_mask:
712 prev_mask = new_axis_mask
713 n += 1
714 new_axis_mask &= new_axis_mask - 1
715 axis = int(math.log2(prev_mask - new_axis_mask))
716 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
717 new_axis_mask >>= 1
718
719 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
720 op.attrs["new_axis_mask"] = 0
721 if axis >= 0:
722 axis_4D[idx] = axis + (4 - len(output_shape))
723 else:
724 axis_4D[idx] = axis
725 op.ofm_shapes[idx] = Shape4D(output_shape)
726
727 op.attrs["split_axis_4D"] = axis_4D
728 return op
729
730
731def rewrite_unpack_output(op, arch, nng):
732 tens = op.outputs[0]
733 if op.run_on_npu and op.type == Op.Unpack:
734 # Unpack is also referred to as Unstack
735 axis = int(op.attrs["axis"])
736 if axis < 0: # Convert to positive axis
737 axis = len(op.inputs[0].shape) + 1 + axis
738 op.type = Op.UnpackReshaped
739 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
740
741 axis_4D = axis + (4 - len(desired_output_shape))
742 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
743
744 for idx, out_tens in enumerate(op.outputs):
745 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
746 return op
747
748
749def add_padding_fields(op, arch, nng):
750 if op.run_on_npu:
751 if "padding" in op.attrs:
752 input_shape = op.ifm_shapes[0]
753 output_shape = op.ofm_shapes[0]
754 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
755 kernel_size = op.inputs[1].shape[:2]
756 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
757 kernel_size = op.attrs["ksize"][1:3]
758 else:
759 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
760
761 if op.type == Op.Conv2DBackpropInputSwitchedBias:
762 upscaling_factor = output_shape.height // input_shape.height
763 padding, skirt = calc_upscaled_padding_and_skirt(
764 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
765 )
766 else:
767 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200768 op.attrs["padding"],
769 op.kernel,
770 input_shape,
771 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200772 )
773
774 op.attrs["explicit_padding"] = padding
775 op.attrs["skirt"] = skirt
776
777 return op
778
779
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200780def reorder_depthwise_weights(op, arch, nng):
781 if op.type.is_depthwise_conv2d_op():
782 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100783 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
784 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200785 weight_tensor.weight_transpose_depthwise = True
786
787 return op
788
789
Raul Farkas090f18a2023-01-24 16:29:06 +0000790def fixup_strided_conv(op, arch, nng):
791 if op.type != Op.Conv2DBias:
Louis Verhaard43d27582022-03-17 14:06:00 +0100792 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200793 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100794 weight_tensor = op.weights
795 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200796
Raul Farkas090f18a2023-01-24 16:29:06 +0000797 # Do not optimize if op is not the first in the network and stride is
798 # supported by the hardware
799 if op.op_index != 0 and stride_x < 4:
800 return op
801 op.ifm.needs_linear_format = True
802
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200803 if (
Raul Farkas090f18a2023-01-24 16:29:06 +0000804 (stride_x == 2 or stride_x == 4)
Louis Verhaard43d27582022-03-17 14:06:00 +0100805 and ifm_shape.depth <= 4
806 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200807 and weight_tensor is not None
808 and weight_tensor.shape[1] >= 2
809 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100810 k_w, _ = op.get_kernel_size()
Raul Farkas090f18a2023-01-24 16:29:06 +0000811 curr_padding_x = needed_total_padding(ifm_shape.width, stride_x, k_w)
812 optimised_padding_x = needed_total_padding(ifm_shape.width // stride_x, 1, (k_w + 1) // stride_x)
813 padding_type = op.attrs.get("padding", None)
814
815 # If padding is enabled, check if current padding matches optimised padding
816 if not padding_type or (padding_type != Padding.VALID and curr_padding_x != optimised_padding_x):
Louis Verhaard43d27582022-03-17 14:06:00 +0100817 # Horizontal padding would become different after optimisation; this would not work
818 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200819 # IFM
Raul Farkas090f18a2023-01-24 16:29:06 +0000820 op.ifm_shapes[0] = Shape4D(
821 [ifm_shape.batch, ifm_shape.height, ifm_shape.width // stride_x, ifm_shape.depth * stride_x]
822 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200823
824 # Weights
825 weight_shape = weight_tensor.shape
826 if weight_shape[1] % 2 != 0:
827 weight_shape[1] = weight_shape[1] + 1
828 padded_array = np.zeros(weight_shape)
829 for i in range(weight_shape[0]):
830 padded_array[i] = np.vstack(
831 [
James Peet7519d502021-07-19 16:47:58 +0100832 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200833 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
834 ]
835 )
James Peet7519d502021-07-19 16:47:58 +0100836 weight_tensor.values = padded_array
Raul Farkas090f18a2023-01-24 16:29:06 +0000837
838 # Change weight shape based on stride_x
839 weight_shape[1] //= stride_x
840 weight_shape[2] *= stride_x
841
James Peet7519d502021-07-19 16:47:58 +0100842 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200843 weight_tensor.set_all_shapes(weight_shape)
844 # If multiple copies of the weights are used, we could avoid
845 # them having the same address by changing the value_id
846 weight_tensor.value_id = uuid.uuid4()
847
848 # Strides
849 stride_x = 1
850 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
851
852 return op
853
854
855def convert_conv_to_fc(op, arch, nng):
856 # Conv 1x1 can be equivalent to Fully Connected.
857 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
858 # caching/double buffering for the weights.
859 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
860 if op.type == Op.Conv2DBias:
861 h = op.ifm_shapes[0].height
862 w = op.ifm_shapes[0].width
863 kh, kw, _, _ = op.inputs[1].shape
864 if h == 1 and w == 1 and kh == 1 and kw == 1:
865 # Overwrite this op as a Fully Connected Op
866 op.name += "_fc"
867 op.type = Op.FullyConnected
868 op.attrs = {
869 "weights_format": 0,
870 }
871 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
872 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100873 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
874 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200875
876 DebugDatabase.add_optimised(op, op)
877 return op
878
879
880def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
881 if op.run_on_npu and op.type.is_relu_op():
882 ifm = op.inputs[0]
883 ofm = op.outputs[0]
884 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
885 # and requires its own to be inserted
886 if not check_quantized_tens_scaling_equal(ifm, ofm):
887 # Override this op with its own primary op (avgpool)
888 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
889 # And fuse the original activation function to it
890 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200891 # Add explicit rescaling
892 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
893 multiplier, shift = scaling.quantise_scale(rescale)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200894 relu_fused_op.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200895 # Tidy up and assign the ifm and ofm to the new op
896 ifm.consumer_list.remove(op)
897
898 relu_fused_op.add_input_tensor(ifm)
899 relu_fused_op.set_output_tensor(ofm)
900 relu_fused_op.set_ifm_ofm_shapes()
901 op = relu_fused_op
902 return op
903
904
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200905def convert_softmax(op, arch, nng):
906 if op.type == Op.Softmax and op.run_on_npu:
907 softmax = SoftMax(op)
908 op = softmax.get_graph()
909 return op
910
911
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200912def convert_prelu(op, arch, nng):
913 if op.type == Op.Prelu:
914 ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
915 if None in (ifm, alpha, ofm):
916 return op
917
Fredrik Svedberg66591652022-08-29 10:51:27 +0200918 if alpha.values is not None:
919 # If const alpha check for possible optimisations
920 alpha_zp = alpha.quantization.zero_point
921 alpha_scale = alpha.quantization.scale_f32
922 # If all alpha values are the same the PReLU can be converted to LeakyRelu
Rickard Bolin5fdcf172022-12-19 12:56:17 +0000923 alpha_min = (alpha.values.min().astype(int) - alpha_zp) * alpha_scale
924 alpha_max = (alpha.values.max().astype(int) - alpha_zp) * alpha_scale
Fredrik Svedberg66591652022-08-29 10:51:27 +0200925 if alpha_min == alpha_max:
926 # or even a Relu
927 if alpha_min == 0:
928 new_op = Op.Relu
929 else:
930 new_op = Op.LeakyRelu
931 op.attrs["alpha"] = alpha_min
932 # setup alpha_scaling for bit exact result
933 ifm_scale = ifm.quantization.scale_f32
934 ofm_scale = ofm.quantization.scale_f32
935 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha_scale, ofm_scale)
936 op.attrs["alpha_scaling"] = (alpha.values.min() - alpha_zp, alpha_scale, alpha_shift)
937 # Change op type
938 op.type = new_op
939 op.name = op.name.replace("Prelu", new_op.name)
940 del op.inputs[1] # Remove alpha tensor
941 return op
942 elif alpha_max < 1:
943 # If alpha_max is less than 1 convert PReLU to Max(alpha * IFM, identity * IFM)
944 # Multiply with alpha tensor
945 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
946 mul_alpha.add_input_tensor(ifm)
947 mul_alpha.add_input_tensor(alpha)
948 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
949 mul_alpha.set_output_tensor(fm_alpha)
950 mul_alpha.set_ifm_ofm_shapes()
951 DebugDatabase.add_optimised(op, mul_alpha)
952 if check_quantized_tens_scaling_equal(ifm, ofm):
953 # No scaling is needed
954 fm_id = ifm
955 else:
956 # Add multiplication with identity
957 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
958 mul_identity.add_input_tensor(ifm)
959 # Create const tensor containing identity as scalar
960 quantization = ifm.quantization.clone()
961 quantization.scale_f32 = np.float32(1)
962 quantization.zero_point = 0
963 one = create_const_tensor("one_const", [], ifm.dtype, [1], quantization=quantization)
964 mul_identity.add_input_tensor(one)
965 # Make sure that fm_id is allocated to a different address than fm_alpha
966 fm_id = ofm.clone(op.name + "_id", set_unique=True)
967 mul_identity.set_output_tensor(fm_id)
968 mul_identity.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000969 DebugDatabase.add_optimised(op, mul_identity)
Fredrik Svedberg66591652022-08-29 10:51:27 +0200970
971 # Combine scaled and alpha multiplied values
972 max_op = Operation(Op.Maximum, op.name + "_max")
973 max_op.add_input_tensor(fm_alpha)
974 max_op.add_input_tensor(fm_id)
975 max_op.set_output_tensor(ofm)
976 max_op.set_ifm_ofm_shapes()
977
978 DebugDatabase.add_optimised(op, max_op)
979 ifm.consumer_list.remove(op)
980 return max_op
981
982 # Catch all PReLU conversion for the cases that could not be optimised above
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200983 no_scale_quant = ifm.quantization.clone()
984 no_scale_quant.scale_f32 = None
985 no_scale_quant.zero_point = 0
Fredrik Svedberg66591652022-08-29 10:51:27 +0200986 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200987
988 # Select values < 0
989 min_op = Operation(Op.Minimum, op.name + "_min")
990 min_op.add_input_tensor(ifm)
991 min_op.add_input_tensor(zero)
992 fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
993 min_op.set_output_tensor(fm_negative)
994 min_op.set_ifm_ofm_shapes()
995 DebugDatabase.add_optimised(op, min_op)
996
997 # and multiply with alpha tensor
998 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
999 mul_alpha.add_input_tensor(fm_negative)
1000 mul_alpha.add_input_tensor(alpha)
1001 fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
1002 mul_alpha.set_output_tensor(fm_alpha)
1003 mul_alpha.set_ifm_ofm_shapes()
1004 DebugDatabase.add_optimised(op, mul_alpha)
1005
1006 # Select (and scale) values > 0
1007 relu_op = Operation(Op.Relu, op.name + "_relu")
1008 relu_op.add_input_tensor(ifm)
1009 fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1010 relu_op.set_output_tensor(fm_scaled)
1011 relu_op.set_ifm_ofm_shapes()
1012 DebugDatabase.add_optimised(op, relu_op)
1013
1014 # Add scaled and alpha multiplied values (without scaling)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001015 add_op = Operation(Op.Add, op.name + "_add")
1016 add_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001017 add_op.add_input_tensor(fm_alpha)
1018 add_op.add_input_tensor(fm_scaled)
1019 add_op.set_output_tensor(ofm)
1020 add_op.set_ifm_ofm_shapes()
1021
1022 DebugDatabase.add_optimised(op, add_op)
1023 ifm.consumer_list.remove(op)
1024 op = add_op
1025
1026 return op
1027
1028
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001029def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
1030 r"""Whenever there is a subgraph with this topology:
1031
Jonas Ohlssond8575072022-03-30 10:30:25 +02001032 Input X For X = -1 or X > 0
1033 | \ / This subgraph can be replaced with either
1034 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
1035 | /
1036 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001037 """
1038
1039 if op.type == Op.Maximum:
1040 # finds the Mul input(s) to the Max
1041 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
1042 if len(muls) == 1:
1043 mul = muls[0].ops[0]
1044 elif len(muls) == 2:
1045 # In the case both inputs are Muls, find the one with the same input as the Max
Fredrik Svedberg66591652022-08-29 10:51:27 +02001046 mul_ifms = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1]
1047 if len(mul_ifms):
1048 mul = mul_ifms[0].ops[0]
1049 else:
1050 # Not using same input
1051 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001052 else:
1053 # No Mul inputs
1054 return op
1055
1056 # make sure the Mul doesn't have any other consumers
1057 mul_ofm = mul.outputs[0]
1058 if len(mul_ofm.consumers()) != 1:
1059 return op
1060 # make sure the Mul doesn't have a fused activation function
1061 if mul.activation:
1062 return op
1063 ifm, ofm = op.get_ifm_ofm()
1064 if ifm is None or ofm is None:
1065 return op
1066
1067 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1068 return op
1069 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
1070 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
1071 return op
1072
1073 # finds the branched input that goes to both the Max and the Mul
1074 shared = set(op.inputs) & set(mul.inputs)
1075 if len(shared) == 1:
1076 shared_in = shared.pop()
1077 # find the constant scalar input to the Mul
1078 const_tens = (set(mul.inputs) - {shared_in}).pop()
1079 # check that it is a scalar
1080 if const_tens.shape != []:
1081 return op
1082 const = const_tens.ops[0]
1083 # check that it is a constant
1084 if const.type != Op.Const:
1085 return op
1086 # Remove the Mul from the shared input's consumers
1087 shared_in.consumer_list.remove(mul)
1088 else:
1089 return op
1090
1091 val = const.outputs[0].values
1092 if val >= 0:
1093 new_op = Op.LeakyRelu
1094 op.attrs["alpha"] = val
1095 # to produce bit exact results, the alpha is not enough;
1096 # save additional scaling info in attr "alpha_scale", to be used as input
1097 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +01001098 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001099 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
1100 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
1101 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
1102 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
1103 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
1104 elif val == -1:
1105 new_op = Op.Abs
1106 else:
1107 return op
1108
1109 op.type = new_op
1110 op.name = op.name.replace("Maximum", new_op.name)
1111 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
1112 op.inputs = [shared_in]
1113 op.set_ifm_ofm_shapes()
1114
1115 # Record optimisation in debug database
1116 DebugDatabase.add_optimised(op, op)
1117
1118 return op
1119
1120
1121def convert_hardswish_to_lut(op, arch, nng):
1122 if op.type == Op.HardSwish:
1123 ifm, ofm = op.get_ifm_ofm()
1124 # Generate the LUT
1125 ifm_scale = np.double(ifm.quantization.scale_f32)
1126 ofm_scale = np.double(ofm.quantization.scale_f32)
1127 zp_in = ifm.quantization.zero_point
1128 zp_out = ofm.quantization.zero_point
1129 ifm_scale_hires = (1 / 128) * ifm_scale
1130 relu_multiplier = np.double(3 / 32768)
1131 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
1132 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
1133 # Use 16bit scale
1134 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
1135 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
1136
1137 values = []
1138 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1139 quantized_min = min(ix)
1140 quantized_max = max(ix)
1141 for x in ix:
1142 input_value = x - zp_in
1143 input_value_hires = input_value * 128
1144 # Compute the input value on essentially the output scale, not shifted yet
1145 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
1146 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
1147 relu_value = np.int16(input_value_hires)
1148 if relu_shift < 31:
1149 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
1150
1151 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
1152
1153 if relu_shift < 31:
1154 relu_value = fp_math.shift_left16(relu_value, 1)
1155
1156 if relu_shift > 31:
1157 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
1158
1159 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
1160 # Now convert that to a 16bit fixedpoint value in [0, 1]
1161 relu_value = (relu_value + (1 << 15)) >> 1
1162 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
1163 shift = 31 - out_shift
1164 shift = -shift if shift < 0 else 0
1165 # Finally apply the output shift
1166 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
1167 lut_result = min(quantized_max, max(quantized_min, lut_result))
1168 values.append(lut_result)
1169 return convert_to_lut(op, values, "hardswish")
1170 return op
1171
1172
1173def convert_lrelu_to_mul_max(op, arch):
1174 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
1175 # (the opposite of convert_mul_max_to_abs_or_lrelu)
1176 ifm, ofm = op.get_ifm_ofm()
1177 if ifm is None or ofm is None:
1178 return op
1179
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001180 alpha = np.float32(op.attrs["alpha"])
1181 use_mul_max = 0 < alpha < 1
Fredrik Svedberg36424312022-09-16 09:39:26 +02001182 is_converted_prelu = "alpha_scaling" in op.attrs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001183 if use_mul_max:
1184 mul_ifm = ifm
1185 new_op = Op.Maximum
1186 else:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001187 # Need to use a different approach for alpha < 0 or alpha > 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001188 no_scale_quant = ifm.quantization.clone()
1189 no_scale_quant.scale_f32 = None
1190 no_scale_quant.zero_point = 0
1191 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
1192
1193 # Select values < 0
1194 min_op = Operation(Op.Minimum, op.name + "_min")
1195 min_op.add_input_tensor(ifm)
1196 min_op.add_input_tensor(zero)
1197 mul_ifm = ifm.clone(op.name + "_negative", set_unique=True)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001198 if alpha < 0 and not is_converted_prelu:
1199 # For negative alpha that is not from a converted PReLU we need to use
1200 # int32 Mul below to perform the (negative) alpha scaling
1201 mul_ifm.dtype = DataType.int32
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001202 min_op.set_output_tensor(mul_ifm)
1203 min_op.set_ifm_ofm_shapes()
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001204 new_op = Op.Add
1205 op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001206 DebugDatabase.add_optimised(op, min_op)
1207
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001208 # Add multiplication with alpha
1209 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001210 mul_alpha.add_input_tensor(mul_ifm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001211 # Create const tensor containing alpha as scalar
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001212 quantization = ifm.quantization.clone()
1213 quantization.min = 0
1214 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
1215 quantization.zero_point = 0
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001216 alpha_dtype = mul_ifm.dtype
Fredrik Svedberg36424312022-09-16 09:39:26 +02001217 if is_converted_prelu:
1218 # The LeakyRelu was the result from convert_prelu and the scaling is provided
Fredrik Svedberg66591652022-08-29 10:51:27 +02001219 scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001220 mul_alpha.explicit_scaling = ExplicitScaling(False, [alpha_shift], [alpha_scale])
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001221 elif alpha == 0 or np.isinf(1 / alpha):
1222 # Handling of alpha near or at zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001223 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001224 scalar = 0
1225 else:
1226 quantization.scale_f32 = alpha
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001227 if alpha_dtype == DataType.int32:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001228 # When the datatype is int32 (alpha negative) we need to do the scaling with the multiplication
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001229 scalar, _ = scaling.elementwise_mul_scale(ifm.quantization.scale_f32, alpha, ofm.quantization.scale_f32)
1230 else:
1231 scalar = 1
Tim Hall3b1578e2023-01-13 17:57:25 +00001232 alpha_tens = create_const_tensor(op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], quantization=quantization)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001233 mul_alpha.add_input_tensor(alpha_tens)
1234 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1235 mul_alpha.set_output_tensor(fm_alpha)
1236 mul_alpha.set_ifm_ofm_shapes()
1237 DebugDatabase.add_optimised(op, mul_alpha)
1238
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001239 if not use_mul_max:
1240 relu_op = Operation(Op.Relu, op.name + "_relu")
1241 relu_op.add_input_tensor(ifm)
1242 fm_id = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1243 relu_op.set_output_tensor(fm_id)
1244 relu_op.set_ifm_ofm_shapes()
1245 DebugDatabase.add_optimised(op, relu_op)
1246 elif check_quantized_tens_scaling_equal(ifm, ofm):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001247 # No identity multiplication is needed
1248 fm_id = ifm
1249 else:
1250 # Add multiplication with identity
1251 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1252 mul_identity.add_input_tensor(ifm)
1253 # Create const tensor containing identity as scalar
1254 quantization = ifm.quantization.clone()
1255 quantization.min = 0
1256 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001257 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001258 quantization.zero_point = 0
Tim Hall3b1578e2023-01-13 17:57:25 +00001259 identity_tens = create_const_tensor(op.name + "_id_scalar", [], ifm.dtype, [1], quantization=quantization)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001260 mul_identity.add_input_tensor(identity_tens)
1261 # Make sure that fm_id is allocated to a different address than fm_alpha
1262 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1263 mul_identity.set_output_tensor(fm_id)
1264 mul_identity.set_ifm_ofm_shapes()
1265 DebugDatabase.add_optimised(op, mul_identity)
1266
1267 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001268 op.type = new_op
1269 op.name = op.name.replace("LeakyRelu", new_op.name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001270 op.inputs = []
1271 ifm.consumer_list.remove(op)
1272 op.add_input_tensor(fm_alpha)
1273 op.add_input_tensor(fm_id)
1274 op.set_ifm_ofm_shapes()
1275
1276 DebugDatabase.add_optimised(op, op)
1277 return op
1278
1279
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001280def convert_to_lut8(op, fn, fn_name):
1281 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
1282 # fn is a function(real) -> real
1283 ifm, ofm = op.get_ifm_ofm()
1284 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1285 return op
1286 # Generate the LUT
1287 ifm_scale = np.double(ifm.quantization.scale_f32)
1288 ofm_scale = np.double(ofm.quantization.scale_f32)
1289 zp_in = ifm.quantization.zero_point
1290 zp_out = ofm.quantization.zero_point
1291 values = []
1292 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1293 quantized_min = min(ix)
1294 quantized_max = max(ix)
1295 for x in ix:
1296 x_real = ifm_scale * (x - zp_in)
1297 y_real = fn(x_real)
1298 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1299 lut_result = min(quantized_max, max(quantized_min, lut_result))
1300 values.append(lut_result)
1301 return convert_to_lut(op, values, fn_name)
1302
1303
1304def convert_lrelu_to_lut(op, arch):
1305 ifm, ofm = op.get_ifm_ofm()
1306 # Generate the LUT
1307 alpha = op.attrs["alpha"]
1308 ifm_scale = np.double(ifm.quantization.scale_f32)
1309 ofm_scale = np.double(ofm.quantization.scale_f32)
1310 zp_in = ifm.quantization.zero_point
1311 zp_out = ofm.quantization.zero_point
1312 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1313 alpha_scalar = 1
1314 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1315 if "alpha_scaling" in op.attrs:
1316 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1317 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1318 values = []
1319 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1320 quantized_min = min(ix)
1321 quantized_max = max(ix)
1322 for x in ix:
1323 if x < zp_in:
1324 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1325 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1326 )
1327 else:
1328 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1329 lut_result = min(quantized_max, max(quantized_min, lut_result))
1330 values.append(lut_result)
1331 return convert_to_lut(op, values, "lrelu")
1332
1333
1334def convert_lrelu(op, arch, nng):
1335 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
1336 if op.type != Op.LeakyRelu:
1337 return op
1338 ifm, ofm = op.get_ifm_ofm()
1339 if ifm is None or ofm is None:
1340 return op
Fredrik Svedberg36424312022-09-16 09:39:26 +02001341 alpha = op.attrs["alpha"]
1342 if alpha == 0:
1343 # When alpha is 0 the opertion can be converted to a ReLU
1344 op.type = Op.Relu
1345 op.name = op.name.replace("LeakyRelu", op.type.name)
1346 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001347 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1348 # use LUT for int8/uint8
1349 return convert_lrelu_to_lut(op, arch)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001350 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16 and alpha > 0:
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001351 # use LeakyRelu unmodified for int16 with equal input/output scaling and positive alpha
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001352 return op
1353 return convert_lrelu_to_mul_max(op, arch)
1354
1355
1356def convert_tanh_sigmoid_to_lut(op, arch, nng):
1357 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
1358 if op.type == Op.Sigmoid:
1359 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1360 elif op.type == Op.Tanh:
1361 return convert_to_lut8(op, math.tanh, "tanh")
1362 return op
1363
1364
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001365def remove_memory_only_ops(op, arch):
1366 if op.run_on_npu and op.type in memory_only_ops:
1367 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001368
1369
1370def fuse_activation_function_with_prev(op, arch, nng):
1371 # if op is a no-op: attempts to move the activation function to the preceding op
1372 if not op.attrs.get("is_nop", False) or op.activation is None:
1373 return op
1374 ifm, ofm = op.get_ifm_ofm()
1375 if ifm is None or ofm is None:
1376 return op
1377 # finds the input(s) to the operation
1378 prev_op = ifm.ops[0]
1379 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1380 fuse = (
1381 prev_op.run_on_npu
1382 and prev_op.type.npu_block_type != NpuBlockType.Default
1383 and len(ifm.ops) == 1
1384 and len(prev_op.outputs[0].consumers()) == 1
1385 and prev_op.activation is None
1386 )
1387 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1388 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1389 # LUT currently only works correctly for elementwise ops
1390 fuse = False
1391 if not fuse:
1392 return op
1393 # Move the fused activation function + corresponding info to prev_op
1394 prev_op.activation = op.activation
1395 prev_op.forced_output_quantization = op.forced_output_quantization
1396 if op.activation_lut is not None:
1397 prev_op.set_activation_lut(op.activation_lut)
1398 # Bypass op
1399 prev_op.set_output_tensor(ofm)
wilisa0179a89042022-11-02 17:18:43 +00001400 DebugDatabase.add_optimised(prev_op, prev_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001401 return op
1402
1403
1404def _leading_pad_ok(leading_pad, stride, kernel_size):
1405 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1406 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1407 max_size = kernel_size // 2
1408 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1409
1410
1411def replace_pad_by_hw_pad(op: Operation, arch, nng):
1412 """
1413 Tries to completely remove a PAD operator by using hardware padding.
1414 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1415 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1416 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1417 if both operations can be run on the NPU.
1418 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1419 """
1420 if (
1421 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +00001422 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001423 and op.run_on_npu
1424 and op.attrs["padding"] == Padding.VALID
1425 ):
1426 pad_op = op.ifm.ops[0]
1427 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1428 return op
1429 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1430 return op
1431 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1432 k = op.kernel
1433 k_w, k_h = k.dilated_wh()
1434
1435 # Check if the PAD operator can be replaced by hardware padding
1436 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1437 # Too much padding, it would require hardware padding to actually insert zeros
1438 return op
1439 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1440 return op
1441
1442 if op.type.is_avgpool_op():
1443 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1444 for pad, k_size in (
1445 (left, k_w),
1446 (right, k_w),
1447 (top, k_h),
1448 (bottom, k_h),
1449 ):
1450 if pad not in (0, k_size // 2):
1451 return op
1452 # Average pool is converted to depthwise, because NPU average pool + same padding
1453 # has a special implementation that is different from PAD followed by average pool with
1454 # valid padding.
1455 k_w, k_h = op.kernel.width, op.kernel.height
1456 ifm = op.ifm
1457 # Remember other inputs
1458 other_inputs = op.inputs[1:]
1459 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1460 quantization = QuantizationParameters(0.0, 255.0)
1461 quantization.scale_f32 = 1.0 / (k_w * k_h)
1462 quantization.zero_point = 0
1463 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1464 weights = np.full(shape, 1)
1465
1466 weight_tens = create_const_tensor(
1467 op.name + "_weights",
1468 shape,
1469 op.ifm.dtype,
1470 weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001471 purpose=TensorPurpose.Weights,
1472 quantization=quantization,
1473 )
James Peet7519d502021-07-19 16:47:58 +01001474 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001475 op.type = Op.DepthwiseConv2DBias
1476 op.inputs = []
1477 op.add_input_tensor(ifm)
1478 op.add_input_tensor(weight_tens)
1479 # Add bias tensor, all biases set to 0
1480 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001481 fixup_bias_tensors(op, arch, nng, DataType.int32)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001482 # Add other inputs
1483 op.inputs.extend(other_inputs)
1484 op.rounding_mode = NpuRoundingMode.NATURAL
1485
1486 # Bypass the PAD operator
1487 op.set_input_tensor(pad_op.ifm, 0)
1488 # Adjust the padding attributes of the convolution operator
1489 op.attrs["padding"] = Padding.EXPLICIT
1490 op.attrs["explicit_padding"] = (top, left, bottom, right)
1491 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +00001492 DebugDatabase.add_optimised(op, op)
1493
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001494 return op
1495
1496
1497def convert_pad(op: Operation, arch, nng):
1498 """
1499 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1500 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1501 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1502 """
1503 if op.type != Op.Pad or not op.run_on_npu:
1504 return op
1505 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1506
1507 ifm = op.ifm
1508 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001509 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001510 ofm = op.ofm
1511 assert ofm is not None
1512 ofm.ops = []
1513 ofm_shape = op.ofm_shapes[0]
1514
1515 # Average pool op that copies IFM to the right place inside the OFM
1516 shp0 = Shape4D(0, 0, 0, 0)
1517 shp_top = shp0.with_height(top)
1518 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1519 avgpool_op.activation = op.activation
1520 quant = ofm.quantization
1521 pad_value = quant.zero_point
1522 # Add operations that fill the borders of the OFM
1523 if top > 0:
1524 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1525 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001526 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001527 )
1528 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1529 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1530 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1531 if bottom > 0:
1532 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1533 zero_tens = create_const_tensor(
1534 op.name + "_bottom",
1535 shape.as_list(),
1536 ofm.dtype,
1537 shape.elements() * [pad_value],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001538 quantization=quant,
1539 )
1540 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1541 create_avg_pool_for_concat(
1542 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1543 )
1544 if left > 0:
1545 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1546 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001547 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001548 )
1549 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1550 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1551 if right > 0:
1552 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1553 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001554 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001555 )
1556 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1557 create_avg_pool_for_concat(
1558 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1559 )
1560
1561 op.type = Op.ConcatTFLite
1562 return avgpool_op
1563
1564
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001565def fixup_bias_tensors(op, arch, nng, dtype=None):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001566 if op.type.needs_bias() and op.bias is None:
1567 # Op has no bias, add bias tensor filled with zeros
1568 nr_biases = op.inputs[1].shape[-1]
1569 bias_values = [0] * nr_biases
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001570 # The DataType of the bias tensor can be explicitly provided or deduced from the ifm
1571 # DataType. Default is int32 bias for 8-bit ifms and int64 for int16 ifms.
1572 # For int16 the selected bias DataType will have an impact on the scaling
1573 # used when encoding the scales and biases later. The default mode will match the
1574 # refence with reduced scaling for int64 bias.
1575 # This means that in cases (in the graph optimiser) where DepthwiseConv2DBias
1576 # is used to emulate average pool int32 bias should be selected for full precision
1577 # int16 scaling.
1578 if dtype is None:
1579 dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32
1580 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001581 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1582
1583 return op
1584
1585
wilisa0146c94772023-02-08 09:56:14 +00001586def detect_asymmetric_weights(op):
1587 # Check all ops (cpu and npu)
1588 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
1589 if op.ifm.dtype in (DataType.int8, DataType.int16):
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001590 if not np.all(op.weights.quantization.zero_point == 0):
wilisa0146c94772023-02-08 09:56:14 +00001591 print(f"Warning: Op {op.type} '{op.name}' has asymmetric weights.", end=" ")
1592 return True
1593 return False
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001594
wilisa0146c94772023-02-08 09:56:14 +00001595
1596def fixup_asymmetric_weights(op, arch, nng):
1597 if detect_asymmetric_weights(op):
1598 if op.run_on_npu:
1599 print("Zero points have been adjusted.")
1600 op.weights.quantization.zero_point *= 0
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001601 return op
1602
1603
wilisa0146c94772023-02-08 09:56:14 +00001604def check_asymmetric_weights(op, arch, nng):
1605 # This function can modify the run_on_npu flag which causes an operator to be placed on the CPU. It is usually only
1606 # set by the supported operator checks. Therefore, it should be run immediately after those checks to avoid the
1607 # possibility of other graph optimiser functions modify the operator (that is later run on the CPU)
1608 if detect_asymmetric_weights(op):
1609 if op.run_on_npu:
1610 print("To run the operator on Ethos-U use the option --force-symmetric-int-weights")
1611 op.run_on_npu = False
1612 return op
1613
1614
1615def fixup_or_check_asymmetric_weights(force_symmetric_int_weights):
1616 if force_symmetric_int_weights:
1617 return fixup_asymmetric_weights
1618 else:
1619 return check_asymmetric_weights
1620
1621
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001622def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1623 if op.type == Op.Mean and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001624 inp, axis = op.inputs
1625 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001626 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001627 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001628 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001629
1630 # Height and width axes have different index depending on dimensions
1631 if axis.shape == [] or axis.shape[0] == 1: # single axis
1632 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1633 if dims in (2, 3):
1634 if axis == 0:
1635 h, w = shape[axis], 1
1636 else:
1637 h, w = 1, shape[axis]
1638 else:
1639 if axis == 1:
1640 h, w = shape[axis], 1
1641 else:
1642 h, w = 1, shape[axis]
1643 else: # multiple axes
1644 axis = sorted(axis.values)
1645 h, w = [shape[i] for i in axis]
1646
1647 # Set necessary depthwise attributes
1648 op.attrs.update(
1649 {
1650 "padding": Padding.VALID,
1651 "stride_h": 1,
1652 "stride_w": 1,
1653 "strides": (1, 1, 1, 1),
1654 "depth_multiplier": 1,
1655 "channel_multiplier": 1,
1656 "dilation_h_factor": 1,
1657 "dilation_w_factor": 1,
1658 "dilation": (1, 1, 1, 1),
1659 }
1660 )
1661 # Change op type
1662 op.type = Op.DepthwiseConv2DBias
1663 # Set IFM/OFM shapes after changing op type
1664 op.set_ifm_ofm_shapes()
1665
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001666 weight_scale, bias = 1, 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001667 ofmq, ifmq = op.ofm.quantization, inp.quantization
Johan Alfvén9d51ec42022-10-27 16:30:01 +02001668 if ifmq.is_scaling_equal(ofmq):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001669 # Here we can just use a simple AvgPool with truncating rounding,
1670 # as we're emulating simple integer division.
1671 op.rounding_mode = NpuRoundingMode.TRUNCATE
1672 op.type = Op.AvgPool
1673 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1674 else:
1675 op.rounding_mode = NpuRoundingMode.NATURAL
1676 weight_scale = 1 / (h * w)
1677 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1678 bias = -ifmq.zero_point * h * w
1679 fiq = ifmq.clone()
1680 fiq.zero_point = 0
1681 op.forced_input_quantization = fiq
1682
1683 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001684 def extend_dims(dim, in_shape):
1685 if dim < 4:
1686 in_shape = [1] + in_shape
1687 if dim == 2:
1688 in_shape += [1]
1689 return in_shape
1690
1691 if dims < 4 or dims_ofm < 4:
1692 # Fix the ofm dimension when keep_dims is false
1693 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1694 if isinstance(axis, int) and dims_ofm + 1 == dims:
1695 ofm_shape.insert(axis, 1)
1696 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1697 for i in axis:
1698 ofm_shape.insert(i, 1)
1699 shape = extend_dims(dims, shape)
1700 dims_ofm = len(ofm_shape)
1701 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1702 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001703
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001704 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001705 weight_shape = None
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001706 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001707 # This can only happen and be done for multiple axes, and
1708 # h * w <= 256 for DepthwiseConv2DBias
1709 # h * w <= 4096 for AvgPool
1710 # which is checked in supported ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001711 shape = [shape[0], 1, h * w, shape[3]]
1712 op.ifm_shapes[0] = Shape4D(shape)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001713 weight_shape = [1, h * w, shape[3], shape[0]]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001714 if h > 256 and op.type == Op.AvgPool:
1715 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1716
1717 # If the AvgPool version is used, we don't need to do anything else
1718 if op.type == Op.AvgPool:
wilisa0179a89042022-11-02 17:18:43 +00001719 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001720 return op
1721
1722 # Make unit weight tensor quantization
1723 weight_quant = ifmq.clone()
1724 weight_quant.min = 0
1725 weight_quant.max = 255
1726 weight_quant.scale_f32 = weight_scale
1727 weight_quant.zero_point = 0
1728
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001729 if weight_shape is None:
1730 # Set weight shape to [H,W,C,B]
1731 weight_shape = [h, w, shape[3], shape[0]]
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001732
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001733 # Add unit weight tensor
1734 op.set_input_tensor(
1735 create_const_tensor(
1736 "weights",
1737 weight_shape,
1738 inp.dtype,
1739 np.ones(weight_shape),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001740 quantization=weight_quant,
1741 ),
1742 1,
1743 )
James Peet7519d502021-07-19 16:47:58 +01001744 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001745
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001746 # Add bias tensor
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001747 bias_shape = [shape[-1]]
1748 op.inputs.append(create_const_tensor("bias", bias_shape, DataType.int32, np.ones(bias_shape) * bias))
wilisa0179a89042022-11-02 17:18:43 +00001749 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001750
1751 return op
1752
1753
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001754def optimise_quantize(op: Operation, arch, nng):
1755
1756 if op.type == Op.Quantize and op.run_on_npu:
1757
1758 ifm, ofm = op.get_ifm_ofm()
1759 input_values = ifm.values
1760
1761 # Guard clause - input not const or no values to quantize
1762 if ifm.ops[0].type != Op.Const or input_values is None:
1763 return op
1764
1765 # Singular val in numpy array, convert to indexable array
1766 if input_values.ndim == 0:
1767 input_values = np.array([input_values])
1768
Fredrik Svedberg11563172022-07-06 14:54:12 +02001769 # requantized int8 to int8 or int16 to int16
1770 if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001771
1772 # scale needs to use double precision to match TFLite reference kernel
1773 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
1774 effective_multiplier, effective_shift = quantise_scale(effective_scale)
1775
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001776 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001777 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001778 input_val = val - ifm.quantization.zero_point
1779
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001780 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
1781 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001782
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001783 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
1784 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001785
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001786 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
1787 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001788
1789 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001790 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001791
1792 quantized_vals = []
1793 for val in input_values:
1794
1795 # Derive quantized value
1796 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001797 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
1798 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001799
1800 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001801 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
1802
1803 # Unsupported data type
1804 else:
1805 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001806
1807 # Make quantize op const and disconnect from parent node
1808
1809 # Remove reference of the current quant op from the parent tensor's consumer list
1810 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1811
1812 # Clear any references to parent node
1813 op.inputs = []
1814
1815 # Convert this quantize op to const
1816 op.type = Op.Const
1817
1818 return op
1819
1820
Ayaan Masood4965fae2022-06-29 11:30:57 +01001821def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
1822 """Static optimisation for SHAPE operator output value known at compile time"""
1823
1824 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
1825
1826 if op.type == Op.Shape and op.run_on_npu:
1827
1828 ifm, ofm = op.get_ifm_ofm()
1829
1830 if len(ifm.shape) != ofm.shape[0]:
1831 return op
1832
1833 # Remove reference of the current shape op from the parent tensor's consumer list
1834 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1835
1836 # Clear any references to parent node
1837 op.inputs = []
1838
1839 # Convert this SHAPE op to const
1840 op.type = Op.Const
1841
1842 # Add size calculation to shape output tensors
1843 ofm.values = np.array(ifm.shape)
1844
1845 return op
1846
1847
Tim Hallea4ba662022-11-11 18:19:53 +00001848def fixup_dilation_gt2(op, arch, nng):
1849 assert op.run_on_npu
1850 if op.type == Op.Conv2DBias or op.type == Op.DepthwiseConv2DBias:
1851 dilation_w, dilation_h = op.get_kernel_dilation()
1852
1853 # if dilation in either axis is greater than that supported by the hardware then we must manually dilate the
1854 # kernel
1855 if dilation_w > 2 or dilation_h > 2:
1856 kernel_w, kernel_h = op.get_kernel_size()
1857 kernel_ic = op.weights.shape[-2]
1858 kernel_oc = op.weights.shape[-1]
1859
1860 # if the dilation is a multiple of 2 then the hardware dialtion can be enabled to provide that multiple
1861 # of 2. this allows the kernel size to be reduced (via the scaled dilation) by half in that dimension.
1862 # odd = 1, even = 2
1863 hw_dilation_h = 1 if (dilation_h & 1) else 2
1864 hw_dilation_w = 1 if (dilation_w & 1) else 2
1865
1866 scale_dilation_h = dilation_h // hw_dilation_h
1867 scale_dilation_w = dilation_w // hw_dilation_w
1868
1869 # create new empty kernel (HWIO format)
1870 new_kernel_h = (kernel_h - 1) * scale_dilation_h + 1
1871 new_kernel_w = (kernel_w - 1) * scale_dilation_w + 1
1872
1873 new_kernel_shape = [new_kernel_h, new_kernel_w, kernel_ic, kernel_oc]
1874 new_kernel_values = np.zeros(new_kernel_shape, dtype=op.weights.values.dtype)
1875
1876 # copy the original kernel values into the new sparse kernel
1877 for h in range(0, kernel_h):
1878 for w in range(0, kernel_w):
1879 new_h = h * scale_dilation_h
1880 new_w = w * scale_dilation_w
1881 new_kernel_values[new_h, new_w, :, :] = op.weights.values[h, w, :, :]
1882
1883 # update the weight tensor with the new dilated kernel
1884 op.weights.shape = new_kernel_shape
1885 op.weights.values = new_kernel_values
1886
1887 # enable(=2) / disable(=1) hardware dilation
1888 op.attrs["dilation"] = (1, hw_dilation_h, hw_dilation_w, 1) # nhwc format
1889 op.attrs["dilation_h_factor"] = hw_dilation_h
1890 op.attrs["dilation_w_factor"] = hw_dilation_w
1891
1892 return op
1893
1894
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001895def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001896 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001897 return op
1898
1899
wilisa0146c94772023-02-08 09:56:14 +00001900def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
Fredrik Svedberg11563172022-07-06 14:54:12 +02001901 # Compile time static optimisations
wilisa0146c94772023-02-08 09:56:14 +00001902 optimisation_list = [
1903 optimise_quantize,
1904 convert_shape_op_to_constant_tensor,
1905 fixup_or_check_asymmetric_weights(force_symmetric_int_weights),
1906 ]
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001907
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001908 for idx, sg in enumerate(nng.subgraphs):
1909 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001910 nng,
1911 sg,
1912 arch,
1913 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01001914 optimisation_list,
1915 rewrite_unsupported=False,
1916 )
1917
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001918 # Pre-processing step
wilisa0146c94772023-02-08 09:56:14 +00001919 pre_process_list = [supported_operator_check, set_ifm_ofm_op_shapes]
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001920
Ayaan Masood4965fae2022-06-29 11:30:57 +01001921 for idx, sg in enumerate(nng.subgraphs):
1922 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1923 nng,
1924 sg,
1925 arch,
1926 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02001927 pre_process_list,
1928 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001929 )
1930
1931 # Handle Concat Ops
1932 for idx, sg in enumerate(nng.subgraphs):
1933 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1934 sg.refresh_after_modification()
1935
1936 # Handle Split Ops
1937 for idx, sg in enumerate(nng.subgraphs):
1938 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1939 nng,
1940 sg,
1941 arch,
1942 [],
1943 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1944 rewrite_unsupported=False,
1945 )
1946
1947 for idx, sg in enumerate(nng.subgraphs):
1948 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001949 nng,
1950 sg,
1951 arch,
1952 [rewrite_split_ops],
1953 [],
1954 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001955 )
1956
1957 # Handle sg input output
1958 for idx, sg in enumerate(nng.subgraphs):
1959 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001960 nng,
1961 sg,
1962 arch,
1963 [],
1964 [fix_sg_input_output],
1965 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001966 )
1967
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001968 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001969 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001970 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001971 sg.refresh_after_modification()
1972
1973 # Rewrite of operators
1974 op_rewrite_list = [
1975 set_tensor_equivalence,
1976 convert_mean_to_depthwise_conv_or_avgpool,
1977 convert_depthwise_to_conv,
1978 convert_conv_to_fc,
1979 convert_softmax,
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001980 convert_prelu,
Fredrik Svedberg36424312022-09-16 09:39:26 +02001981 convert_mul_max_to_abs_or_lrelu,
1982 convert_lrelu,
Raul Farkas090f18a2023-01-24 16:29:06 +00001983 fixup_strided_conv,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001984 convert_hardswish_to_lut,
1985 rewrite_fully_connected_input,
1986 convert_batched_fc_shape,
1987 fixup_conv2d_backprop,
1988 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001989 reorder_depthwise_weights,
Tim Hall885033b2022-07-21 11:46:03 +01001990 fixup_resize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001991 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001992 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001993 convert_tanh_sigmoid_to_lut,
1994 replace_pad_by_hw_pad,
Tim Hallea4ba662022-11-11 18:19:53 +00001995 fixup_dilation_gt2,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001996 ]
1997
1998 for idx, sg in enumerate(nng.subgraphs):
1999 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002000 nng,
2001 sg,
2002 arch,
2003 [],
2004 op_rewrite_list,
2005 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002006 )
2007
2008 for idx, sg in enumerate(nng.subgraphs):
2009 # remove passthrough tensors and attempt further optimizations
2010 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
2011 nng,
2012 sg,
2013 arch,
2014 [remove_passthrough_tensor],
2015 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
2016 )
2017
2018 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
2019 # since ifm/ofm_shapes are of importance to this function
2020 for sg in nng.subgraphs:
2021 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
2022 sg.refresh_after_modification()
2023
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01002024 # Make sure that const optimisations on subgraph outputs are handled correctly
2025 for sg in nng.subgraphs:
2026 for ofm in sg.output_tensors:
2027 if ofm.is_const and ofm.ops[0].type_changed:
2028 # Subgraph output cannot be const - insert a memory copy
2029 op = ofm.ops[0]
2030 ofm_clone = ofm.clone()
2031 ofm_clone.values = ofm.values
2032 ofm.values = None
Tim Hall3b1578e2023-01-13 17:57:25 +00002033 zero = create_const_tensor("zero", [1], ofm.dtype, [0], quantization=ofm.quantization)
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01002034 memcpy = create_add_nop(f"{ofm.name}_copy")
2035 memcpy.add_input_tensor(ofm_clone)
2036 memcpy.add_input_tensor(zero)
2037 memcpy.set_output_tensor(ofm)
2038 memcpy.set_ifm_ofm_shapes()
2039 op.set_output_tensor(ofm_clone)
2040 DebugDatabase.add_optimised(op, memcpy)
2041
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002042 return nng