blob: b858f6480aca14c91916f10effb65e641efabf93 [file] [log] [blame]
Rickard Bolinbc6ee582022-11-04 08:24:29 +00001# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020017# Description:
18# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
19# to do the traversal of the graph.
20import math
21import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020022
23import numpy as np
24
25from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020026from . import rewrite_graph
27from . import scaling
28from .api import NpuRoundingMode
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020029from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020030from .data_type import DataType
31from .debug_database import DebugDatabase
32from .errors import UnsupportedFeatureError
33from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020034from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020035from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020036from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020037from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020038from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020039from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020040from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020041from .graph_optimiser_util import needed_total_padding
42from .graph_optimiser_util import set_ifm_ofm_op_shapes
43from .graph_optimiser_util import set_tensor_equivalence
44from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020045from .numeric_util import round_away_zero
46from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020047from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020048from .operation import NpuBlockType
49from .operation import Op
50from .operation import Operation
51from .operation import Padding
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +010052from .operation_util import create_add_nop
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020053from .operation_util import create_avgpool_nop
54from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010055from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020056from .shape4d import Shape4D
57from .softmax import SoftMax
58from .tensor import check_quantized_tens_scaling_equal
59from .tensor import create_const_tensor
60from .tensor import create_equivalence_id
61from .tensor import QuantizationParameters
62from .tensor import Tensor
63from .tensor import TensorPurpose
64from .tflite_mapping import optype_to_builtintype
65
66passthrough_nodes = (Op.Identity,)
67
68
69def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
70 """Creates an average pool for the given concat op/input feature map"""
71 ofm = concat_op.ofm
72 avgpool_op = create_avgpool_nop(name)
73 avgpool_op.inputs = [ifm]
74 avgpool_op.outputs = [ofm]
75
76 avgpool_op.write_offset = write_offset
77 avgpool_op.write_shape = ifm_shape
78 ofm.ops.append(avgpool_op)
79 DebugDatabase.add_optimised(concat_op, avgpool_op)
80 avgpool_op.ifm_shapes.append(ifm_shape)
81 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
82 avgpool_op.memory_function = Op.ConcatSliceWrite
83 return avgpool_op
84
85
86def remove_passthrough_tensor(tens, arch, nng):
87 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
88 assert len(tens.ops[0].inputs) == 1
89 tens = tens.ops[0].inputs[0]
90 return tens
91
92
93def rewrite_concat_ops(op, arch):
94 if not op.run_on_npu or not op.type.is_concat_op():
95 return
96
97 axis_4D = 0
98 ofm = op.ofm
99 ofm.ops = []
100 offset = 0
101
102 unfuse_activation_function(op)
103
104 if op.type == Op.Pack:
105 # Pack is also referred to as Stack
106 axis = int(op.attrs["axis"])
107 if axis < 0: # Convert to positive axis
108 axis = len(op.inputs[0].shape) + 1 + axis
109
110 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
111
112 axis_4D = axis + (4 - len(desired_shape))
113
114 for idx, inp in enumerate(op.inputs):
115 op.ifm_shapes[idx] = Shape4D(desired_shape)
116 op.type = Op.PackReshaped
117
118 inputs, axis = op.get_concat_inputs_axis()
119 for idx, inp in enumerate(inputs):
120 if op.type != Op.PackReshaped:
121 op.ifm_shapes[idx] = Shape4D(inp.shape)
122 if axis >= 0:
123 axis_4D = axis + (4 - len(inp.shape))
124 else:
125 axis_4D = axis
126 write_offset = [0, 0, 0, 0]
127 write_offset[axis_4D] = offset
128 concat_end = offset + op.ifm_shapes[idx][axis_4D]
129 create_avg_pool_for_concat(
130 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
131 )
132 offset = concat_end
133 assert ofm.shape[axis] == offset
134
135 return op
136
137
138def rewrite_split_ops(tens, arch, nng):
139
140 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
141 split_op = tens.ops[0]
142
143 # Not supported so leave it and run on CPU
144 if not split_op.run_on_npu:
145 return tens
146
147 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
148
149 tens.ops = []
150 new_op = Operation(Op.SplitSliceRead, split_op.name)
151 new_op.inputs = [inp]
152 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000153 if None in (offset_end, offset_start):
154 read_shape = None
155 else:
156 # the read shape is relative to each start offset
157 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200158
159 # For Split the offset cannot be extracted from the tensor so it has to
160 # be calculated from the index of the output tensor
161 if axis is not None:
162 # Get the start and end of the split
163 offset_start = [0] * 4
164 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
165 for idx, out in enumerate(outputs):
166 if axis_4D_list is not None:
167 axis_4D = axis_4D_list[idx]
168 else:
169 split_op.ofm_shapes[idx] = Shape4D(out.shape)
170 if axis >= 0:
171 axis_4D = axis + (4 - len(out.shape))
172 else:
173 axis_4D = axis
174
175 if out == tens:
176 ofm_shape_idx = idx
177 read_shape = split_op.ofm_shapes[idx]
178 break
179
180 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
181
182 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
183 new_op.read_shapes[0] = read_shape
184 new_op.run_on_npu = True
185 new_op.set_output_tensor(tens)
186 new_op.ifm_shapes.append(Shape4D(inp.shape))
187 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
188 DebugDatabase.add_optimised(split_op, new_op)
189
190 return tens
191
192
193def remove_SplitSliceRead(op, arch):
194
195 if op.type == Op.SplitSliceRead:
196 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
197 if (
198 len(op.ofm.consumer_list) == 1
199 and op.ofm.consumer_list[0] is not None
200 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200201 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200202 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
203 ):
204 # SplitSliceRead can be performed by tensor consumer
205 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200206 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200207 else:
208 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
209 avgpool_op.add_input_tensor(op.ifm)
210 avgpool_op.outputs = [op.ofm]
211 op.ofm.ops.remove(op)
212 op.ofm.ops.append(avgpool_op)
213 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
214 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
215 avgpool_op.read_offsets[0] = op.read_offsets[0]
216 avgpool_op.read_shapes[0] = op.read_shapes[0]
217
218 op.ifm.consumer_list.remove(op)
219 DebugDatabase.add_optimised(op, avgpool_op)
220
221
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200222def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
223 k_w, k_h = kernel.dilated_wh()
224 s_x, s_y = kernel.stride
225 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
226 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
227 if padding_type == Padding.SAME:
228 left_pad = (xpad + 0) // 2
229 right_pad = (xpad + 1) // 2
230 top_pad = (ypad + 0) // 2
231 bottom_pad = (ypad + 1) // 2
232 elif padding_type == Padding.VALID:
233 left_pad = 0
234 right_pad = 0
235 top_pad = 0
236 bottom_pad = 0
237 elif padding_type == Padding.EXPLICIT:
238 # Padding is specified in a PAD operator which has been bypassed.
239 top, left, bottom, right = explicit_padding
240 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
241 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Rickard Bolin9ae34552022-06-09 13:07:17 +0000242 elif padding_type == Padding.TILE:
243 # The values in the explicit padding only represent the "direction" in which to pad
244 top_pad, left_pad, bottom_pad, right_pad = explicit_padding
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200245 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000246 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200247 padding = (top_pad, left_pad, bottom_pad, right_pad)
248 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
249 return padding, skirt
250
251
252def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
253 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
254 if padding_type == Padding.SAME:
255 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
256 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
257 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
258 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
259 left_pad = max(kernel_width - 1 - right_pad, 0)
260 top_pad = max(kernel_height - 1 - bottom_pad, 0)
261 elif padding_type == Padding.VALID:
262 right_pad = max(kernel_width - 2, 0)
263 bottom_pad = max(kernel_height - 2, 0)
264 left_pad = kernel_width - 1
265 top_pad = kernel_height - 1
266 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000267 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200268 padding = (top_pad, left_pad, bottom_pad, right_pad)
269 skirt = padding
270 return padding, skirt
271
272
273def fixup_conv2d_backprop(op, arch, nng):
274 if op.type == Op.Conv2DBackpropInput:
275 # flip the inputs
276 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
277 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000278 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200279
280 # Update strides
281 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
282
283 return op
284
285
286# Convert the op to an elementwise add
Tim Hall885033b2022-07-21 11:46:03 +0100287def convert_resize_1x1_to_add(op):
288 op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200289 op.name = op.name + "_add"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200290 # Create an input tensor filled with zeros
291 shape = op.ofm_shapes[0].as_list()
292 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100293 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200294 tens.quantization = QuantizationParameters(0.0, 255.0)
295 tens.quantization.scale_f32 = 1.0
296 tens.quantization.zero_point = 0
297 tens.consumer_list = [op]
298 tens_op = op.inputs[1].ops[0]
299 tens_op.set_output_tensor(tens)
300 # Set the add inputs
301 op.inputs[1] = op.inputs[0]
302 op.inputs[0] = tens
303 op.set_ifm_ofm_shapes()
304
305 return op
306
307
Tim Hall885033b2022-07-21 11:46:03 +0100308# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
309# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
310# to select the appropriate nearest neighbor value
311def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
312 ifm = op.ifm
313 ofm = op.ofm
314 output_depth = ofm.shape[-1]
315 dw_op_attrs = {
316 "padding": Padding.VALID,
317 "stride_h": 1,
318 "stride_w": 1,
319 "strides": (1, 1, 1, 1),
320 "depth_multiplier": 1,
321 "channel_multiplier": 1,
322 "dilation_h_factor": 1,
323 "dilation_w_factor": 1,
324 "dilation": (1, 1, 1, 1),
325 }
326
327 # change resizebilinear to depthwise
328 op.type = Op.DepthwiseConv2DBias
329 op.attrs.update(dw_op_attrs)
330 op.set_input_tensor(ifm, 0) # ifm tensor index
331 op.activation = None
332
333 # add input resample to resize by x2
334 op.ifm_resampling_mode = resampling_mode.NEAREST
335
336 # don't care about the rounding mode as it is nearest neighbor
337
338 # setup weight tensor
339 weight_quant = QuantizationParameters()
340 weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
341 weight_quant.zero_point = 0
342 weight_quant.quant_dim = 0
343 ofm_dtype = ofm.dtype
344 if ofm_dtype == DataType.uint8:
345 weight_value_dtype = np.uint8
346 weight_quant.quant_min = 0
347 weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
348 else:
349 if ofm_dtype == DataType.int8:
350 weight_value_dtype = np.int8
351 else:
352 assert ofm_dtype == DataType.int16
353 weight_value_dtype = np.int16
354
355 weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
356 weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
357
358 weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
359
360 # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
361 # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
362 # below-and-right (i.e. next) to it (D).
363 # 0---1---2
364 # | A | B |
365 # 1---*---+
366 # | C | D |
367 # 2---+---+
368 weight_values = [0] * (upscale_factor * upscale_factor)
369 centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
370 weight_values[centre_coeff] = 1
371
372 # add weight tensor, this will discard the size tensor of the resize op
373 op.set_input_tensor(
374 create_const_tensor(
375 "weights",
376 weight_shape,
377 ofm.dtype,
378 np.array(weight_values).reshape(weight_shape),
379 value_dtype=weight_value_dtype,
380 quantization=weight_quant,
381 ),
382 1, # inputs tensor weight index
383 )
384
385 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
386 # need to append the bias tensor as resize ops only have 2 inputs
387 assert len(op.inputs) == 2
388 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +0200389 fixup_bias_tensors(op, None, None, DataType.int32)
Tim Hall885033b2022-07-21 11:46:03 +0100390
391 # finally update the shape incase we've change the tensor shapes or connections
392 op.set_ifm_ofm_shapes()
393
394 return op
395
396
397# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
398# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
399# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
400def convert_resize_to_upscale_and_average_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200401 pre_op = op
402 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000403 dtype = op.ifm.dtype
Tim Hall885033b2022-07-21 11:46:03 +0100404
Rickard Boline546def2022-01-25 15:45:00 +0000405 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Tim Hall47c76362022-07-18 21:26:47 +0100406 op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
Tim Hall3c5cfe92022-03-16 16:31:57 +0000407 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200408
409 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
Tim Hall47c76362022-07-18 21:26:47 +0100410
411 # Get upscale factor that was calculated in the supported operators check
412 upscale_factor = op.attrs["upscale_factor"]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200413
Rickard Boline546def2022-01-25 15:45:00 +0000414 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100415 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
416 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
Rickard Boline546def2022-01-25 15:45:00 +0000417 n = int(np.log2(upscale_factor))
418
Tim Hall885033b2022-07-21 11:46:03 +0100419 # Perform x2 upscaling n-1 times
Rickard Boline546def2022-01-25 15:45:00 +0000420 scaled_op = pre_op
421 for count in range(n - 1):
422 if count > 0:
423 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200424 scaled_op.inputs[0] = pre_op.outputs[0]
425
Tim Hall885033b2022-07-21 11:46:03 +0100426 # Nearest neighbor x2 upscaling
Tim Hall47c76362022-07-18 21:26:47 +0100427 upscaled_shape = upscaled_shape * 2
Rickard Boline546def2022-01-25 15:45:00 +0000428 shape = op.ofm_shapes[0].as_list()
429 shape[1:3] = upscaled_shape
430 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
431 out_tens.quantization = op.outputs[0].quantization.clone()
432 scaled_op.set_output_tensor(out_tens)
433 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200434
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200435 scaled_op.set_ifm_ofm_shapes()
436
Tim Hall885033b2022-07-21 11:46:03 +0100437 # Last x2 upscaling
Rickard Boline546def2022-01-25 15:45:00 +0000438 if n > 1:
439 scaled_op = op.clone(f"_{n-1}")
440 scaled_op.inputs[0] = pre_op.outputs[0]
Tim Hall885033b2022-07-21 11:46:03 +0100441
442 if scaled_op.original_type == Op.ResizeBilinear:
443 if scaled_op.attrs["align_corners"]:
444 # no padding
445 scaled_op.attrs["padding"] = Padding.VALID
446 else:
447 # padding to the right and bottom (limits average pool to 8x8 kernel)
448 scaled_op.attrs["padding"] = Padding.EXPLICIT
449 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
450
451 # kernal size dependent on the upscaling factor
452 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
453 else: # Op.ResizeNearestNeighbor
454 if scaled_op.attrs["align_corners"]:
455 # use depthwise conv to select the correct value
456 scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
457 else:
Johan Alfvéna64616c2022-10-17 12:29:12 +0200458 # Keep 1x1 kernel and average pool, this applies both when
459 # half-pixel-centers is True and False. Calculations are the
460 # same in the reference.
Tim Hall885033b2022-07-21 11:46:03 +0100461 pass
462
Rickard Boline546def2022-01-25 15:45:00 +0000463 scaled_op.outputs = outputs
464 scaled_op.outputs[0].ops = [scaled_op]
465 scaled_op.set_ifm_ofm_shapes()
466
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200467 return op
468
469
Rickard Bolinfea15162022-07-04 16:19:16 +0000470def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
471 def _compute_interpolation_values(index, input_size, output_size):
472 scale = input_size / output_size
473 scaled_value = (index + 0.5 * half_pixel_centers) * scale - 0.5 * half_pixel_centers
474 lower_bound = max(np.floor(scaled_value), 0)
475
476 return scaled_value, lower_bound
477
478 def _compute_kernels(input_height, input_width, output_height, output_width):
479 kernels = []
480 for y in (1, 2):
481 for x in (1, 2):
482 sv_h, lb_h = _compute_interpolation_values(y, input_height, output_height)
483 sv_w, lb_w = _compute_interpolation_values(x, input_width, output_width)
484
485 # Interpolation values calculated for (x, y) = ([1, 2], [1, 2]) will always generalize to the whole
486 # input for upscale = 2 and input sizes >= 2x2 and be in the correct order for going left-to-right,
487 # top-to-bottom - same as the depthwise convolution strides across each tile
488 kernel = np.zeros((2, 2))
489 kernel[1, 1] = (1 - (sv_h - lb_h)) * (1 - (sv_w - lb_w))
490 kernel[0, 1] = (sv_h - lb_h) * (1 - (sv_w - lb_w))
491 kernel[1, 0] = (1 - (sv_h - lb_h)) * (sv_w - lb_w)
492 kernel[0, 0] = (sv_h - lb_h) * (sv_w - lb_w)
493 kernel *= 16
494 kernels.append(kernel)
495
496 return kernels
497
498 def _build_convolutions(op, kernels):
499 dw_op_attrs = {
500 "padding": Padding.TILE,
501 "stride_h": 1,
502 "stride_w": 1,
503 "strides": (1, 1, 1, 1),
504 "depth_multiplier": 1,
505 "channel_multiplier": 1,
506 "dilation_h_factor": 1,
507 "dilation_w_factor": 1,
508 "dilation": (1, 1, 1, 1),
509 }
510 ifm = op.ifm
511 ofm = op.ofm
512 ofm.ops = []
513 elem_size = 2 if ofm.dtype == DataType.int16 else 1
514
515 n, h, w, c = ifm.shape
516 _, _, ow, _ = ofm.shape
517
518 intermediate_tens = Tensor(ifm.shape, ifm.dtype, "intermediate_tens")
519 intermediate_tens.quantization = op.outputs[0].quantization.clone()
520 avgpool_op = op
521 avgpool_op.name = "rb_init_avgpool"
522 avgpool_op.type = Op.AvgPool
523 avgpool_op.attrs["padding"] = Padding.VALID
524 avgpool_op.attrs["stride_w"] = 1
525 avgpool_op.attrs["stride_h"] = 1
526 avgpool_op.attrs["filter_width"] = 1
527 avgpool_op.attrs["filter_height"] = 1
528 avgpool_op.attrs["strides"] = [1, 1, 1, 1]
529 avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
530
531 avgpool_op.add_input_tensor(ifm)
532 avgpool_op.set_output_tensor(intermediate_tens)
533 avgpool_op.set_ifm_ofm_shapes()
534
535 dw_conv = Operation(Op.DepthwiseConv2DBias, "depthwise_conv")
536 dw_conv._original_type = Op.ResizeBilinear
537 dw_conv.write_shape = Shape4D(n, h, w, c)
538 dw_conv.write_offset = Shape4D(0, 0, 0, 0)
539
540 # Set the output rounding mode. Resize bilinear requires rounding away from zero. Therefore, we need to
541 # adjust the accumulated value by a "small" amount before applying natural rounding. The "small" amount
542 # should be big enough to cause a x.5 to be rounded correctly but small enough not to cause smaller
543 # values to be incorrectly rounded
544 ofm.quantization.next_after = True
545 dw_conv.rounding_mode = NpuRoundingMode.NATURAL
546
547 # Double height and width stride to write the output of each of the four depthwise convolutions below
548 # interleaved with each other when combined with OFM tile base offsets.
549 dw_conv.ofm_stride_multiplier = [1, 2, 2] # C/H/W
550
551 # Choose tile padding direction - pad by 1 with edge values in two direction.
552 # For example, TL (top left) will pad top and left in H/W-plane in all channels.
553 directions = [[1, 1, 0, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 1]] # TL, TR, BL, BR
554 for i in (0, 1):
555 for j in (0, 1):
556 index = i * 2 + j
557 dw_conv.name = f"depthwise_conv_{index}"
558 dw_op_attrs["explicit_padding"] = directions[index]
559 dw_conv.attrs.update(dw_op_attrs)
560
561 # This will offset the start of the write by modifying the Tile 0 base address
562 dw_conv.tile_base_offsets_ofm[0] = (i * ow + j) * c * elem_size
563
564 ofm.ops.append(dw_conv)
565 dw_conv.outputs = [ofm]
566
567 kernel = kernels[index]
568 shape = [2, 2, 1, c]
569 kernel = np.dstack([kernel] * c)
570
571 quant = QuantizationParameters()
572 quant.zero_point = 0
573 quant.scale_f32 = 1.0 / 16
574
575 dw_conv.inputs = []
576 dw_conv.add_input_tensor(intermediate_tens)
577 dw_conv.add_input_tensor(
578 create_const_tensor(
579 "weights",
580 shape,
581 intermediate_tens.dtype,
582 np.array(kernel).reshape(shape),
583 value_dtype=np.int8,
584 quantization=quant,
585 ),
586 )
587
588 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
589 # need to append the bias tensor as resize ops only have 2 inputs
590 assert len(dw_conv.inputs) == 2
591 dw_conv.inputs.append(None)
Rickard Bolin017b4cc2022-09-23 10:16:48 +0000592 fixup_bias_tensors(dw_conv, None, None, dtype=DataType.int32)
Rickard Bolinfea15162022-07-04 16:19:16 +0000593
594 dw_conv.set_ifm_ofm_shapes()
595 dw_conv = dw_conv.clone(f"_{index}")
596 return op
597
598 _, input_height, input_width, _ = op.ifm.shape
599 _, output_height, output_width, _ = op.ofm.shape
600
601 kernels = _compute_kernels(input_height, input_width, output_height, output_width)
602 op = _build_convolutions(op, kernels)
603
604 return op
605
606
Tim Hall885033b2022-07-21 11:46:03 +0100607def fixup_resize(op, arch, nng):
608 if op.type.is_resize_op() and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200609 if op.ifm_shapes[0] == op.ofm_shapes[0]:
Tim Hall885033b2022-07-21 11:46:03 +0100610 # Bypass the resize op which is essentially a NOP
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200611 op.inputs = op.inputs[:1]
612 op.type = Op.Identity
613 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
Tim Hall885033b2022-07-21 11:46:03 +0100614 convert_resize_1x1_to_add(op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000615 elif op.type == Op.ResizeBilinear and op.attrs.get("half_pixel_centers", False):
616 convert_resizebilinear_to_depthwise_convolutions(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200617 else:
Tim Hall885033b2022-07-21 11:46:03 +0100618 convert_resize_to_upscale_and_average_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200619
620 return op
621
622
623def convert_nop_split_to_identity(op, arch, nng):
624 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
625 # the list comprehension should return a list with a single tensor
626 # if it shouldn't, remove_passthrough_tensor will fail appropriately
627 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
628 op.type = Op.Identity
629 return op
630
631
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100632def rewrite_fully_connected_input(op: Operation, arch, nng):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200633
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100634 if op.type == Op.FullyConnected:
635 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
636 assert new_shape is not None, "Tensor can not be reshaped to 2D"
637 op.ifm_shapes[0] = new_shape
Johan Alfvén65835e02022-10-13 10:49:30 +0200638
639 if op.ifm_shapes[0].batch > 1 and op.ofm_shapes[0].batch == 1:
640 # If IFM is batching then also make sure OFM is batching
641 h, w = op.ofm_shapes[0].height, op.ofm_shapes[0].width
642 op.ofm_shapes[0] = Shape4D([h * w, 1, 1, op.ofm_shapes[0].depth])
643
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200644 return op
645
646
647def convert_batched_fc_shape(op, arch, nng):
648 if op.type == Op.FullyConnected:
649 # Check if the first dimension indicates batching
650 if op.ifm_shapes[0].batch > 1:
651 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
652 n = op.ifm_shapes[0].batch
653 h, w = batching_split.get(n, (1, n))
654 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
655
656 # Reshape Weights to be 4D. IO becomes HWIO
657 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100658 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
659 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200660
661 n = op.ofm_shapes[0].batch
662 h, w = batching_split.get(n, (1, n))
663 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
664 return op
665
666
667def unfuse_activation_function(op):
668 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
669 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
670 op.activation = None
671 out_tens = op.outputs[0]
672 intermediate_tens = out_tens.clone("_act_intermediate")
673 act_op.set_output_tensor(out_tens)
674 act_op.add_input_tensor(intermediate_tens)
675 op.set_output_tensor(intermediate_tens)
676 act_op.set_ifm_ofm_shapes()
677
678
679def rewrite_stridedslice_output(op, arch, nng):
680 if not op.run_on_npu or op.type != Op.StridedSlice:
681 return op
682
683 new_axis_mask = op.attrs["new_axis_mask"]
684 shrink_axis_mask = op.attrs["shrink_axis_mask"]
685
686 if shrink_axis_mask == 0 and new_axis_mask == 0:
687 return op
688
689 axis_4D = [0] * len(op.outputs)
690 for idx, out_tens in enumerate(op.outputs):
691 output_shape = list(out_tens.shape)
692
693 if shrink_axis_mask != 0:
694 n = 0
695 axis = 0
696 while shrink_axis_mask:
697 prev_mask = shrink_axis_mask
698 n += 1
699 shrink_axis_mask &= shrink_axis_mask - 1
700 axis = int(math.log2(prev_mask - shrink_axis_mask))
701 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
702
703 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
704 op.attrs["shrink_axis_mask"] = 0
705 if axis >= 0:
706 axis_4D[idx] = axis + (4 - len(output_shape))
707 else:
708 axis_4D[idx] = axis
709 op.ofm_shapes[idx] = Shape4D(output_shape)
710
711 elif new_axis_mask != 0:
712 n = 0
713 axis = 0
714 while new_axis_mask:
715 prev_mask = new_axis_mask
716 n += 1
717 new_axis_mask &= new_axis_mask - 1
718 axis = int(math.log2(prev_mask - new_axis_mask))
719 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
720 new_axis_mask >>= 1
721
722 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
723 op.attrs["new_axis_mask"] = 0
724 if axis >= 0:
725 axis_4D[idx] = axis + (4 - len(output_shape))
726 else:
727 axis_4D[idx] = axis
728 op.ofm_shapes[idx] = Shape4D(output_shape)
729
730 op.attrs["split_axis_4D"] = axis_4D
731 return op
732
733
734def rewrite_unpack_output(op, arch, nng):
735 tens = op.outputs[0]
736 if op.run_on_npu and op.type == Op.Unpack:
737 # Unpack is also referred to as Unstack
738 axis = int(op.attrs["axis"])
739 if axis < 0: # Convert to positive axis
740 axis = len(op.inputs[0].shape) + 1 + axis
741 op.type = Op.UnpackReshaped
742 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
743
744 axis_4D = axis + (4 - len(desired_output_shape))
745 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
746
747 for idx, out_tens in enumerate(op.outputs):
748 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
749 return op
750
751
752def add_padding_fields(op, arch, nng):
753 if op.run_on_npu:
754 if "padding" in op.attrs:
755 input_shape = op.ifm_shapes[0]
756 output_shape = op.ofm_shapes[0]
757 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
758 kernel_size = op.inputs[1].shape[:2]
759 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
760 kernel_size = op.attrs["ksize"][1:3]
761 else:
762 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
763
764 if op.type == Op.Conv2DBackpropInputSwitchedBias:
765 upscaling_factor = output_shape.height // input_shape.height
766 padding, skirt = calc_upscaled_padding_and_skirt(
767 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
768 )
769 else:
770 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200771 op.attrs["padding"],
772 op.kernel,
773 input_shape,
774 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200775 )
776
777 op.attrs["explicit_padding"] = padding
778 op.attrs["skirt"] = skirt
779
780 return op
781
782
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200783def reorder_depthwise_weights(op, arch, nng):
784 if op.type.is_depthwise_conv2d_op():
785 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100786 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
787 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200788 weight_tensor.weight_transpose_depthwise = True
789
790 return op
791
792
793def optimise_strided_conv(op, arch, nng):
Louis Verhaard43d27582022-03-17 14:06:00 +0100794 if op.type != Op.Conv2DBias or op.op_index != 0:
795 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200796 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100797 weight_tensor = op.weights
798 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200799
800 if (
Louis Verhaard43d27582022-03-17 14:06:00 +0100801 stride_x == 2
802 and ifm_shape.depth <= 4
803 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200804 and weight_tensor is not None
805 and weight_tensor.shape[1] >= 2
806 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100807 k_w, _ = op.get_kernel_size()
808 curr_padding_x = needed_total_padding(ifm_shape.width, 2, k_w)
809 optimised_padding_x = needed_total_padding(ifm_shape.width // 2, 1, (k_w + 1) // 2)
810 if curr_padding_x != optimised_padding_x:
811 # Horizontal padding would become different after optimisation; this would not work
812 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200813 # IFM
814 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
815
816 # Weights
817 weight_shape = weight_tensor.shape
818 if weight_shape[1] % 2 != 0:
819 weight_shape[1] = weight_shape[1] + 1
820 padded_array = np.zeros(weight_shape)
821 for i in range(weight_shape[0]):
822 padded_array[i] = np.vstack(
823 [
James Peet7519d502021-07-19 16:47:58 +0100824 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200825 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
826 ]
827 )
James Peet7519d502021-07-19 16:47:58 +0100828 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200829 weight_shape[1] //= 2
830 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100831 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200832 weight_tensor.set_all_shapes(weight_shape)
833 # If multiple copies of the weights are used, we could avoid
834 # them having the same address by changing the value_id
835 weight_tensor.value_id = uuid.uuid4()
836
837 # Strides
838 stride_x = 1
839 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
840
841 return op
842
843
844def convert_conv_to_fc(op, arch, nng):
845 # Conv 1x1 can be equivalent to Fully Connected.
846 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
847 # caching/double buffering for the weights.
848 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
849 if op.type == Op.Conv2DBias:
850 h = op.ifm_shapes[0].height
851 w = op.ifm_shapes[0].width
852 kh, kw, _, _ = op.inputs[1].shape
853 if h == 1 and w == 1 and kh == 1 and kw == 1:
854 # Overwrite this op as a Fully Connected Op
855 op.name += "_fc"
856 op.type = Op.FullyConnected
857 op.attrs = {
858 "weights_format": 0,
859 }
860 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
861 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100862 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
863 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200864
865 DebugDatabase.add_optimised(op, op)
866 return op
867
868
869def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
870 if op.run_on_npu and op.type.is_relu_op():
871 ifm = op.inputs[0]
872 ofm = op.outputs[0]
873 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
874 # and requires its own to be inserted
875 if not check_quantized_tens_scaling_equal(ifm, ofm):
876 # Override this op with its own primary op (avgpool)
877 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
878 # And fuse the original activation function to it
879 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200880 # Add explicit rescaling
881 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
882 multiplier, shift = scaling.quantise_scale(rescale)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200883 relu_fused_op.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200884 # Tidy up and assign the ifm and ofm to the new op
885 ifm.consumer_list.remove(op)
886
887 relu_fused_op.add_input_tensor(ifm)
888 relu_fused_op.set_output_tensor(ofm)
889 relu_fused_op.set_ifm_ofm_shapes()
890 op = relu_fused_op
891 return op
892
893
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200894def convert_softmax(op, arch, nng):
895 if op.type == Op.Softmax and op.run_on_npu:
896 softmax = SoftMax(op)
897 op = softmax.get_graph()
898 return op
899
900
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200901def convert_prelu(op, arch, nng):
902 if op.type == Op.Prelu:
903 ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
904 if None in (ifm, alpha, ofm):
905 return op
906
Fredrik Svedberg66591652022-08-29 10:51:27 +0200907 if alpha.values is not None:
908 # If const alpha check for possible optimisations
909 alpha_zp = alpha.quantization.zero_point
910 alpha_scale = alpha.quantization.scale_f32
911 # If all alpha values are the same the PReLU can be converted to LeakyRelu
912 alpha_min = (alpha.values.min().astype(np.int) - alpha_zp) * alpha_scale
913 alpha_max = (alpha.values.max().astype(np.int) - alpha_zp) * alpha_scale
914 if alpha_min == alpha_max:
915 # or even a Relu
916 if alpha_min == 0:
917 new_op = Op.Relu
918 else:
919 new_op = Op.LeakyRelu
920 op.attrs["alpha"] = alpha_min
921 # setup alpha_scaling for bit exact result
922 ifm_scale = ifm.quantization.scale_f32
923 ofm_scale = ofm.quantization.scale_f32
924 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha_scale, ofm_scale)
925 op.attrs["alpha_scaling"] = (alpha.values.min() - alpha_zp, alpha_scale, alpha_shift)
926 # Change op type
927 op.type = new_op
928 op.name = op.name.replace("Prelu", new_op.name)
929 del op.inputs[1] # Remove alpha tensor
930 return op
931 elif alpha_max < 1:
932 # If alpha_max is less than 1 convert PReLU to Max(alpha * IFM, identity * IFM)
933 # Multiply with alpha tensor
934 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
935 mul_alpha.add_input_tensor(ifm)
936 mul_alpha.add_input_tensor(alpha)
937 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
938 mul_alpha.set_output_tensor(fm_alpha)
939 mul_alpha.set_ifm_ofm_shapes()
940 DebugDatabase.add_optimised(op, mul_alpha)
941 if check_quantized_tens_scaling_equal(ifm, ofm):
942 # No scaling is needed
943 fm_id = ifm
944 else:
945 # Add multiplication with identity
946 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
947 mul_identity.add_input_tensor(ifm)
948 # Create const tensor containing identity as scalar
949 quantization = ifm.quantization.clone()
950 quantization.scale_f32 = np.float32(1)
951 quantization.zero_point = 0
952 one = create_const_tensor("one_const", [], ifm.dtype, [1], quantization=quantization)
953 mul_identity.add_input_tensor(one)
954 # Make sure that fm_id is allocated to a different address than fm_alpha
955 fm_id = ofm.clone(op.name + "_id", set_unique=True)
956 mul_identity.set_output_tensor(fm_id)
957 mul_identity.set_ifm_ofm_shapes()
958
959 # Combine scaled and alpha multiplied values
960 max_op = Operation(Op.Maximum, op.name + "_max")
961 max_op.add_input_tensor(fm_alpha)
962 max_op.add_input_tensor(fm_id)
963 max_op.set_output_tensor(ofm)
964 max_op.set_ifm_ofm_shapes()
965
966 DebugDatabase.add_optimised(op, max_op)
967 ifm.consumer_list.remove(op)
968 return max_op
969
970 # Catch all PReLU conversion for the cases that could not be optimised above
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200971 no_scale_quant = ifm.quantization.clone()
972 no_scale_quant.scale_f32 = None
973 no_scale_quant.zero_point = 0
Fredrik Svedberg66591652022-08-29 10:51:27 +0200974 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200975
976 # Select values < 0
977 min_op = Operation(Op.Minimum, op.name + "_min")
978 min_op.add_input_tensor(ifm)
979 min_op.add_input_tensor(zero)
980 fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
981 min_op.set_output_tensor(fm_negative)
982 min_op.set_ifm_ofm_shapes()
983 DebugDatabase.add_optimised(op, min_op)
984
985 # and multiply with alpha tensor
986 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
987 mul_alpha.add_input_tensor(fm_negative)
988 mul_alpha.add_input_tensor(alpha)
989 fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
990 mul_alpha.set_output_tensor(fm_alpha)
991 mul_alpha.set_ifm_ofm_shapes()
992 DebugDatabase.add_optimised(op, mul_alpha)
993
994 # Select (and scale) values > 0
995 relu_op = Operation(Op.Relu, op.name + "_relu")
996 relu_op.add_input_tensor(ifm)
997 fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
998 relu_op.set_output_tensor(fm_scaled)
999 relu_op.set_ifm_ofm_shapes()
1000 DebugDatabase.add_optimised(op, relu_op)
1001
1002 # Add scaled and alpha multiplied values (without scaling)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001003 add_op = Operation(Op.Add, op.name + "_add")
1004 add_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001005 add_op.add_input_tensor(fm_alpha)
1006 add_op.add_input_tensor(fm_scaled)
1007 add_op.set_output_tensor(ofm)
1008 add_op.set_ifm_ofm_shapes()
1009
1010 DebugDatabase.add_optimised(op, add_op)
1011 ifm.consumer_list.remove(op)
1012 op = add_op
1013
1014 return op
1015
1016
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001017def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
1018 r"""Whenever there is a subgraph with this topology:
1019
Jonas Ohlssond8575072022-03-30 10:30:25 +02001020 Input X For X = -1 or X > 0
1021 | \ / This subgraph can be replaced with either
1022 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
1023 | /
1024 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001025 """
1026
1027 if op.type == Op.Maximum:
1028 # finds the Mul input(s) to the Max
1029 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
1030 if len(muls) == 1:
1031 mul = muls[0].ops[0]
1032 elif len(muls) == 2:
1033 # In the case both inputs are Muls, find the one with the same input as the Max
Fredrik Svedberg66591652022-08-29 10:51:27 +02001034 mul_ifms = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1]
1035 if len(mul_ifms):
1036 mul = mul_ifms[0].ops[0]
1037 else:
1038 # Not using same input
1039 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001040 else:
1041 # No Mul inputs
1042 return op
1043
1044 # make sure the Mul doesn't have any other consumers
1045 mul_ofm = mul.outputs[0]
1046 if len(mul_ofm.consumers()) != 1:
1047 return op
1048 # make sure the Mul doesn't have a fused activation function
1049 if mul.activation:
1050 return op
1051 ifm, ofm = op.get_ifm_ofm()
1052 if ifm is None or ofm is None:
1053 return op
1054
1055 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1056 return op
1057 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
1058 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
1059 return op
1060
1061 # finds the branched input that goes to both the Max and the Mul
1062 shared = set(op.inputs) & set(mul.inputs)
1063 if len(shared) == 1:
1064 shared_in = shared.pop()
1065 # find the constant scalar input to the Mul
1066 const_tens = (set(mul.inputs) - {shared_in}).pop()
1067 # check that it is a scalar
1068 if const_tens.shape != []:
1069 return op
1070 const = const_tens.ops[0]
1071 # check that it is a constant
1072 if const.type != Op.Const:
1073 return op
1074 # Remove the Mul from the shared input's consumers
1075 shared_in.consumer_list.remove(mul)
1076 else:
1077 return op
1078
1079 val = const.outputs[0].values
1080 if val >= 0:
1081 new_op = Op.LeakyRelu
1082 op.attrs["alpha"] = val
1083 # to produce bit exact results, the alpha is not enough;
1084 # save additional scaling info in attr "alpha_scale", to be used as input
1085 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +01001086 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001087 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
1088 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
1089 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
1090 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
1091 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
1092 elif val == -1:
1093 new_op = Op.Abs
1094 else:
1095 return op
1096
1097 op.type = new_op
1098 op.name = op.name.replace("Maximum", new_op.name)
1099 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
1100 op.inputs = [shared_in]
1101 op.set_ifm_ofm_shapes()
1102
1103 # Record optimisation in debug database
1104 DebugDatabase.add_optimised(op, op)
1105
1106 return op
1107
1108
1109def convert_hardswish_to_lut(op, arch, nng):
1110 if op.type == Op.HardSwish:
1111 ifm, ofm = op.get_ifm_ofm()
1112 # Generate the LUT
1113 ifm_scale = np.double(ifm.quantization.scale_f32)
1114 ofm_scale = np.double(ofm.quantization.scale_f32)
1115 zp_in = ifm.quantization.zero_point
1116 zp_out = ofm.quantization.zero_point
1117 ifm_scale_hires = (1 / 128) * ifm_scale
1118 relu_multiplier = np.double(3 / 32768)
1119 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
1120 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
1121 # Use 16bit scale
1122 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
1123 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
1124
1125 values = []
1126 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1127 quantized_min = min(ix)
1128 quantized_max = max(ix)
1129 for x in ix:
1130 input_value = x - zp_in
1131 input_value_hires = input_value * 128
1132 # Compute the input value on essentially the output scale, not shifted yet
1133 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
1134 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
1135 relu_value = np.int16(input_value_hires)
1136 if relu_shift < 31:
1137 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
1138
1139 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
1140
1141 if relu_shift < 31:
1142 relu_value = fp_math.shift_left16(relu_value, 1)
1143
1144 if relu_shift > 31:
1145 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
1146
1147 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
1148 # Now convert that to a 16bit fixedpoint value in [0, 1]
1149 relu_value = (relu_value + (1 << 15)) >> 1
1150 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
1151 shift = 31 - out_shift
1152 shift = -shift if shift < 0 else 0
1153 # Finally apply the output shift
1154 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
1155 lut_result = min(quantized_max, max(quantized_min, lut_result))
1156 values.append(lut_result)
1157 return convert_to_lut(op, values, "hardswish")
1158 return op
1159
1160
1161def convert_lrelu_to_mul_max(op, arch):
1162 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
1163 # (the opposite of convert_mul_max_to_abs_or_lrelu)
1164 ifm, ofm = op.get_ifm_ofm()
1165 if ifm is None or ofm is None:
1166 return op
1167
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001168 alpha = np.float32(op.attrs["alpha"])
1169 use_mul_max = 0 < alpha < 1
Fredrik Svedberg36424312022-09-16 09:39:26 +02001170 is_converted_prelu = "alpha_scaling" in op.attrs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001171 if use_mul_max:
1172 mul_ifm = ifm
1173 new_op = Op.Maximum
1174 else:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001175 # Need to use a different approach for alpha < 0 or alpha > 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001176 no_scale_quant = ifm.quantization.clone()
1177 no_scale_quant.scale_f32 = None
1178 no_scale_quant.zero_point = 0
1179 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
1180
1181 # Select values < 0
1182 min_op = Operation(Op.Minimum, op.name + "_min")
1183 min_op.add_input_tensor(ifm)
1184 min_op.add_input_tensor(zero)
1185 mul_ifm = ifm.clone(op.name + "_negative", set_unique=True)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001186 if alpha < 0 and not is_converted_prelu:
1187 # For negative alpha that is not from a converted PReLU we need to use
1188 # int32 Mul below to perform the (negative) alpha scaling
1189 mul_ifm.dtype = DataType.int32
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001190 min_op.set_output_tensor(mul_ifm)
1191 min_op.set_ifm_ofm_shapes()
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001192 new_op = Op.Add
1193 op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001194 DebugDatabase.add_optimised(op, min_op)
1195
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001196 # Add multiplication with alpha
1197 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001198 mul_alpha.add_input_tensor(mul_ifm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001199 # Create const tensor containing alpha as scalar
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001200 quantization = ifm.quantization.clone()
1201 quantization.min = 0
1202 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
1203 quantization.zero_point = 0
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001204 alpha_dtype = mul_ifm.dtype
Fredrik Svedberg36424312022-09-16 09:39:26 +02001205 if is_converted_prelu:
1206 # The LeakyRelu was the result from convert_prelu and the scaling is provided
Fredrik Svedberg66591652022-08-29 10:51:27 +02001207 scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001208 mul_alpha.explicit_scaling = ExplicitScaling(False, [alpha_shift], [alpha_scale])
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001209 elif alpha == 0 or np.isinf(1 / alpha):
1210 # Handling of alpha near or at zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001211 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001212 scalar = 0
1213 else:
1214 quantization.scale_f32 = alpha
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001215 if alpha_dtype == DataType.int32:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001216 # When the datatype is int32 (alpha negative) we need to do the scaling with the multiplication
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001217 scalar, _ = scaling.elementwise_mul_scale(ifm.quantization.scale_f32, alpha, ofm.quantization.scale_f32)
1218 else:
1219 scalar = 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001220 alpha_tens = create_const_tensor(
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001221 op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], alpha_dtype.as_numpy_type(), quantization=quantization
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001222 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001223 mul_alpha.add_input_tensor(alpha_tens)
1224 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1225 mul_alpha.set_output_tensor(fm_alpha)
1226 mul_alpha.set_ifm_ofm_shapes()
1227 DebugDatabase.add_optimised(op, mul_alpha)
1228
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001229 if not use_mul_max:
1230 relu_op = Operation(Op.Relu, op.name + "_relu")
1231 relu_op.add_input_tensor(ifm)
1232 fm_id = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1233 relu_op.set_output_tensor(fm_id)
1234 relu_op.set_ifm_ofm_shapes()
1235 DebugDatabase.add_optimised(op, relu_op)
1236 elif check_quantized_tens_scaling_equal(ifm, ofm):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001237 # No identity multiplication is needed
1238 fm_id = ifm
1239 else:
1240 # Add multiplication with identity
1241 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1242 mul_identity.add_input_tensor(ifm)
1243 # Create const tensor containing identity as scalar
1244 quantization = ifm.quantization.clone()
1245 quantization.min = 0
1246 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001247 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001248 quantization.zero_point = 0
1249 identity_tens = create_const_tensor(
1250 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
1251 )
1252 mul_identity.add_input_tensor(identity_tens)
1253 # Make sure that fm_id is allocated to a different address than fm_alpha
1254 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1255 mul_identity.set_output_tensor(fm_id)
1256 mul_identity.set_ifm_ofm_shapes()
1257 DebugDatabase.add_optimised(op, mul_identity)
1258
1259 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001260 op.type = new_op
1261 op.name = op.name.replace("LeakyRelu", new_op.name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001262 op.inputs = []
1263 ifm.consumer_list.remove(op)
1264 op.add_input_tensor(fm_alpha)
1265 op.add_input_tensor(fm_id)
1266 op.set_ifm_ofm_shapes()
1267
1268 DebugDatabase.add_optimised(op, op)
1269 return op
1270
1271
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001272def convert_to_lut8(op, fn, fn_name):
1273 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
1274 # fn is a function(real) -> real
1275 ifm, ofm = op.get_ifm_ofm()
1276 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1277 return op
1278 # Generate the LUT
1279 ifm_scale = np.double(ifm.quantization.scale_f32)
1280 ofm_scale = np.double(ofm.quantization.scale_f32)
1281 zp_in = ifm.quantization.zero_point
1282 zp_out = ofm.quantization.zero_point
1283 values = []
1284 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1285 quantized_min = min(ix)
1286 quantized_max = max(ix)
1287 for x in ix:
1288 x_real = ifm_scale * (x - zp_in)
1289 y_real = fn(x_real)
1290 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1291 lut_result = min(quantized_max, max(quantized_min, lut_result))
1292 values.append(lut_result)
1293 return convert_to_lut(op, values, fn_name)
1294
1295
1296def convert_lrelu_to_lut(op, arch):
1297 ifm, ofm = op.get_ifm_ofm()
1298 # Generate the LUT
1299 alpha = op.attrs["alpha"]
1300 ifm_scale = np.double(ifm.quantization.scale_f32)
1301 ofm_scale = np.double(ofm.quantization.scale_f32)
1302 zp_in = ifm.quantization.zero_point
1303 zp_out = ofm.quantization.zero_point
1304 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1305 alpha_scalar = 1
1306 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1307 if "alpha_scaling" in op.attrs:
1308 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1309 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1310 values = []
1311 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1312 quantized_min = min(ix)
1313 quantized_max = max(ix)
1314 for x in ix:
1315 if x < zp_in:
1316 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1317 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1318 )
1319 else:
1320 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1321 lut_result = min(quantized_max, max(quantized_min, lut_result))
1322 values.append(lut_result)
1323 return convert_to_lut(op, values, "lrelu")
1324
1325
1326def convert_lrelu(op, arch, nng):
1327 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
1328 if op.type != Op.LeakyRelu:
1329 return op
1330 ifm, ofm = op.get_ifm_ofm()
1331 if ifm is None or ofm is None:
1332 return op
Fredrik Svedberg36424312022-09-16 09:39:26 +02001333 alpha = op.attrs["alpha"]
1334 if alpha == 0:
1335 # When alpha is 0 the opertion can be converted to a ReLU
1336 op.type = Op.Relu
1337 op.name = op.name.replace("LeakyRelu", op.type.name)
1338 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001339 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1340 # use LUT for int8/uint8
1341 return convert_lrelu_to_lut(op, arch)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001342 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16 and alpha > 0:
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001343 # use LeakyRelu unmodified for int16 with equal input/output scaling and positive alpha
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001344 return op
1345 return convert_lrelu_to_mul_max(op, arch)
1346
1347
1348def convert_tanh_sigmoid_to_lut(op, arch, nng):
1349 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
1350 if op.type == Op.Sigmoid:
1351 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1352 elif op.type == Op.Tanh:
1353 return convert_to_lut8(op, math.tanh, "tanh")
1354 return op
1355
1356
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001357def remove_memory_only_ops(op, arch):
1358 if op.run_on_npu and op.type in memory_only_ops:
1359 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001360
1361
1362def fuse_activation_function_with_prev(op, arch, nng):
1363 # if op is a no-op: attempts to move the activation function to the preceding op
1364 if not op.attrs.get("is_nop", False) or op.activation is None:
1365 return op
1366 ifm, ofm = op.get_ifm_ofm()
1367 if ifm is None or ofm is None:
1368 return op
1369 # finds the input(s) to the operation
1370 prev_op = ifm.ops[0]
1371 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1372 fuse = (
1373 prev_op.run_on_npu
1374 and prev_op.type.npu_block_type != NpuBlockType.Default
1375 and len(ifm.ops) == 1
1376 and len(prev_op.outputs[0].consumers()) == 1
1377 and prev_op.activation is None
1378 )
1379 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1380 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1381 # LUT currently only works correctly for elementwise ops
1382 fuse = False
1383 if not fuse:
1384 return op
1385 # Move the fused activation function + corresponding info to prev_op
1386 prev_op.activation = op.activation
1387 prev_op.forced_output_quantization = op.forced_output_quantization
1388 if op.activation_lut is not None:
1389 prev_op.set_activation_lut(op.activation_lut)
1390 # Bypass op
1391 prev_op.set_output_tensor(ofm)
1392 DebugDatabase.add_optimised(op, prev_op)
1393 return op
1394
1395
1396def _leading_pad_ok(leading_pad, stride, kernel_size):
1397 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1398 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1399 max_size = kernel_size // 2
1400 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1401
1402
1403def replace_pad_by_hw_pad(op: Operation, arch, nng):
1404 """
1405 Tries to completely remove a PAD operator by using hardware padding.
1406 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1407 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1408 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1409 if both operations can be run on the NPU.
1410 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1411 """
1412 if (
1413 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +00001414 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001415 and op.run_on_npu
1416 and op.attrs["padding"] == Padding.VALID
1417 ):
1418 pad_op = op.ifm.ops[0]
1419 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1420 return op
1421 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1422 return op
1423 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1424 k = op.kernel
1425 k_w, k_h = k.dilated_wh()
1426
1427 # Check if the PAD operator can be replaced by hardware padding
1428 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1429 # Too much padding, it would require hardware padding to actually insert zeros
1430 return op
1431 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1432 return op
1433
1434 if op.type.is_avgpool_op():
1435 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1436 for pad, k_size in (
1437 (left, k_w),
1438 (right, k_w),
1439 (top, k_h),
1440 (bottom, k_h),
1441 ):
1442 if pad not in (0, k_size // 2):
1443 return op
1444 # Average pool is converted to depthwise, because NPU average pool + same padding
1445 # has a special implementation that is different from PAD followed by average pool with
1446 # valid padding.
1447 k_w, k_h = op.kernel.width, op.kernel.height
1448 ifm = op.ifm
1449 # Remember other inputs
1450 other_inputs = op.inputs[1:]
1451 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1452 quantization = QuantizationParameters(0.0, 255.0)
1453 quantization.scale_f32 = 1.0 / (k_w * k_h)
1454 quantization.zero_point = 0
1455 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1456 weights = np.full(shape, 1)
1457
1458 weight_tens = create_const_tensor(
1459 op.name + "_weights",
1460 shape,
1461 op.ifm.dtype,
1462 weights,
1463 np.uint8,
1464 purpose=TensorPurpose.Weights,
1465 quantization=quantization,
1466 )
James Peet7519d502021-07-19 16:47:58 +01001467 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001468 op.type = Op.DepthwiseConv2DBias
1469 op.inputs = []
1470 op.add_input_tensor(ifm)
1471 op.add_input_tensor(weight_tens)
1472 # Add bias tensor, all biases set to 0
1473 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001474 fixup_bias_tensors(op, arch, nng, DataType.int32)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001475 # Add other inputs
1476 op.inputs.extend(other_inputs)
1477 op.rounding_mode = NpuRoundingMode.NATURAL
1478
1479 # Bypass the PAD operator
1480 op.set_input_tensor(pad_op.ifm, 0)
1481 # Adjust the padding attributes of the convolution operator
1482 op.attrs["padding"] = Padding.EXPLICIT
1483 op.attrs["explicit_padding"] = (top, left, bottom, right)
1484 op.set_ifm_ofm_shapes()
1485 return op
1486
1487
1488def convert_pad(op: Operation, arch, nng):
1489 """
1490 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1491 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1492 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1493 """
1494 if op.type != Op.Pad or not op.run_on_npu:
1495 return op
1496 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1497
1498 ifm = op.ifm
1499 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001500 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001501 ofm = op.ofm
1502 assert ofm is not None
1503 ofm.ops = []
1504 ofm_shape = op.ofm_shapes[0]
1505
1506 # Average pool op that copies IFM to the right place inside the OFM
1507 shp0 = Shape4D(0, 0, 0, 0)
1508 shp_top = shp0.with_height(top)
1509 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1510 avgpool_op.activation = op.activation
1511 quant = ofm.quantization
1512 pad_value = quant.zero_point
1513 # Add operations that fill the borders of the OFM
1514 if top > 0:
1515 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1516 zero_tens = create_const_tensor(
1517 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1518 )
1519 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1520 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1521 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1522 if bottom > 0:
1523 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1524 zero_tens = create_const_tensor(
1525 op.name + "_bottom",
1526 shape.as_list(),
1527 ofm.dtype,
1528 shape.elements() * [pad_value],
1529 np.uint8,
1530 quantization=quant,
1531 )
1532 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1533 create_avg_pool_for_concat(
1534 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1535 )
1536 if left > 0:
1537 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1538 zero_tens = create_const_tensor(
1539 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1540 )
1541 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1542 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1543 if right > 0:
1544 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1545 zero_tens = create_const_tensor(
1546 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1547 )
1548 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1549 create_avg_pool_for_concat(
1550 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1551 )
1552
1553 op.type = Op.ConcatTFLite
1554 return avgpool_op
1555
1556
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001557def fixup_bias_tensors(op, arch, nng, dtype=None):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001558 if op.type.needs_bias() and op.bias is None:
1559 # Op has no bias, add bias tensor filled with zeros
1560 nr_biases = op.inputs[1].shape[-1]
1561 bias_values = [0] * nr_biases
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001562 # The DataType of the bias tensor can be explicitly provided or deduced from the ifm
1563 # DataType. Default is int32 bias for 8-bit ifms and int64 for int16 ifms.
1564 # For int16 the selected bias DataType will have an impact on the scaling
1565 # used when encoding the scales and biases later. The default mode will match the
1566 # refence with reduced scaling for int64 bias.
1567 # This means that in cases (in the graph optimiser) where DepthwiseConv2DBias
1568 # is used to emulate average pool int32 bias should be selected for full precision
1569 # int16 scaling.
1570 if dtype is None:
1571 dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32
1572 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001573 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1574
1575 return op
1576
1577
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001578def fixup_asymmetric_weights(op, arch, nng):
1579 if op.run_on_npu and (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op()):
1580 if op.ifm.dtype == DataType.int8:
1581 if not np.all(op.weights.quantization.zero_point == 0):
1582 print(f"Warning: {op.type} '{op.name}' has asymmetric weights, zero points have been adjusted.")
1583 op.weights.quantization.zero_point *= 0
1584
1585 return op
1586
1587
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001588def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1589 if op.type == Op.Mean and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001590 inp, axis = op.inputs
1591 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001592 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001593 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001594 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001595
1596 # Height and width axes have different index depending on dimensions
1597 if axis.shape == [] or axis.shape[0] == 1: # single axis
1598 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1599 if dims in (2, 3):
1600 if axis == 0:
1601 h, w = shape[axis], 1
1602 else:
1603 h, w = 1, shape[axis]
1604 else:
1605 if axis == 1:
1606 h, w = shape[axis], 1
1607 else:
1608 h, w = 1, shape[axis]
1609 else: # multiple axes
1610 axis = sorted(axis.values)
1611 h, w = [shape[i] for i in axis]
1612
1613 # Set necessary depthwise attributes
1614 op.attrs.update(
1615 {
1616 "padding": Padding.VALID,
1617 "stride_h": 1,
1618 "stride_w": 1,
1619 "strides": (1, 1, 1, 1),
1620 "depth_multiplier": 1,
1621 "channel_multiplier": 1,
1622 "dilation_h_factor": 1,
1623 "dilation_w_factor": 1,
1624 "dilation": (1, 1, 1, 1),
1625 }
1626 )
1627 # Change op type
1628 op.type = Op.DepthwiseConv2DBias
1629 # Set IFM/OFM shapes after changing op type
1630 op.set_ifm_ofm_shapes()
1631
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001632 weight_scale, bias = 1, 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001633 ofmq, ifmq = op.ofm.quantization, inp.quantization
Johan Alfvén9d51ec42022-10-27 16:30:01 +02001634 if ifmq.is_scaling_equal(ofmq):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001635 # Here we can just use a simple AvgPool with truncating rounding,
1636 # as we're emulating simple integer division.
1637 op.rounding_mode = NpuRoundingMode.TRUNCATE
1638 op.type = Op.AvgPool
1639 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1640 else:
1641 op.rounding_mode = NpuRoundingMode.NATURAL
1642 weight_scale = 1 / (h * w)
1643 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1644 bias = -ifmq.zero_point * h * w
1645 fiq = ifmq.clone()
1646 fiq.zero_point = 0
1647 op.forced_input_quantization = fiq
1648
1649 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001650 def extend_dims(dim, in_shape):
1651 if dim < 4:
1652 in_shape = [1] + in_shape
1653 if dim == 2:
1654 in_shape += [1]
1655 return in_shape
1656
1657 if dims < 4 or dims_ofm < 4:
1658 # Fix the ofm dimension when keep_dims is false
1659 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1660 if isinstance(axis, int) and dims_ofm + 1 == dims:
1661 ofm_shape.insert(axis, 1)
1662 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1663 for i in axis:
1664 ofm_shape.insert(i, 1)
1665 shape = extend_dims(dims, shape)
1666 dims_ofm = len(ofm_shape)
1667 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1668 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001669
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001670 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001671 weight_shape = None
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001672 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001673 # This can only happen and be done for multiple axes, and
1674 # h * w <= 256 for DepthwiseConv2DBias
1675 # h * w <= 4096 for AvgPool
1676 # which is checked in supported ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001677 shape = [shape[0], 1, h * w, shape[3]]
1678 op.ifm_shapes[0] = Shape4D(shape)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001679 weight_shape = [1, h * w, shape[3], shape[0]]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001680 if h > 256 and op.type == Op.AvgPool:
1681 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1682
1683 # If the AvgPool version is used, we don't need to do anything else
1684 if op.type == Op.AvgPool:
1685 return op
1686
1687 # Make unit weight tensor quantization
1688 weight_quant = ifmq.clone()
1689 weight_quant.min = 0
1690 weight_quant.max = 255
1691 weight_quant.scale_f32 = weight_scale
1692 weight_quant.zero_point = 0
1693
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001694 if weight_shape is None:
1695 # Set weight shape to [H,W,C,B]
1696 weight_shape = [h, w, shape[3], shape[0]]
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001697
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001698 # Add unit weight tensor
1699 op.set_input_tensor(
1700 create_const_tensor(
1701 "weights",
1702 weight_shape,
1703 inp.dtype,
1704 np.ones(weight_shape),
1705 value_dtype=np.uint8,
1706 quantization=weight_quant,
1707 ),
1708 1,
1709 )
James Peet7519d502021-07-19 16:47:58 +01001710 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001711
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001712 # Add bias tensor
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001713 bias_shape = [shape[-1]]
1714 op.inputs.append(create_const_tensor("bias", bias_shape, DataType.int32, np.ones(bias_shape) * bias))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001715
1716 return op
1717
1718
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001719def optimise_quantize(op: Operation, arch, nng):
1720
1721 if op.type == Op.Quantize and op.run_on_npu:
1722
1723 ifm, ofm = op.get_ifm_ofm()
1724 input_values = ifm.values
1725
1726 # Guard clause - input not const or no values to quantize
1727 if ifm.ops[0].type != Op.Const or input_values is None:
1728 return op
1729
1730 # Singular val in numpy array, convert to indexable array
1731 if input_values.ndim == 0:
1732 input_values = np.array([input_values])
1733
Fredrik Svedberg11563172022-07-06 14:54:12 +02001734 # requantized int8 to int8 or int16 to int16
1735 if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001736
1737 # scale needs to use double precision to match TFLite reference kernel
1738 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
1739 effective_multiplier, effective_shift = quantise_scale(effective_scale)
1740
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001741 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001742 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001743 input_val = val - ifm.quantization.zero_point
1744
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001745 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
1746 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001747
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001748 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
1749 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001750
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001751 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
1752 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001753
1754 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001755 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001756
1757 quantized_vals = []
1758 for val in input_values:
1759
1760 # Derive quantized value
1761 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001762 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
1763 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001764
1765 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001766 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
1767
1768 # Unsupported data type
1769 else:
1770 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001771
1772 # Make quantize op const and disconnect from parent node
1773
1774 # Remove reference of the current quant op from the parent tensor's consumer list
1775 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1776
1777 # Clear any references to parent node
1778 op.inputs = []
1779
1780 # Convert this quantize op to const
1781 op.type = Op.Const
1782
1783 return op
1784
1785
Ayaan Masood4965fae2022-06-29 11:30:57 +01001786def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
1787 """Static optimisation for SHAPE operator output value known at compile time"""
1788
1789 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
1790
1791 if op.type == Op.Shape and op.run_on_npu:
1792
1793 ifm, ofm = op.get_ifm_ofm()
1794
1795 if len(ifm.shape) != ofm.shape[0]:
1796 return op
1797
1798 # Remove reference of the current shape op from the parent tensor's consumer list
1799 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1800
1801 # Clear any references to parent node
1802 op.inputs = []
1803
1804 # Convert this SHAPE op to const
1805 op.type = Op.Const
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01001806 DebugDatabase.add_optimised(op, op)
Ayaan Masood4965fae2022-06-29 11:30:57 +01001807
1808 # Add size calculation to shape output tensors
1809 ofm.values = np.array(ifm.shape)
1810
1811 return op
1812
1813
Tim Hallea4ba662022-11-11 18:19:53 +00001814def fixup_dilation_gt2(op, arch, nng):
1815 assert op.run_on_npu
1816 if op.type == Op.Conv2DBias or op.type == Op.DepthwiseConv2DBias:
1817 dilation_w, dilation_h = op.get_kernel_dilation()
1818
1819 # if dilation in either axis is greater than that supported by the hardware then we must manually dilate the
1820 # kernel
1821 if dilation_w > 2 or dilation_h > 2:
1822 kernel_w, kernel_h = op.get_kernel_size()
1823 kernel_ic = op.weights.shape[-2]
1824 kernel_oc = op.weights.shape[-1]
1825
1826 # if the dilation is a multiple of 2 then the hardware dialtion can be enabled to provide that multiple
1827 # of 2. this allows the kernel size to be reduced (via the scaled dilation) by half in that dimension.
1828 # odd = 1, even = 2
1829 hw_dilation_h = 1 if (dilation_h & 1) else 2
1830 hw_dilation_w = 1 if (dilation_w & 1) else 2
1831
1832 scale_dilation_h = dilation_h // hw_dilation_h
1833 scale_dilation_w = dilation_w // hw_dilation_w
1834
1835 # create new empty kernel (HWIO format)
1836 new_kernel_h = (kernel_h - 1) * scale_dilation_h + 1
1837 new_kernel_w = (kernel_w - 1) * scale_dilation_w + 1
1838
1839 new_kernel_shape = [new_kernel_h, new_kernel_w, kernel_ic, kernel_oc]
1840 new_kernel_values = np.zeros(new_kernel_shape, dtype=op.weights.values.dtype)
1841
1842 # copy the original kernel values into the new sparse kernel
1843 for h in range(0, kernel_h):
1844 for w in range(0, kernel_w):
1845 new_h = h * scale_dilation_h
1846 new_w = w * scale_dilation_w
1847 new_kernel_values[new_h, new_w, :, :] = op.weights.values[h, w, :, :]
1848
1849 # update the weight tensor with the new dilated kernel
1850 op.weights.shape = new_kernel_shape
1851 op.weights.values = new_kernel_values
1852
1853 # enable(=2) / disable(=1) hardware dilation
1854 op.attrs["dilation"] = (1, hw_dilation_h, hw_dilation_w, 1) # nhwc format
1855 op.attrs["dilation_h_factor"] = hw_dilation_h
1856 op.attrs["dilation_w_factor"] = hw_dilation_w
1857
1858 return op
1859
1860
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001861def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001862 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001863 return op
1864
1865
1866def tflite_optimise_graph(nng, arch):
Fredrik Svedberg11563172022-07-06 14:54:12 +02001867 # Compile time static optimisations
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001868 optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
1869
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001870 for idx, sg in enumerate(nng.subgraphs):
1871 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001872 nng,
1873 sg,
1874 arch,
1875 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01001876 optimisation_list,
1877 rewrite_unsupported=False,
1878 )
1879
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001880 # Pre-processing step
1881 pre_process_list = [
1882 supported_operator_check,
1883 set_ifm_ofm_op_shapes,
1884 ]
1885
Ayaan Masood4965fae2022-06-29 11:30:57 +01001886 for idx, sg in enumerate(nng.subgraphs):
1887 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1888 nng,
1889 sg,
1890 arch,
1891 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02001892 pre_process_list,
1893 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001894 )
1895
1896 # Handle Concat Ops
1897 for idx, sg in enumerate(nng.subgraphs):
1898 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1899 sg.refresh_after_modification()
1900
1901 # Handle Split Ops
1902 for idx, sg in enumerate(nng.subgraphs):
1903 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1904 nng,
1905 sg,
1906 arch,
1907 [],
1908 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1909 rewrite_unsupported=False,
1910 )
1911
1912 for idx, sg in enumerate(nng.subgraphs):
1913 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001914 nng,
1915 sg,
1916 arch,
1917 [rewrite_split_ops],
1918 [],
1919 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001920 )
1921
1922 # Handle sg input output
1923 for idx, sg in enumerate(nng.subgraphs):
1924 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001925 nng,
1926 sg,
1927 arch,
1928 [],
1929 [fix_sg_input_output],
1930 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001931 )
1932
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001933 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001934 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001935 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001936 sg.refresh_after_modification()
1937
1938 # Rewrite of operators
1939 op_rewrite_list = [
1940 set_tensor_equivalence,
1941 convert_mean_to_depthwise_conv_or_avgpool,
1942 convert_depthwise_to_conv,
1943 convert_conv_to_fc,
1944 convert_softmax,
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001945 convert_prelu,
Fredrik Svedberg36424312022-09-16 09:39:26 +02001946 convert_mul_max_to_abs_or_lrelu,
1947 convert_lrelu,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001948 optimise_strided_conv,
1949 convert_hardswish_to_lut,
1950 rewrite_fully_connected_input,
1951 convert_batched_fc_shape,
1952 fixup_conv2d_backprop,
1953 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001954 reorder_depthwise_weights,
Tim Hall885033b2022-07-21 11:46:03 +01001955 fixup_resize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001956 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001957 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001958 convert_tanh_sigmoid_to_lut,
1959 replace_pad_by_hw_pad,
Tim Hallea4ba662022-11-11 18:19:53 +00001960 fixup_dilation_gt2,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001961 ]
1962
1963 for idx, sg in enumerate(nng.subgraphs):
1964 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001965 nng,
1966 sg,
1967 arch,
1968 [],
1969 op_rewrite_list,
1970 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001971 )
1972
1973 for idx, sg in enumerate(nng.subgraphs):
1974 # remove passthrough tensors and attempt further optimizations
1975 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1976 nng,
1977 sg,
1978 arch,
1979 [remove_passthrough_tensor],
1980 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1981 )
1982
1983 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1984 # since ifm/ofm_shapes are of importance to this function
1985 for sg in nng.subgraphs:
1986 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1987 sg.refresh_after_modification()
1988
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01001989 # Make sure that const optimisations on subgraph outputs are handled correctly
1990 for sg in nng.subgraphs:
1991 for ofm in sg.output_tensors:
1992 if ofm.is_const and ofm.ops[0].type_changed:
1993 # Subgraph output cannot be const - insert a memory copy
1994 op = ofm.ops[0]
1995 ofm_clone = ofm.clone()
1996 ofm_clone.values = ofm.values
1997 ofm.values = None
1998 np_dtype = ofm.dtype.as_numpy_type()
1999 zero = create_const_tensor("zero", [1], ofm.dtype, [0], np_dtype, quantization=ofm.quantization)
2000 memcpy = create_add_nop(f"{ofm.name}_copy")
2001 memcpy.add_input_tensor(ofm_clone)
2002 memcpy.add_input_tensor(zero)
2003 memcpy.set_output_tensor(ofm)
2004 memcpy.set_ifm_ofm_shapes()
2005 op.set_output_tensor(ofm_clone)
2006 DebugDatabase.add_optimised(op, memcpy)
2007
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002008 return nng