blob: 73137feb16da212e00c56bcf4bef298860b0c269 [file] [log] [blame]
Tim Hall3b1578e2023-01-13 17:57:25 +00001# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020017# Description:
18# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
19# to do the traversal of the graph.
20import math
21import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020022
23import numpy as np
24
25from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020026from . import rewrite_graph
27from . import scaling
28from .api import NpuRoundingMode
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020029from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020030from .data_type import DataType
31from .debug_database import DebugDatabase
32from .errors import UnsupportedFeatureError
33from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020034from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020035from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020036from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020037from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020038from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020039from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020040from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020041from .graph_optimiser_util import needed_total_padding
42from .graph_optimiser_util import set_ifm_ofm_op_shapes
43from .graph_optimiser_util import set_tensor_equivalence
44from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020045from .numeric_util import round_away_zero
46from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020047from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020048from .operation import NpuBlockType
49from .operation import Op
50from .operation import Operation
51from .operation import Padding
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +010052from .operation_util import create_add_nop
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020053from .operation_util import create_avgpool_nop
54from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010055from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020056from .shape4d import Shape4D
57from .softmax import SoftMax
58from .tensor import check_quantized_tens_scaling_equal
59from .tensor import create_const_tensor
60from .tensor import create_equivalence_id
61from .tensor import QuantizationParameters
62from .tensor import Tensor
63from .tensor import TensorPurpose
64from .tflite_mapping import optype_to_builtintype
65
66passthrough_nodes = (Op.Identity,)
67
68
69def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
70 """Creates an average pool for the given concat op/input feature map"""
71 ofm = concat_op.ofm
72 avgpool_op = create_avgpool_nop(name)
73 avgpool_op.inputs = [ifm]
74 avgpool_op.outputs = [ofm]
75
76 avgpool_op.write_offset = write_offset
77 avgpool_op.write_shape = ifm_shape
78 ofm.ops.append(avgpool_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020079 avgpool_op.ifm_shapes.append(ifm_shape)
80 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
81 avgpool_op.memory_function = Op.ConcatSliceWrite
wilisa0179a89042022-11-02 17:18:43 +000082 DebugDatabase.add_optimised(concat_op, avgpool_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020083 return avgpool_op
84
85
86def remove_passthrough_tensor(tens, arch, nng):
87 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
88 assert len(tens.ops[0].inputs) == 1
89 tens = tens.ops[0].inputs[0]
90 return tens
91
92
93def rewrite_concat_ops(op, arch):
94 if not op.run_on_npu or not op.type.is_concat_op():
95 return
96
97 axis_4D = 0
98 ofm = op.ofm
99 ofm.ops = []
100 offset = 0
101
102 unfuse_activation_function(op)
103
104 if op.type == Op.Pack:
105 # Pack is also referred to as Stack
106 axis = int(op.attrs["axis"])
107 if axis < 0: # Convert to positive axis
108 axis = len(op.inputs[0].shape) + 1 + axis
109
110 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
111
112 axis_4D = axis + (4 - len(desired_shape))
113
114 for idx, inp in enumerate(op.inputs):
115 op.ifm_shapes[idx] = Shape4D(desired_shape)
116 op.type = Op.PackReshaped
117
118 inputs, axis = op.get_concat_inputs_axis()
119 for idx, inp in enumerate(inputs):
120 if op.type != Op.PackReshaped:
121 op.ifm_shapes[idx] = Shape4D(inp.shape)
122 if axis >= 0:
123 axis_4D = axis + (4 - len(inp.shape))
124 else:
125 axis_4D = axis
126 write_offset = [0, 0, 0, 0]
127 write_offset[axis_4D] = offset
128 concat_end = offset + op.ifm_shapes[idx][axis_4D]
129 create_avg_pool_for_concat(
130 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
131 )
132 offset = concat_end
133 assert ofm.shape[axis] == offset
134
135 return op
136
137
138def rewrite_split_ops(tens, arch, nng):
139
140 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
141 split_op = tens.ops[0]
142
143 # Not supported so leave it and run on CPU
144 if not split_op.run_on_npu:
145 return tens
146
147 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
148
149 tens.ops = []
150 new_op = Operation(Op.SplitSliceRead, split_op.name)
151 new_op.inputs = [inp]
152 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000153 if None in (offset_end, offset_start):
154 read_shape = None
155 else:
156 # the read shape is relative to each start offset
157 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200158
159 # For Split the offset cannot be extracted from the tensor so it has to
160 # be calculated from the index of the output tensor
161 if axis is not None:
162 # Get the start and end of the split
163 offset_start = [0] * 4
164 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
165 for idx, out in enumerate(outputs):
166 if axis_4D_list is not None:
167 axis_4D = axis_4D_list[idx]
168 else:
169 split_op.ofm_shapes[idx] = Shape4D(out.shape)
170 if axis >= 0:
171 axis_4D = axis + (4 - len(out.shape))
172 else:
173 axis_4D = axis
174
175 if out == tens:
176 ofm_shape_idx = idx
177 read_shape = split_op.ofm_shapes[idx]
178 break
179
180 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
181
182 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
183 new_op.read_shapes[0] = read_shape
184 new_op.run_on_npu = True
185 new_op.set_output_tensor(tens)
186 new_op.ifm_shapes.append(Shape4D(inp.shape))
187 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
188 DebugDatabase.add_optimised(split_op, new_op)
189
190 return tens
191
192
193def remove_SplitSliceRead(op, arch):
194
195 if op.type == Op.SplitSliceRead:
196 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
197 if (
198 len(op.ofm.consumer_list) == 1
199 and op.ofm.consumer_list[0] is not None
200 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200201 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200202 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
203 ):
204 # SplitSliceRead can be performed by tensor consumer
205 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200206 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200207 else:
208 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
209 avgpool_op.add_input_tensor(op.ifm)
210 avgpool_op.outputs = [op.ofm]
211 op.ofm.ops.remove(op)
212 op.ofm.ops.append(avgpool_op)
213 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
214 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
215 avgpool_op.read_offsets[0] = op.read_offsets[0]
216 avgpool_op.read_shapes[0] = op.read_shapes[0]
217
218 op.ifm.consumer_list.remove(op)
219 DebugDatabase.add_optimised(op, avgpool_op)
220
221
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200222def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
223 k_w, k_h = kernel.dilated_wh()
224 s_x, s_y = kernel.stride
225 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
226 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
227 if padding_type == Padding.SAME:
228 left_pad = (xpad + 0) // 2
229 right_pad = (xpad + 1) // 2
230 top_pad = (ypad + 0) // 2
231 bottom_pad = (ypad + 1) // 2
232 elif padding_type == Padding.VALID:
233 left_pad = 0
234 right_pad = 0
235 top_pad = 0
236 bottom_pad = 0
237 elif padding_type == Padding.EXPLICIT:
238 # Padding is specified in a PAD operator which has been bypassed.
239 top, left, bottom, right = explicit_padding
240 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
241 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Rickard Bolin9ae34552022-06-09 13:07:17 +0000242 elif padding_type == Padding.TILE:
243 # The values in the explicit padding only represent the "direction" in which to pad
244 top_pad, left_pad, bottom_pad, right_pad = explicit_padding
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200245 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000246 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200247 padding = (top_pad, left_pad, bottom_pad, right_pad)
248 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
249 return padding, skirt
250
251
252def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
253 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
254 if padding_type == Padding.SAME:
255 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
256 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
257 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
258 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
259 left_pad = max(kernel_width - 1 - right_pad, 0)
260 top_pad = max(kernel_height - 1 - bottom_pad, 0)
261 elif padding_type == Padding.VALID:
262 right_pad = max(kernel_width - 2, 0)
263 bottom_pad = max(kernel_height - 2, 0)
264 left_pad = kernel_width - 1
265 top_pad = kernel_height - 1
266 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000267 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200268 padding = (top_pad, left_pad, bottom_pad, right_pad)
269 skirt = padding
270 return padding, skirt
271
272
273def fixup_conv2d_backprop(op, arch, nng):
274 if op.type == Op.Conv2DBackpropInput:
275 # flip the inputs
276 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
277 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000278 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200279
280 # Update strides
281 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
wilisa0179a89042022-11-02 17:18:43 +0000282 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200283
284 return op
285
286
287# Convert the op to an elementwise add
Tim Hall885033b2022-07-21 11:46:03 +0100288def convert_resize_1x1_to_add(op):
289 op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200290 op.name = op.name + "_add"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200291 # Create an input tensor filled with zeros
292 shape = op.ofm_shapes[0].as_list()
293 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100294 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200295 tens.quantization = QuantizationParameters(0.0, 255.0)
296 tens.quantization.scale_f32 = 1.0
297 tens.quantization.zero_point = 0
298 tens.consumer_list = [op]
299 tens_op = op.inputs[1].ops[0]
300 tens_op.set_output_tensor(tens)
301 # Set the add inputs
302 op.inputs[1] = op.inputs[0]
303 op.inputs[0] = tens
304 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000305 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200306
307 return op
308
309
Tim Hall885033b2022-07-21 11:46:03 +0100310# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
311# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
312# to select the appropriate nearest neighbor value
313def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
314 ifm = op.ifm
315 ofm = op.ofm
316 output_depth = ofm.shape[-1]
317 dw_op_attrs = {
318 "padding": Padding.VALID,
319 "stride_h": 1,
320 "stride_w": 1,
321 "strides": (1, 1, 1, 1),
322 "depth_multiplier": 1,
323 "channel_multiplier": 1,
324 "dilation_h_factor": 1,
325 "dilation_w_factor": 1,
326 "dilation": (1, 1, 1, 1),
327 }
328
329 # change resizebilinear to depthwise
330 op.type = Op.DepthwiseConv2DBias
331 op.attrs.update(dw_op_attrs)
332 op.set_input_tensor(ifm, 0) # ifm tensor index
333 op.activation = None
334
335 # add input resample to resize by x2
336 op.ifm_resampling_mode = resampling_mode.NEAREST
337
338 # don't care about the rounding mode as it is nearest neighbor
339
340 # setup weight tensor
341 weight_quant = QuantizationParameters()
342 weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
343 weight_quant.zero_point = 0
344 weight_quant.quant_dim = 0
345 ofm_dtype = ofm.dtype
Tim Hall3b1578e2023-01-13 17:57:25 +0000346 if ofm_dtype.type == BaseType.UnsignedInt:
Tim Hall885033b2022-07-21 11:46:03 +0100347 weight_quant.quant_min = 0
348 weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
349 else:
Tim Hall885033b2022-07-21 11:46:03 +0100350 weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
351 weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
352
353 weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
354
355 # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
356 # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
357 # below-and-right (i.e. next) to it (D).
358 # 0---1---2
359 # | A | B |
360 # 1---*---+
361 # | C | D |
362 # 2---+---+
363 weight_values = [0] * (upscale_factor * upscale_factor)
364 centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
365 weight_values[centre_coeff] = 1
366
367 # add weight tensor, this will discard the size tensor of the resize op
368 op.set_input_tensor(
369 create_const_tensor(
370 "weights",
371 weight_shape,
Tim Hall3b1578e2023-01-13 17:57:25 +0000372 ofm_dtype,
Tim Hall885033b2022-07-21 11:46:03 +0100373 np.array(weight_values).reshape(weight_shape),
Tim Hall885033b2022-07-21 11:46:03 +0100374 quantization=weight_quant,
375 ),
376 1, # inputs tensor weight index
377 )
378
379 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
380 # need to append the bias tensor as resize ops only have 2 inputs
381 assert len(op.inputs) == 2
382 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +0200383 fixup_bias_tensors(op, None, None, DataType.int32)
Tim Hall885033b2022-07-21 11:46:03 +0100384
385 # finally update the shape incase we've change the tensor shapes or connections
386 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000387 DebugDatabase.add_optimised(op, op)
Tim Hall885033b2022-07-21 11:46:03 +0100388
389 return op
390
391
392# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
393# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
394# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
395def convert_resize_to_upscale_and_average_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200396 pre_op = op
397 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000398 dtype = op.ifm.dtype
Tim Hall885033b2022-07-21 11:46:03 +0100399
Rickard Boline546def2022-01-25 15:45:00 +0000400 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Tim Hall47c76362022-07-18 21:26:47 +0100401 op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
Tim Hall3c5cfe92022-03-16 16:31:57 +0000402 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200403
404 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
Tim Hall47c76362022-07-18 21:26:47 +0100405
406 # Get upscale factor that was calculated in the supported operators check
407 upscale_factor = op.attrs["upscale_factor"]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200408
Rickard Boline546def2022-01-25 15:45:00 +0000409 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100410 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
411 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
Rickard Boline546def2022-01-25 15:45:00 +0000412 n = int(np.log2(upscale_factor))
413
Tim Hall885033b2022-07-21 11:46:03 +0100414 # Perform x2 upscaling n-1 times
Rickard Boline546def2022-01-25 15:45:00 +0000415 scaled_op = pre_op
416 for count in range(n - 1):
417 if count > 0:
418 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200419 scaled_op.inputs[0] = pre_op.outputs[0]
420
Tim Hall885033b2022-07-21 11:46:03 +0100421 # Nearest neighbor x2 upscaling
Tim Hall47c76362022-07-18 21:26:47 +0100422 upscaled_shape = upscaled_shape * 2
Rickard Boline546def2022-01-25 15:45:00 +0000423 shape = op.ofm_shapes[0].as_list()
424 shape[1:3] = upscaled_shape
425 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
426 out_tens.quantization = op.outputs[0].quantization.clone()
427 scaled_op.set_output_tensor(out_tens)
428 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200429
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200430 scaled_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000431 DebugDatabase.add_optimised(op, scaled_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200432
Tim Hall885033b2022-07-21 11:46:03 +0100433 # Last x2 upscaling
Rickard Boline546def2022-01-25 15:45:00 +0000434 if n > 1:
435 scaled_op = op.clone(f"_{n-1}")
436 scaled_op.inputs[0] = pre_op.outputs[0]
Tim Hall885033b2022-07-21 11:46:03 +0100437
438 if scaled_op.original_type == Op.ResizeBilinear:
439 if scaled_op.attrs["align_corners"]:
440 # no padding
441 scaled_op.attrs["padding"] = Padding.VALID
442 else:
443 # padding to the right and bottom (limits average pool to 8x8 kernel)
444 scaled_op.attrs["padding"] = Padding.EXPLICIT
445 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
446
447 # kernal size dependent on the upscaling factor
448 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
449 else: # Op.ResizeNearestNeighbor
450 if scaled_op.attrs["align_corners"]:
451 # use depthwise conv to select the correct value
452 scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
453 else:
Johan Alfvéna64616c2022-10-17 12:29:12 +0200454 # Keep 1x1 kernel and average pool, this applies both when
455 # half-pixel-centers is True and False. Calculations are the
456 # same in the reference.
Tim Hall885033b2022-07-21 11:46:03 +0100457 pass
458
Rickard Boline546def2022-01-25 15:45:00 +0000459 scaled_op.outputs = outputs
460 scaled_op.outputs[0].ops = [scaled_op]
461 scaled_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000462 DebugDatabase.add_optimised(op, scaled_op)
Rickard Boline546def2022-01-25 15:45:00 +0000463
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200464 return op
465
466
Rickard Bolinfea15162022-07-04 16:19:16 +0000467def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
468 def _compute_interpolation_values(index, input_size, output_size):
469 scale = input_size / output_size
470 scaled_value = (index + 0.5 * half_pixel_centers) * scale - 0.5 * half_pixel_centers
471 lower_bound = max(np.floor(scaled_value), 0)
472
473 return scaled_value, lower_bound
474
475 def _compute_kernels(input_height, input_width, output_height, output_width):
476 kernels = []
477 for y in (1, 2):
478 for x in (1, 2):
479 sv_h, lb_h = _compute_interpolation_values(y, input_height, output_height)
480 sv_w, lb_w = _compute_interpolation_values(x, input_width, output_width)
481
482 # Interpolation values calculated for (x, y) = ([1, 2], [1, 2]) will always generalize to the whole
483 # input for upscale = 2 and input sizes >= 2x2 and be in the correct order for going left-to-right,
484 # top-to-bottom - same as the depthwise convolution strides across each tile
485 kernel = np.zeros((2, 2))
486 kernel[1, 1] = (1 - (sv_h - lb_h)) * (1 - (sv_w - lb_w))
487 kernel[0, 1] = (sv_h - lb_h) * (1 - (sv_w - lb_w))
488 kernel[1, 0] = (1 - (sv_h - lb_h)) * (sv_w - lb_w)
489 kernel[0, 0] = (sv_h - lb_h) * (sv_w - lb_w)
490 kernel *= 16
491 kernels.append(kernel)
492
493 return kernels
494
495 def _build_convolutions(op, kernels):
496 dw_op_attrs = {
497 "padding": Padding.TILE,
498 "stride_h": 1,
499 "stride_w": 1,
500 "strides": (1, 1, 1, 1),
501 "depth_multiplier": 1,
502 "channel_multiplier": 1,
503 "dilation_h_factor": 1,
504 "dilation_w_factor": 1,
505 "dilation": (1, 1, 1, 1),
506 }
507 ifm = op.ifm
508 ofm = op.ofm
509 ofm.ops = []
510 elem_size = 2 if ofm.dtype == DataType.int16 else 1
511
512 n, h, w, c = ifm.shape
513 _, _, ow, _ = ofm.shape
514
515 intermediate_tens = Tensor(ifm.shape, ifm.dtype, "intermediate_tens")
516 intermediate_tens.quantization = op.outputs[0].quantization.clone()
517 avgpool_op = op
518 avgpool_op.name = "rb_init_avgpool"
519 avgpool_op.type = Op.AvgPool
520 avgpool_op.attrs["padding"] = Padding.VALID
521 avgpool_op.attrs["stride_w"] = 1
522 avgpool_op.attrs["stride_h"] = 1
523 avgpool_op.attrs["filter_width"] = 1
524 avgpool_op.attrs["filter_height"] = 1
525 avgpool_op.attrs["strides"] = [1, 1, 1, 1]
526 avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
527
528 avgpool_op.add_input_tensor(ifm)
529 avgpool_op.set_output_tensor(intermediate_tens)
530 avgpool_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000531 DebugDatabase.add_optimised(op, op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000532
533 dw_conv = Operation(Op.DepthwiseConv2DBias, "depthwise_conv")
534 dw_conv._original_type = Op.ResizeBilinear
535 dw_conv.write_shape = Shape4D(n, h, w, c)
536 dw_conv.write_offset = Shape4D(0, 0, 0, 0)
537
538 # Set the output rounding mode. Resize bilinear requires rounding away from zero. Therefore, we need to
539 # adjust the accumulated value by a "small" amount before applying natural rounding. The "small" amount
540 # should be big enough to cause a x.5 to be rounded correctly but small enough not to cause smaller
541 # values to be incorrectly rounded
542 ofm.quantization.next_after = True
543 dw_conv.rounding_mode = NpuRoundingMode.NATURAL
544
545 # Double height and width stride to write the output of each of the four depthwise convolutions below
546 # interleaved with each other when combined with OFM tile base offsets.
547 dw_conv.ofm_stride_multiplier = [1, 2, 2] # C/H/W
548
549 # Choose tile padding direction - pad by 1 with edge values in two direction.
550 # For example, TL (top left) will pad top and left in H/W-plane in all channels.
551 directions = [[1, 1, 0, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 1]] # TL, TR, BL, BR
552 for i in (0, 1):
553 for j in (0, 1):
554 index = i * 2 + j
555 dw_conv.name = f"depthwise_conv_{index}"
556 dw_op_attrs["explicit_padding"] = directions[index]
557 dw_conv.attrs.update(dw_op_attrs)
558
559 # This will offset the start of the write by modifying the Tile 0 base address
560 dw_conv.tile_base_offsets_ofm[0] = (i * ow + j) * c * elem_size
561
562 ofm.ops.append(dw_conv)
563 dw_conv.outputs = [ofm]
564
565 kernel = kernels[index]
566 shape = [2, 2, 1, c]
567 kernel = np.dstack([kernel] * c)
568
569 quant = QuantizationParameters()
570 quant.zero_point = 0
571 quant.scale_f32 = 1.0 / 16
572
573 dw_conv.inputs = []
574 dw_conv.add_input_tensor(intermediate_tens)
575 dw_conv.add_input_tensor(
576 create_const_tensor(
577 "weights",
578 shape,
579 intermediate_tens.dtype,
580 np.array(kernel).reshape(shape),
Rickard Bolinfea15162022-07-04 16:19:16 +0000581 quantization=quant,
582 ),
583 )
584
585 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
586 # need to append the bias tensor as resize ops only have 2 inputs
587 assert len(dw_conv.inputs) == 2
588 dw_conv.inputs.append(None)
Rickard Bolin017b4cc2022-09-23 10:16:48 +0000589 fixup_bias_tensors(dw_conv, None, None, dtype=DataType.int32)
Rickard Bolinfea15162022-07-04 16:19:16 +0000590
591 dw_conv.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000592 DebugDatabase.add_optimised(op, dw_conv)
593
Rickard Bolinfea15162022-07-04 16:19:16 +0000594 dw_conv = dw_conv.clone(f"_{index}")
595 return op
596
597 _, input_height, input_width, _ = op.ifm.shape
598 _, output_height, output_width, _ = op.ofm.shape
599
600 kernels = _compute_kernels(input_height, input_width, output_height, output_width)
601 op = _build_convolutions(op, kernels)
602
603 return op
604
605
Tim Hall885033b2022-07-21 11:46:03 +0100606def fixup_resize(op, arch, nng):
607 if op.type.is_resize_op() and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200608 if op.ifm_shapes[0] == op.ofm_shapes[0]:
Tim Hall885033b2022-07-21 11:46:03 +0100609 # Bypass the resize op which is essentially a NOP
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200610 op.inputs = op.inputs[:1]
611 op.type = Op.Identity
612 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
Tim Hall885033b2022-07-21 11:46:03 +0100613 convert_resize_1x1_to_add(op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000614 elif op.type == Op.ResizeBilinear and op.attrs.get("half_pixel_centers", False):
615 convert_resizebilinear_to_depthwise_convolutions(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200616 else:
Tim Hall885033b2022-07-21 11:46:03 +0100617 convert_resize_to_upscale_and_average_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200618
619 return op
620
621
622def convert_nop_split_to_identity(op, arch, nng):
623 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
624 # the list comprehension should return a list with a single tensor
625 # if it shouldn't, remove_passthrough_tensor will fail appropriately
626 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
627 op.type = Op.Identity
628 return op
629
630
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100631def rewrite_fully_connected_input(op: Operation, arch, nng):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200632
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100633 if op.type == Op.FullyConnected:
634 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
635 assert new_shape is not None, "Tensor can not be reshaped to 2D"
636 op.ifm_shapes[0] = new_shape
Johan Alfvén65835e02022-10-13 10:49:30 +0200637
638 if op.ifm_shapes[0].batch > 1 and op.ofm_shapes[0].batch == 1:
639 # If IFM is batching then also make sure OFM is batching
640 h, w = op.ofm_shapes[0].height, op.ofm_shapes[0].width
641 op.ofm_shapes[0] = Shape4D([h * w, 1, 1, op.ofm_shapes[0].depth])
642
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200643 return op
644
645
646def convert_batched_fc_shape(op, arch, nng):
647 if op.type == Op.FullyConnected:
648 # Check if the first dimension indicates batching
649 if op.ifm_shapes[0].batch > 1:
650 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
651 n = op.ifm_shapes[0].batch
652 h, w = batching_split.get(n, (1, n))
653 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
654
655 # Reshape Weights to be 4D. IO becomes HWIO
656 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100657 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
658 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200659
660 n = op.ofm_shapes[0].batch
661 h, w = batching_split.get(n, (1, n))
662 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
663 return op
664
665
666def unfuse_activation_function(op):
667 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
668 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
669 op.activation = None
670 out_tens = op.outputs[0]
671 intermediate_tens = out_tens.clone("_act_intermediate")
672 act_op.set_output_tensor(out_tens)
673 act_op.add_input_tensor(intermediate_tens)
674 op.set_output_tensor(intermediate_tens)
675 act_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000676 DebugDatabase.add_optimised(op, act_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200677
678
679def rewrite_stridedslice_output(op, arch, nng):
680 if not op.run_on_npu or op.type != Op.StridedSlice:
681 return op
682
683 new_axis_mask = op.attrs["new_axis_mask"]
684 shrink_axis_mask = op.attrs["shrink_axis_mask"]
685
686 if shrink_axis_mask == 0 and new_axis_mask == 0:
687 return op
688
689 axis_4D = [0] * len(op.outputs)
690 for idx, out_tens in enumerate(op.outputs):
691 output_shape = list(out_tens.shape)
692
693 if shrink_axis_mask != 0:
694 n = 0
695 axis = 0
696 while shrink_axis_mask:
697 prev_mask = shrink_axis_mask
698 n += 1
699 shrink_axis_mask &= shrink_axis_mask - 1
700 axis = int(math.log2(prev_mask - shrink_axis_mask))
701 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
702
703 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
704 op.attrs["shrink_axis_mask"] = 0
705 if axis >= 0:
706 axis_4D[idx] = axis + (4 - len(output_shape))
707 else:
708 axis_4D[idx] = axis
709 op.ofm_shapes[idx] = Shape4D(output_shape)
710
711 elif new_axis_mask != 0:
712 n = 0
713 axis = 0
714 while new_axis_mask:
715 prev_mask = new_axis_mask
716 n += 1
717 new_axis_mask &= new_axis_mask - 1
718 axis = int(math.log2(prev_mask - new_axis_mask))
719 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
720 new_axis_mask >>= 1
721
722 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
723 op.attrs["new_axis_mask"] = 0
724 if axis >= 0:
725 axis_4D[idx] = axis + (4 - len(output_shape))
726 else:
727 axis_4D[idx] = axis
728 op.ofm_shapes[idx] = Shape4D(output_shape)
729
730 op.attrs["split_axis_4D"] = axis_4D
731 return op
732
733
734def rewrite_unpack_output(op, arch, nng):
735 tens = op.outputs[0]
736 if op.run_on_npu and op.type == Op.Unpack:
737 # Unpack is also referred to as Unstack
738 axis = int(op.attrs["axis"])
739 if axis < 0: # Convert to positive axis
740 axis = len(op.inputs[0].shape) + 1 + axis
741 op.type = Op.UnpackReshaped
742 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
743
744 axis_4D = axis + (4 - len(desired_output_shape))
745 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
746
747 for idx, out_tens in enumerate(op.outputs):
748 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
749 return op
750
751
752def add_padding_fields(op, arch, nng):
753 if op.run_on_npu:
754 if "padding" in op.attrs:
755 input_shape = op.ifm_shapes[0]
756 output_shape = op.ofm_shapes[0]
757 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
758 kernel_size = op.inputs[1].shape[:2]
759 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
760 kernel_size = op.attrs["ksize"][1:3]
761 else:
762 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
763
764 if op.type == Op.Conv2DBackpropInputSwitchedBias:
765 upscaling_factor = output_shape.height // input_shape.height
766 padding, skirt = calc_upscaled_padding_and_skirt(
767 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
768 )
769 else:
770 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200771 op.attrs["padding"],
772 op.kernel,
773 input_shape,
774 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200775 )
776
777 op.attrs["explicit_padding"] = padding
778 op.attrs["skirt"] = skirt
779
780 return op
781
782
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200783def reorder_depthwise_weights(op, arch, nng):
784 if op.type.is_depthwise_conv2d_op():
785 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100786 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
787 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200788 weight_tensor.weight_transpose_depthwise = True
789
790 return op
791
792
Raul Farkas090f18a2023-01-24 16:29:06 +0000793def fixup_strided_conv(op, arch, nng):
794 if op.type != Op.Conv2DBias:
Louis Verhaard43d27582022-03-17 14:06:00 +0100795 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200796 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100797 weight_tensor = op.weights
798 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200799
Raul Farkas090f18a2023-01-24 16:29:06 +0000800 # Do not optimize if op is not the first in the network and stride is
801 # supported by the hardware
802 if op.op_index != 0 and stride_x < 4:
803 return op
804 op.ifm.needs_linear_format = True
805
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200806 if (
Raul Farkas090f18a2023-01-24 16:29:06 +0000807 (stride_x == 2 or stride_x == 4)
Louis Verhaard43d27582022-03-17 14:06:00 +0100808 and ifm_shape.depth <= 4
809 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200810 and weight_tensor is not None
811 and weight_tensor.shape[1] >= 2
812 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100813 k_w, _ = op.get_kernel_size()
Raul Farkas090f18a2023-01-24 16:29:06 +0000814 curr_padding_x = needed_total_padding(ifm_shape.width, stride_x, k_w)
815 optimised_padding_x = needed_total_padding(ifm_shape.width // stride_x, 1, (k_w + 1) // stride_x)
816 padding_type = op.attrs.get("padding", None)
817
818 # If padding is enabled, check if current padding matches optimised padding
819 if not padding_type or (padding_type != Padding.VALID and curr_padding_x != optimised_padding_x):
Louis Verhaard43d27582022-03-17 14:06:00 +0100820 # Horizontal padding would become different after optimisation; this would not work
821 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200822 # IFM
Raul Farkas090f18a2023-01-24 16:29:06 +0000823 op.ifm_shapes[0] = Shape4D(
824 [ifm_shape.batch, ifm_shape.height, ifm_shape.width // stride_x, ifm_shape.depth * stride_x]
825 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200826
827 # Weights
828 weight_shape = weight_tensor.shape
829 if weight_shape[1] % 2 != 0:
830 weight_shape[1] = weight_shape[1] + 1
831 padded_array = np.zeros(weight_shape)
832 for i in range(weight_shape[0]):
833 padded_array[i] = np.vstack(
834 [
James Peet7519d502021-07-19 16:47:58 +0100835 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200836 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
837 ]
838 )
James Peet7519d502021-07-19 16:47:58 +0100839 weight_tensor.values = padded_array
Raul Farkas090f18a2023-01-24 16:29:06 +0000840
841 # Change weight shape based on stride_x
842 weight_shape[1] //= stride_x
843 weight_shape[2] *= stride_x
844
James Peet7519d502021-07-19 16:47:58 +0100845 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200846 weight_tensor.set_all_shapes(weight_shape)
847 # If multiple copies of the weights are used, we could avoid
848 # them having the same address by changing the value_id
849 weight_tensor.value_id = uuid.uuid4()
850
851 # Strides
852 stride_x = 1
853 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
854
855 return op
856
857
858def convert_conv_to_fc(op, arch, nng):
859 # Conv 1x1 can be equivalent to Fully Connected.
860 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
861 # caching/double buffering for the weights.
862 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
863 if op.type == Op.Conv2DBias:
864 h = op.ifm_shapes[0].height
865 w = op.ifm_shapes[0].width
866 kh, kw, _, _ = op.inputs[1].shape
867 if h == 1 and w == 1 and kh == 1 and kw == 1:
868 # Overwrite this op as a Fully Connected Op
869 op.name += "_fc"
870 op.type = Op.FullyConnected
871 op.attrs = {
872 "weights_format": 0,
873 }
874 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
875 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100876 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
877 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200878
879 DebugDatabase.add_optimised(op, op)
880 return op
881
882
883def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
884 if op.run_on_npu and op.type.is_relu_op():
885 ifm = op.inputs[0]
886 ofm = op.outputs[0]
887 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
888 # and requires its own to be inserted
889 if not check_quantized_tens_scaling_equal(ifm, ofm):
890 # Override this op with its own primary op (avgpool)
891 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
892 # And fuse the original activation function to it
893 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200894 # Add explicit rescaling
895 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
896 multiplier, shift = scaling.quantise_scale(rescale)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200897 relu_fused_op.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200898 # Tidy up and assign the ifm and ofm to the new op
899 ifm.consumer_list.remove(op)
900
901 relu_fused_op.add_input_tensor(ifm)
902 relu_fused_op.set_output_tensor(ofm)
903 relu_fused_op.set_ifm_ofm_shapes()
904 op = relu_fused_op
905 return op
906
907
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200908def convert_softmax(op, arch, nng):
909 if op.type == Op.Softmax and op.run_on_npu:
910 softmax = SoftMax(op)
911 op = softmax.get_graph()
912 return op
913
914
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200915def convert_prelu(op, arch, nng):
916 if op.type == Op.Prelu:
917 ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
918 if None in (ifm, alpha, ofm):
919 return op
920
Fredrik Svedberg66591652022-08-29 10:51:27 +0200921 if alpha.values is not None:
922 # If const alpha check for possible optimisations
923 alpha_zp = alpha.quantization.zero_point
924 alpha_scale = alpha.quantization.scale_f32
925 # If all alpha values are the same the PReLU can be converted to LeakyRelu
Rickard Bolin5fdcf172022-12-19 12:56:17 +0000926 alpha_min = (alpha.values.min().astype(int) - alpha_zp) * alpha_scale
927 alpha_max = (alpha.values.max().astype(int) - alpha_zp) * alpha_scale
Fredrik Svedberg66591652022-08-29 10:51:27 +0200928 if alpha_min == alpha_max:
929 # or even a Relu
930 if alpha_min == 0:
931 new_op = Op.Relu
932 else:
933 new_op = Op.LeakyRelu
934 op.attrs["alpha"] = alpha_min
935 # setup alpha_scaling for bit exact result
936 ifm_scale = ifm.quantization.scale_f32
937 ofm_scale = ofm.quantization.scale_f32
938 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha_scale, ofm_scale)
939 op.attrs["alpha_scaling"] = (alpha.values.min() - alpha_zp, alpha_scale, alpha_shift)
940 # Change op type
941 op.type = new_op
942 op.name = op.name.replace("Prelu", new_op.name)
943 del op.inputs[1] # Remove alpha tensor
944 return op
945 elif alpha_max < 1:
946 # If alpha_max is less than 1 convert PReLU to Max(alpha * IFM, identity * IFM)
947 # Multiply with alpha tensor
948 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
949 mul_alpha.add_input_tensor(ifm)
950 mul_alpha.add_input_tensor(alpha)
951 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
952 mul_alpha.set_output_tensor(fm_alpha)
953 mul_alpha.set_ifm_ofm_shapes()
954 DebugDatabase.add_optimised(op, mul_alpha)
955 if check_quantized_tens_scaling_equal(ifm, ofm):
956 # No scaling is needed
957 fm_id = ifm
958 else:
959 # Add multiplication with identity
960 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
961 mul_identity.add_input_tensor(ifm)
962 # Create const tensor containing identity as scalar
963 quantization = ifm.quantization.clone()
964 quantization.scale_f32 = np.float32(1)
965 quantization.zero_point = 0
966 one = create_const_tensor("one_const", [], ifm.dtype, [1], quantization=quantization)
967 mul_identity.add_input_tensor(one)
968 # Make sure that fm_id is allocated to a different address than fm_alpha
969 fm_id = ofm.clone(op.name + "_id", set_unique=True)
970 mul_identity.set_output_tensor(fm_id)
971 mul_identity.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000972 DebugDatabase.add_optimised(op, mul_identity)
Fredrik Svedberg66591652022-08-29 10:51:27 +0200973
974 # Combine scaled and alpha multiplied values
975 max_op = Operation(Op.Maximum, op.name + "_max")
976 max_op.add_input_tensor(fm_alpha)
977 max_op.add_input_tensor(fm_id)
978 max_op.set_output_tensor(ofm)
979 max_op.set_ifm_ofm_shapes()
980
981 DebugDatabase.add_optimised(op, max_op)
982 ifm.consumer_list.remove(op)
983 return max_op
984
985 # Catch all PReLU conversion for the cases that could not be optimised above
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200986 no_scale_quant = ifm.quantization.clone()
987 no_scale_quant.scale_f32 = None
988 no_scale_quant.zero_point = 0
Fredrik Svedberg66591652022-08-29 10:51:27 +0200989 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200990
991 # Select values < 0
992 min_op = Operation(Op.Minimum, op.name + "_min")
993 min_op.add_input_tensor(ifm)
994 min_op.add_input_tensor(zero)
995 fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
996 min_op.set_output_tensor(fm_negative)
997 min_op.set_ifm_ofm_shapes()
998 DebugDatabase.add_optimised(op, min_op)
999
1000 # and multiply with alpha tensor
1001 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
1002 mul_alpha.add_input_tensor(fm_negative)
1003 mul_alpha.add_input_tensor(alpha)
1004 fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
1005 mul_alpha.set_output_tensor(fm_alpha)
1006 mul_alpha.set_ifm_ofm_shapes()
1007 DebugDatabase.add_optimised(op, mul_alpha)
1008
1009 # Select (and scale) values > 0
1010 relu_op = Operation(Op.Relu, op.name + "_relu")
1011 relu_op.add_input_tensor(ifm)
1012 fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1013 relu_op.set_output_tensor(fm_scaled)
1014 relu_op.set_ifm_ofm_shapes()
1015 DebugDatabase.add_optimised(op, relu_op)
1016
1017 # Add scaled and alpha multiplied values (without scaling)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001018 add_op = Operation(Op.Add, op.name + "_add")
1019 add_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001020 add_op.add_input_tensor(fm_alpha)
1021 add_op.add_input_tensor(fm_scaled)
1022 add_op.set_output_tensor(ofm)
1023 add_op.set_ifm_ofm_shapes()
1024
1025 DebugDatabase.add_optimised(op, add_op)
1026 ifm.consumer_list.remove(op)
1027 op = add_op
1028
1029 return op
1030
1031
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001032def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
1033 r"""Whenever there is a subgraph with this topology:
1034
Jonas Ohlssond8575072022-03-30 10:30:25 +02001035 Input X For X = -1 or X > 0
1036 | \ / This subgraph can be replaced with either
1037 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
1038 | /
1039 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001040 """
1041
1042 if op.type == Op.Maximum:
1043 # finds the Mul input(s) to the Max
1044 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
1045 if len(muls) == 1:
1046 mul = muls[0].ops[0]
1047 elif len(muls) == 2:
1048 # In the case both inputs are Muls, find the one with the same input as the Max
Fredrik Svedberg66591652022-08-29 10:51:27 +02001049 mul_ifms = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1]
1050 if len(mul_ifms):
1051 mul = mul_ifms[0].ops[0]
1052 else:
1053 # Not using same input
1054 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001055 else:
1056 # No Mul inputs
1057 return op
1058
1059 # make sure the Mul doesn't have any other consumers
1060 mul_ofm = mul.outputs[0]
1061 if len(mul_ofm.consumers()) != 1:
1062 return op
1063 # make sure the Mul doesn't have a fused activation function
1064 if mul.activation:
1065 return op
1066 ifm, ofm = op.get_ifm_ofm()
1067 if ifm is None or ofm is None:
1068 return op
1069
1070 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1071 return op
1072 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
1073 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
1074 return op
1075
1076 # finds the branched input that goes to both the Max and the Mul
1077 shared = set(op.inputs) & set(mul.inputs)
1078 if len(shared) == 1:
1079 shared_in = shared.pop()
1080 # find the constant scalar input to the Mul
1081 const_tens = (set(mul.inputs) - {shared_in}).pop()
1082 # check that it is a scalar
1083 if const_tens.shape != []:
1084 return op
1085 const = const_tens.ops[0]
1086 # check that it is a constant
1087 if const.type != Op.Const:
1088 return op
1089 # Remove the Mul from the shared input's consumers
1090 shared_in.consumer_list.remove(mul)
1091 else:
1092 return op
1093
1094 val = const.outputs[0].values
1095 if val >= 0:
1096 new_op = Op.LeakyRelu
1097 op.attrs["alpha"] = val
1098 # to produce bit exact results, the alpha is not enough;
1099 # save additional scaling info in attr "alpha_scale", to be used as input
1100 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +01001101 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001102 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
1103 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
1104 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
1105 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
1106 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
1107 elif val == -1:
1108 new_op = Op.Abs
1109 else:
1110 return op
1111
1112 op.type = new_op
1113 op.name = op.name.replace("Maximum", new_op.name)
1114 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
1115 op.inputs = [shared_in]
1116 op.set_ifm_ofm_shapes()
1117
1118 # Record optimisation in debug database
1119 DebugDatabase.add_optimised(op, op)
1120
1121 return op
1122
1123
1124def convert_hardswish_to_lut(op, arch, nng):
1125 if op.type == Op.HardSwish:
1126 ifm, ofm = op.get_ifm_ofm()
1127 # Generate the LUT
1128 ifm_scale = np.double(ifm.quantization.scale_f32)
1129 ofm_scale = np.double(ofm.quantization.scale_f32)
1130 zp_in = ifm.quantization.zero_point
1131 zp_out = ofm.quantization.zero_point
1132 ifm_scale_hires = (1 / 128) * ifm_scale
1133 relu_multiplier = np.double(3 / 32768)
1134 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
1135 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
1136 # Use 16bit scale
1137 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
1138 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
1139
1140 values = []
1141 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1142 quantized_min = min(ix)
1143 quantized_max = max(ix)
1144 for x in ix:
1145 input_value = x - zp_in
1146 input_value_hires = input_value * 128
1147 # Compute the input value on essentially the output scale, not shifted yet
1148 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
1149 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
1150 relu_value = np.int16(input_value_hires)
1151 if relu_shift < 31:
1152 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
1153
1154 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
1155
1156 if relu_shift < 31:
1157 relu_value = fp_math.shift_left16(relu_value, 1)
1158
1159 if relu_shift > 31:
1160 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
1161
1162 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
1163 # Now convert that to a 16bit fixedpoint value in [0, 1]
1164 relu_value = (relu_value + (1 << 15)) >> 1
1165 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
1166 shift = 31 - out_shift
1167 shift = -shift if shift < 0 else 0
1168 # Finally apply the output shift
1169 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
1170 lut_result = min(quantized_max, max(quantized_min, lut_result))
1171 values.append(lut_result)
1172 return convert_to_lut(op, values, "hardswish")
1173 return op
1174
1175
1176def convert_lrelu_to_mul_max(op, arch):
1177 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
1178 # (the opposite of convert_mul_max_to_abs_or_lrelu)
1179 ifm, ofm = op.get_ifm_ofm()
1180 if ifm is None or ofm is None:
1181 return op
1182
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001183 alpha = np.float32(op.attrs["alpha"])
1184 use_mul_max = 0 < alpha < 1
Fredrik Svedberg36424312022-09-16 09:39:26 +02001185 is_converted_prelu = "alpha_scaling" in op.attrs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001186 if use_mul_max:
1187 mul_ifm = ifm
1188 new_op = Op.Maximum
1189 else:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001190 # Need to use a different approach for alpha < 0 or alpha > 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001191 no_scale_quant = ifm.quantization.clone()
1192 no_scale_quant.scale_f32 = None
1193 no_scale_quant.zero_point = 0
1194 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
1195
1196 # Select values < 0
1197 min_op = Operation(Op.Minimum, op.name + "_min")
1198 min_op.add_input_tensor(ifm)
1199 min_op.add_input_tensor(zero)
1200 mul_ifm = ifm.clone(op.name + "_negative", set_unique=True)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001201 if alpha < 0 and not is_converted_prelu:
1202 # For negative alpha that is not from a converted PReLU we need to use
1203 # int32 Mul below to perform the (negative) alpha scaling
1204 mul_ifm.dtype = DataType.int32
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001205 min_op.set_output_tensor(mul_ifm)
1206 min_op.set_ifm_ofm_shapes()
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001207 new_op = Op.Add
1208 op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001209 DebugDatabase.add_optimised(op, min_op)
1210
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001211 # Add multiplication with alpha
1212 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001213 mul_alpha.add_input_tensor(mul_ifm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001214 # Create const tensor containing alpha as scalar
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001215 quantization = ifm.quantization.clone()
1216 quantization.min = 0
1217 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
1218 quantization.zero_point = 0
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001219 alpha_dtype = mul_ifm.dtype
Fredrik Svedberg36424312022-09-16 09:39:26 +02001220 if is_converted_prelu:
1221 # The LeakyRelu was the result from convert_prelu and the scaling is provided
Fredrik Svedberg66591652022-08-29 10:51:27 +02001222 scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001223 mul_alpha.explicit_scaling = ExplicitScaling(False, [alpha_shift], [alpha_scale])
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001224 elif alpha == 0 or np.isinf(1 / alpha):
1225 # Handling of alpha near or at zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001226 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001227 scalar = 0
1228 else:
1229 quantization.scale_f32 = alpha
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001230 if alpha_dtype == DataType.int32:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001231 # When the datatype is int32 (alpha negative) we need to do the scaling with the multiplication
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001232 scalar, _ = scaling.elementwise_mul_scale(ifm.quantization.scale_f32, alpha, ofm.quantization.scale_f32)
1233 else:
1234 scalar = 1
Tim Hall3b1578e2023-01-13 17:57:25 +00001235 alpha_tens = create_const_tensor(op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], quantization=quantization)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001236 mul_alpha.add_input_tensor(alpha_tens)
1237 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1238 mul_alpha.set_output_tensor(fm_alpha)
1239 mul_alpha.set_ifm_ofm_shapes()
1240 DebugDatabase.add_optimised(op, mul_alpha)
1241
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001242 if not use_mul_max:
1243 relu_op = Operation(Op.Relu, op.name + "_relu")
1244 relu_op.add_input_tensor(ifm)
1245 fm_id = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1246 relu_op.set_output_tensor(fm_id)
1247 relu_op.set_ifm_ofm_shapes()
1248 DebugDatabase.add_optimised(op, relu_op)
1249 elif check_quantized_tens_scaling_equal(ifm, ofm):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001250 # No identity multiplication is needed
1251 fm_id = ifm
1252 else:
1253 # Add multiplication with identity
1254 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1255 mul_identity.add_input_tensor(ifm)
1256 # Create const tensor containing identity as scalar
1257 quantization = ifm.quantization.clone()
1258 quantization.min = 0
1259 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001260 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001261 quantization.zero_point = 0
Tim Hall3b1578e2023-01-13 17:57:25 +00001262 identity_tens = create_const_tensor(op.name + "_id_scalar", [], ifm.dtype, [1], quantization=quantization)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001263 mul_identity.add_input_tensor(identity_tens)
1264 # Make sure that fm_id is allocated to a different address than fm_alpha
1265 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1266 mul_identity.set_output_tensor(fm_id)
1267 mul_identity.set_ifm_ofm_shapes()
1268 DebugDatabase.add_optimised(op, mul_identity)
1269
1270 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001271 op.type = new_op
1272 op.name = op.name.replace("LeakyRelu", new_op.name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001273 op.inputs = []
1274 ifm.consumer_list.remove(op)
1275 op.add_input_tensor(fm_alpha)
1276 op.add_input_tensor(fm_id)
1277 op.set_ifm_ofm_shapes()
1278
1279 DebugDatabase.add_optimised(op, op)
1280 return op
1281
1282
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001283def convert_to_lut8(op, fn, fn_name):
1284 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
1285 # fn is a function(real) -> real
1286 ifm, ofm = op.get_ifm_ofm()
1287 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1288 return op
1289 # Generate the LUT
1290 ifm_scale = np.double(ifm.quantization.scale_f32)
1291 ofm_scale = np.double(ofm.quantization.scale_f32)
1292 zp_in = ifm.quantization.zero_point
1293 zp_out = ofm.quantization.zero_point
1294 values = []
1295 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1296 quantized_min = min(ix)
1297 quantized_max = max(ix)
1298 for x in ix:
1299 x_real = ifm_scale * (x - zp_in)
1300 y_real = fn(x_real)
1301 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1302 lut_result = min(quantized_max, max(quantized_min, lut_result))
1303 values.append(lut_result)
1304 return convert_to_lut(op, values, fn_name)
1305
1306
1307def convert_lrelu_to_lut(op, arch):
1308 ifm, ofm = op.get_ifm_ofm()
1309 # Generate the LUT
1310 alpha = op.attrs["alpha"]
1311 ifm_scale = np.double(ifm.quantization.scale_f32)
1312 ofm_scale = np.double(ofm.quantization.scale_f32)
1313 zp_in = ifm.quantization.zero_point
1314 zp_out = ofm.quantization.zero_point
1315 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1316 alpha_scalar = 1
1317 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1318 if "alpha_scaling" in op.attrs:
1319 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1320 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1321 values = []
1322 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1323 quantized_min = min(ix)
1324 quantized_max = max(ix)
1325 for x in ix:
1326 if x < zp_in:
1327 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1328 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1329 )
1330 else:
1331 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1332 lut_result = min(quantized_max, max(quantized_min, lut_result))
1333 values.append(lut_result)
1334 return convert_to_lut(op, values, "lrelu")
1335
1336
1337def convert_lrelu(op, arch, nng):
1338 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
1339 if op.type != Op.LeakyRelu:
1340 return op
1341 ifm, ofm = op.get_ifm_ofm()
1342 if ifm is None or ofm is None:
1343 return op
Fredrik Svedberg36424312022-09-16 09:39:26 +02001344 alpha = op.attrs["alpha"]
1345 if alpha == 0:
1346 # When alpha is 0 the opertion can be converted to a ReLU
1347 op.type = Op.Relu
1348 op.name = op.name.replace("LeakyRelu", op.type.name)
1349 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001350 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1351 # use LUT for int8/uint8
1352 return convert_lrelu_to_lut(op, arch)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001353 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16 and alpha > 0:
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001354 # use LeakyRelu unmodified for int16 with equal input/output scaling and positive alpha
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001355 return op
1356 return convert_lrelu_to_mul_max(op, arch)
1357
1358
1359def convert_tanh_sigmoid_to_lut(op, arch, nng):
1360 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
1361 if op.type == Op.Sigmoid:
1362 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1363 elif op.type == Op.Tanh:
1364 return convert_to_lut8(op, math.tanh, "tanh")
1365 return op
1366
1367
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001368def remove_memory_only_ops(op, arch):
1369 if op.run_on_npu and op.type in memory_only_ops:
1370 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001371
1372
1373def fuse_activation_function_with_prev(op, arch, nng):
1374 # if op is a no-op: attempts to move the activation function to the preceding op
1375 if not op.attrs.get("is_nop", False) or op.activation is None:
1376 return op
1377 ifm, ofm = op.get_ifm_ofm()
1378 if ifm is None or ofm is None:
1379 return op
1380 # finds the input(s) to the operation
1381 prev_op = ifm.ops[0]
1382 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1383 fuse = (
1384 prev_op.run_on_npu
1385 and prev_op.type.npu_block_type != NpuBlockType.Default
1386 and len(ifm.ops) == 1
1387 and len(prev_op.outputs[0].consumers()) == 1
1388 and prev_op.activation is None
1389 )
1390 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1391 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1392 # LUT currently only works correctly for elementwise ops
1393 fuse = False
1394 if not fuse:
1395 return op
1396 # Move the fused activation function + corresponding info to prev_op
1397 prev_op.activation = op.activation
1398 prev_op.forced_output_quantization = op.forced_output_quantization
1399 if op.activation_lut is not None:
1400 prev_op.set_activation_lut(op.activation_lut)
1401 # Bypass op
1402 prev_op.set_output_tensor(ofm)
wilisa0179a89042022-11-02 17:18:43 +00001403 DebugDatabase.add_optimised(prev_op, prev_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001404 return op
1405
1406
1407def _leading_pad_ok(leading_pad, stride, kernel_size):
1408 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1409 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1410 max_size = kernel_size // 2
1411 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1412
1413
1414def replace_pad_by_hw_pad(op: Operation, arch, nng):
1415 """
1416 Tries to completely remove a PAD operator by using hardware padding.
1417 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1418 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1419 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1420 if both operations can be run on the NPU.
1421 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1422 """
1423 if (
1424 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +00001425 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001426 and op.run_on_npu
1427 and op.attrs["padding"] == Padding.VALID
1428 ):
1429 pad_op = op.ifm.ops[0]
1430 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1431 return op
1432 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1433 return op
1434 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1435 k = op.kernel
1436 k_w, k_h = k.dilated_wh()
1437
1438 # Check if the PAD operator can be replaced by hardware padding
1439 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1440 # Too much padding, it would require hardware padding to actually insert zeros
1441 return op
1442 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1443 return op
1444
1445 if op.type.is_avgpool_op():
1446 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1447 for pad, k_size in (
1448 (left, k_w),
1449 (right, k_w),
1450 (top, k_h),
1451 (bottom, k_h),
1452 ):
1453 if pad not in (0, k_size // 2):
1454 return op
1455 # Average pool is converted to depthwise, because NPU average pool + same padding
1456 # has a special implementation that is different from PAD followed by average pool with
1457 # valid padding.
1458 k_w, k_h = op.kernel.width, op.kernel.height
1459 ifm = op.ifm
1460 # Remember other inputs
1461 other_inputs = op.inputs[1:]
1462 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1463 quantization = QuantizationParameters(0.0, 255.0)
1464 quantization.scale_f32 = 1.0 / (k_w * k_h)
1465 quantization.zero_point = 0
1466 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1467 weights = np.full(shape, 1)
1468
1469 weight_tens = create_const_tensor(
1470 op.name + "_weights",
1471 shape,
1472 op.ifm.dtype,
1473 weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001474 purpose=TensorPurpose.Weights,
1475 quantization=quantization,
1476 )
James Peet7519d502021-07-19 16:47:58 +01001477 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001478 op.type = Op.DepthwiseConv2DBias
1479 op.inputs = []
1480 op.add_input_tensor(ifm)
1481 op.add_input_tensor(weight_tens)
1482 # Add bias tensor, all biases set to 0
1483 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001484 fixup_bias_tensors(op, arch, nng, DataType.int32)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001485 # Add other inputs
1486 op.inputs.extend(other_inputs)
1487 op.rounding_mode = NpuRoundingMode.NATURAL
1488
1489 # Bypass the PAD operator
1490 op.set_input_tensor(pad_op.ifm, 0)
1491 # Adjust the padding attributes of the convolution operator
1492 op.attrs["padding"] = Padding.EXPLICIT
1493 op.attrs["explicit_padding"] = (top, left, bottom, right)
1494 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +00001495 DebugDatabase.add_optimised(op, op)
1496
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001497 return op
1498
1499
1500def convert_pad(op: Operation, arch, nng):
1501 """
1502 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1503 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1504 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1505 """
1506 if op.type != Op.Pad or not op.run_on_npu:
1507 return op
1508 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1509
1510 ifm = op.ifm
1511 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001512 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001513 ofm = op.ofm
1514 assert ofm is not None
1515 ofm.ops = []
1516 ofm_shape = op.ofm_shapes[0]
1517
1518 # Average pool op that copies IFM to the right place inside the OFM
1519 shp0 = Shape4D(0, 0, 0, 0)
1520 shp_top = shp0.with_height(top)
1521 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1522 avgpool_op.activation = op.activation
1523 quant = ofm.quantization
1524 pad_value = quant.zero_point
1525 # Add operations that fill the borders of the OFM
1526 if top > 0:
1527 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1528 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001529 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001530 )
1531 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1532 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1533 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1534 if bottom > 0:
1535 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1536 zero_tens = create_const_tensor(
1537 op.name + "_bottom",
1538 shape.as_list(),
1539 ofm.dtype,
1540 shape.elements() * [pad_value],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001541 quantization=quant,
1542 )
1543 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1544 create_avg_pool_for_concat(
1545 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1546 )
1547 if left > 0:
1548 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1549 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001550 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001551 )
1552 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1553 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1554 if right > 0:
1555 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1556 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001557 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001558 )
1559 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1560 create_avg_pool_for_concat(
1561 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1562 )
1563
1564 op.type = Op.ConcatTFLite
1565 return avgpool_op
1566
1567
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001568def fixup_bias_tensors(op, arch, nng, dtype=None):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001569 if op.type.needs_bias() and op.bias is None:
1570 # Op has no bias, add bias tensor filled with zeros
1571 nr_biases = op.inputs[1].shape[-1]
1572 bias_values = [0] * nr_biases
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001573 # The DataType of the bias tensor can be explicitly provided or deduced from the ifm
1574 # DataType. Default is int32 bias for 8-bit ifms and int64 for int16 ifms.
1575 # For int16 the selected bias DataType will have an impact on the scaling
1576 # used when encoding the scales and biases later. The default mode will match the
1577 # refence with reduced scaling for int64 bias.
1578 # This means that in cases (in the graph optimiser) where DepthwiseConv2DBias
1579 # is used to emulate average pool int32 bias should be selected for full precision
1580 # int16 scaling.
1581 if dtype is None:
1582 dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32
1583 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001584 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1585
1586 return op
1587
1588
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001589def fixup_asymmetric_weights(op, arch, nng):
1590 if op.run_on_npu and (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op()):
1591 if op.ifm.dtype == DataType.int8:
1592 if not np.all(op.weights.quantization.zero_point == 0):
1593 print(f"Warning: {op.type} '{op.name}' has asymmetric weights, zero points have been adjusted.")
1594 op.weights.quantization.zero_point *= 0
1595
1596 return op
1597
1598
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001599def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1600 if op.type == Op.Mean and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001601 inp, axis = op.inputs
1602 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001603 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001604 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001605 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001606
1607 # Height and width axes have different index depending on dimensions
1608 if axis.shape == [] or axis.shape[0] == 1: # single axis
1609 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1610 if dims in (2, 3):
1611 if axis == 0:
1612 h, w = shape[axis], 1
1613 else:
1614 h, w = 1, shape[axis]
1615 else:
1616 if axis == 1:
1617 h, w = shape[axis], 1
1618 else:
1619 h, w = 1, shape[axis]
1620 else: # multiple axes
1621 axis = sorted(axis.values)
1622 h, w = [shape[i] for i in axis]
1623
1624 # Set necessary depthwise attributes
1625 op.attrs.update(
1626 {
1627 "padding": Padding.VALID,
1628 "stride_h": 1,
1629 "stride_w": 1,
1630 "strides": (1, 1, 1, 1),
1631 "depth_multiplier": 1,
1632 "channel_multiplier": 1,
1633 "dilation_h_factor": 1,
1634 "dilation_w_factor": 1,
1635 "dilation": (1, 1, 1, 1),
1636 }
1637 )
1638 # Change op type
1639 op.type = Op.DepthwiseConv2DBias
1640 # Set IFM/OFM shapes after changing op type
1641 op.set_ifm_ofm_shapes()
1642
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001643 weight_scale, bias = 1, 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001644 ofmq, ifmq = op.ofm.quantization, inp.quantization
Johan Alfvén9d51ec42022-10-27 16:30:01 +02001645 if ifmq.is_scaling_equal(ofmq):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001646 # Here we can just use a simple AvgPool with truncating rounding,
1647 # as we're emulating simple integer division.
1648 op.rounding_mode = NpuRoundingMode.TRUNCATE
1649 op.type = Op.AvgPool
1650 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1651 else:
1652 op.rounding_mode = NpuRoundingMode.NATURAL
1653 weight_scale = 1 / (h * w)
1654 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1655 bias = -ifmq.zero_point * h * w
1656 fiq = ifmq.clone()
1657 fiq.zero_point = 0
1658 op.forced_input_quantization = fiq
1659
1660 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001661 def extend_dims(dim, in_shape):
1662 if dim < 4:
1663 in_shape = [1] + in_shape
1664 if dim == 2:
1665 in_shape += [1]
1666 return in_shape
1667
1668 if dims < 4 or dims_ofm < 4:
1669 # Fix the ofm dimension when keep_dims is false
1670 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1671 if isinstance(axis, int) and dims_ofm + 1 == dims:
1672 ofm_shape.insert(axis, 1)
1673 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1674 for i in axis:
1675 ofm_shape.insert(i, 1)
1676 shape = extend_dims(dims, shape)
1677 dims_ofm = len(ofm_shape)
1678 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1679 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001680
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001681 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001682 weight_shape = None
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001683 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001684 # This can only happen and be done for multiple axes, and
1685 # h * w <= 256 for DepthwiseConv2DBias
1686 # h * w <= 4096 for AvgPool
1687 # which is checked in supported ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001688 shape = [shape[0], 1, h * w, shape[3]]
1689 op.ifm_shapes[0] = Shape4D(shape)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001690 weight_shape = [1, h * w, shape[3], shape[0]]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001691 if h > 256 and op.type == Op.AvgPool:
1692 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1693
1694 # If the AvgPool version is used, we don't need to do anything else
1695 if op.type == Op.AvgPool:
wilisa0179a89042022-11-02 17:18:43 +00001696 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001697 return op
1698
1699 # Make unit weight tensor quantization
1700 weight_quant = ifmq.clone()
1701 weight_quant.min = 0
1702 weight_quant.max = 255
1703 weight_quant.scale_f32 = weight_scale
1704 weight_quant.zero_point = 0
1705
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001706 if weight_shape is None:
1707 # Set weight shape to [H,W,C,B]
1708 weight_shape = [h, w, shape[3], shape[0]]
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001709
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001710 # Add unit weight tensor
1711 op.set_input_tensor(
1712 create_const_tensor(
1713 "weights",
1714 weight_shape,
1715 inp.dtype,
1716 np.ones(weight_shape),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001717 quantization=weight_quant,
1718 ),
1719 1,
1720 )
James Peet7519d502021-07-19 16:47:58 +01001721 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001722
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001723 # Add bias tensor
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001724 bias_shape = [shape[-1]]
1725 op.inputs.append(create_const_tensor("bias", bias_shape, DataType.int32, np.ones(bias_shape) * bias))
wilisa0179a89042022-11-02 17:18:43 +00001726 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001727
1728 return op
1729
1730
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001731def optimise_quantize(op: Operation, arch, nng):
1732
1733 if op.type == Op.Quantize and op.run_on_npu:
1734
1735 ifm, ofm = op.get_ifm_ofm()
1736 input_values = ifm.values
1737
1738 # Guard clause - input not const or no values to quantize
1739 if ifm.ops[0].type != Op.Const or input_values is None:
1740 return op
1741
1742 # Singular val in numpy array, convert to indexable array
1743 if input_values.ndim == 0:
1744 input_values = np.array([input_values])
1745
Fredrik Svedberg11563172022-07-06 14:54:12 +02001746 # requantized int8 to int8 or int16 to int16
1747 if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001748
1749 # scale needs to use double precision to match TFLite reference kernel
1750 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
1751 effective_multiplier, effective_shift = quantise_scale(effective_scale)
1752
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001753 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001754 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001755 input_val = val - ifm.quantization.zero_point
1756
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001757 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
1758 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001759
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001760 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
1761 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001762
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001763 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
1764 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001765
1766 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001767 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001768
1769 quantized_vals = []
1770 for val in input_values:
1771
1772 # Derive quantized value
1773 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001774 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
1775 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001776
1777 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001778 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
1779
1780 # Unsupported data type
1781 else:
1782 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001783
1784 # Make quantize op const and disconnect from parent node
1785
1786 # Remove reference of the current quant op from the parent tensor's consumer list
1787 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1788
1789 # Clear any references to parent node
1790 op.inputs = []
1791
1792 # Convert this quantize op to const
1793 op.type = Op.Const
1794
1795 return op
1796
1797
Ayaan Masood4965fae2022-06-29 11:30:57 +01001798def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
1799 """Static optimisation for SHAPE operator output value known at compile time"""
1800
1801 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
1802
1803 if op.type == Op.Shape and op.run_on_npu:
1804
1805 ifm, ofm = op.get_ifm_ofm()
1806
1807 if len(ifm.shape) != ofm.shape[0]:
1808 return op
1809
1810 # Remove reference of the current shape op from the parent tensor's consumer list
1811 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1812
1813 # Clear any references to parent node
1814 op.inputs = []
1815
1816 # Convert this SHAPE op to const
1817 op.type = Op.Const
1818
1819 # Add size calculation to shape output tensors
1820 ofm.values = np.array(ifm.shape)
1821
1822 return op
1823
1824
Tim Hallea4ba662022-11-11 18:19:53 +00001825def fixup_dilation_gt2(op, arch, nng):
1826 assert op.run_on_npu
1827 if op.type == Op.Conv2DBias or op.type == Op.DepthwiseConv2DBias:
1828 dilation_w, dilation_h = op.get_kernel_dilation()
1829
1830 # if dilation in either axis is greater than that supported by the hardware then we must manually dilate the
1831 # kernel
1832 if dilation_w > 2 or dilation_h > 2:
1833 kernel_w, kernel_h = op.get_kernel_size()
1834 kernel_ic = op.weights.shape[-2]
1835 kernel_oc = op.weights.shape[-1]
1836
1837 # if the dilation is a multiple of 2 then the hardware dialtion can be enabled to provide that multiple
1838 # of 2. this allows the kernel size to be reduced (via the scaled dilation) by half in that dimension.
1839 # odd = 1, even = 2
1840 hw_dilation_h = 1 if (dilation_h & 1) else 2
1841 hw_dilation_w = 1 if (dilation_w & 1) else 2
1842
1843 scale_dilation_h = dilation_h // hw_dilation_h
1844 scale_dilation_w = dilation_w // hw_dilation_w
1845
1846 # create new empty kernel (HWIO format)
1847 new_kernel_h = (kernel_h - 1) * scale_dilation_h + 1
1848 new_kernel_w = (kernel_w - 1) * scale_dilation_w + 1
1849
1850 new_kernel_shape = [new_kernel_h, new_kernel_w, kernel_ic, kernel_oc]
1851 new_kernel_values = np.zeros(new_kernel_shape, dtype=op.weights.values.dtype)
1852
1853 # copy the original kernel values into the new sparse kernel
1854 for h in range(0, kernel_h):
1855 for w in range(0, kernel_w):
1856 new_h = h * scale_dilation_h
1857 new_w = w * scale_dilation_w
1858 new_kernel_values[new_h, new_w, :, :] = op.weights.values[h, w, :, :]
1859
1860 # update the weight tensor with the new dilated kernel
1861 op.weights.shape = new_kernel_shape
1862 op.weights.values = new_kernel_values
1863
1864 # enable(=2) / disable(=1) hardware dilation
1865 op.attrs["dilation"] = (1, hw_dilation_h, hw_dilation_w, 1) # nhwc format
1866 op.attrs["dilation_h_factor"] = hw_dilation_h
1867 op.attrs["dilation_w_factor"] = hw_dilation_w
1868
1869 return op
1870
1871
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001872def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001873 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001874 return op
1875
1876
1877def tflite_optimise_graph(nng, arch):
Fredrik Svedberg11563172022-07-06 14:54:12 +02001878 # Compile time static optimisations
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001879 optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
1880
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001881 for idx, sg in enumerate(nng.subgraphs):
1882 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001883 nng,
1884 sg,
1885 arch,
1886 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01001887 optimisation_list,
1888 rewrite_unsupported=False,
1889 )
1890
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001891 # Pre-processing step
1892 pre_process_list = [
1893 supported_operator_check,
1894 set_ifm_ofm_op_shapes,
1895 ]
1896
Ayaan Masood4965fae2022-06-29 11:30:57 +01001897 for idx, sg in enumerate(nng.subgraphs):
1898 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1899 nng,
1900 sg,
1901 arch,
1902 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02001903 pre_process_list,
1904 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001905 )
1906
1907 # Handle Concat Ops
1908 for idx, sg in enumerate(nng.subgraphs):
1909 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1910 sg.refresh_after_modification()
1911
1912 # Handle Split Ops
1913 for idx, sg in enumerate(nng.subgraphs):
1914 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1915 nng,
1916 sg,
1917 arch,
1918 [],
1919 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1920 rewrite_unsupported=False,
1921 )
1922
1923 for idx, sg in enumerate(nng.subgraphs):
1924 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001925 nng,
1926 sg,
1927 arch,
1928 [rewrite_split_ops],
1929 [],
1930 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001931 )
1932
1933 # Handle sg input output
1934 for idx, sg in enumerate(nng.subgraphs):
1935 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001936 nng,
1937 sg,
1938 arch,
1939 [],
1940 [fix_sg_input_output],
1941 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001942 )
1943
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001944 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001945 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001946 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001947 sg.refresh_after_modification()
1948
1949 # Rewrite of operators
1950 op_rewrite_list = [
1951 set_tensor_equivalence,
1952 convert_mean_to_depthwise_conv_or_avgpool,
1953 convert_depthwise_to_conv,
1954 convert_conv_to_fc,
1955 convert_softmax,
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001956 convert_prelu,
Fredrik Svedberg36424312022-09-16 09:39:26 +02001957 convert_mul_max_to_abs_or_lrelu,
1958 convert_lrelu,
Raul Farkas090f18a2023-01-24 16:29:06 +00001959 fixup_strided_conv,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001960 convert_hardswish_to_lut,
1961 rewrite_fully_connected_input,
1962 convert_batched_fc_shape,
1963 fixup_conv2d_backprop,
1964 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001965 reorder_depthwise_weights,
Tim Hall885033b2022-07-21 11:46:03 +01001966 fixup_resize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001967 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001968 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001969 convert_tanh_sigmoid_to_lut,
1970 replace_pad_by_hw_pad,
Tim Hallea4ba662022-11-11 18:19:53 +00001971 fixup_dilation_gt2,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001972 ]
1973
1974 for idx, sg in enumerate(nng.subgraphs):
1975 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001976 nng,
1977 sg,
1978 arch,
1979 [],
1980 op_rewrite_list,
1981 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001982 )
1983
1984 for idx, sg in enumerate(nng.subgraphs):
1985 # remove passthrough tensors and attempt further optimizations
1986 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1987 nng,
1988 sg,
1989 arch,
1990 [remove_passthrough_tensor],
1991 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1992 )
1993
1994 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1995 # since ifm/ofm_shapes are of importance to this function
1996 for sg in nng.subgraphs:
1997 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1998 sg.refresh_after_modification()
1999
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01002000 # Make sure that const optimisations on subgraph outputs are handled correctly
2001 for sg in nng.subgraphs:
2002 for ofm in sg.output_tensors:
2003 if ofm.is_const and ofm.ops[0].type_changed:
2004 # Subgraph output cannot be const - insert a memory copy
2005 op = ofm.ops[0]
2006 ofm_clone = ofm.clone()
2007 ofm_clone.values = ofm.values
2008 ofm.values = None
Tim Hall3b1578e2023-01-13 17:57:25 +00002009 zero = create_const_tensor("zero", [1], ofm.dtype, [0], quantization=ofm.quantization)
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01002010 memcpy = create_add_nop(f"{ofm.name}_copy")
2011 memcpy.add_input_tensor(ofm_clone)
2012 memcpy.add_input_tensor(zero)
2013 memcpy.set_output_tensor(ofm)
2014 memcpy.set_ifm_ofm_shapes()
2015 op.set_output_tensor(ofm_clone)
2016 DebugDatabase.add_optimised(op, memcpy)
2017
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002018 return nng