blob: fcaac151a78ab8880b0c25a3872fd3a00157625a [file] [log] [blame]
Rickard Bolinfea15162022-07-04 16:19:16 +00001# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from . import rewrite_graph
26from . import scaling
27from .api import NpuRoundingMode
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020028from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020029from .data_type import DataType
30from .debug_database import DebugDatabase
31from .errors import UnsupportedFeatureError
32from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020033from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020034from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020035from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020036from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020037from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020038from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020039from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020040from .graph_optimiser_util import needed_total_padding
41from .graph_optimiser_util import set_ifm_ofm_op_shapes
42from .graph_optimiser_util import set_tensor_equivalence
43from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020044from .numeric_util import round_away_zero
45from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020046from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020047from .operation import NpuBlockType
48from .operation import Op
49from .operation import Operation
50from .operation import Padding
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +010051from .operation_util import create_add_nop
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020052from .operation_util import create_avgpool_nop
53from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010054from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020055from .shape4d import Shape4D
56from .softmax import SoftMax
57from .tensor import check_quantized_tens_scaling_equal
58from .tensor import create_const_tensor
59from .tensor import create_equivalence_id
60from .tensor import QuantizationParameters
61from .tensor import Tensor
62from .tensor import TensorPurpose
63from .tflite_mapping import optype_to_builtintype
64
65passthrough_nodes = (Op.Identity,)
66
67
68def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
69 """Creates an average pool for the given concat op/input feature map"""
70 ofm = concat_op.ofm
71 avgpool_op = create_avgpool_nop(name)
72 avgpool_op.inputs = [ifm]
73 avgpool_op.outputs = [ofm]
74
75 avgpool_op.write_offset = write_offset
76 avgpool_op.write_shape = ifm_shape
77 ofm.ops.append(avgpool_op)
78 DebugDatabase.add_optimised(concat_op, avgpool_op)
79 avgpool_op.ifm_shapes.append(ifm_shape)
80 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
81 avgpool_op.memory_function = Op.ConcatSliceWrite
82 return avgpool_op
83
84
85def remove_passthrough_tensor(tens, arch, nng):
86 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
87 assert len(tens.ops[0].inputs) == 1
88 tens = tens.ops[0].inputs[0]
89 return tens
90
91
92def rewrite_concat_ops(op, arch):
93 if not op.run_on_npu or not op.type.is_concat_op():
94 return
95
96 axis_4D = 0
97 ofm = op.ofm
98 ofm.ops = []
99 offset = 0
100
101 unfuse_activation_function(op)
102
103 if op.type == Op.Pack:
104 # Pack is also referred to as Stack
105 axis = int(op.attrs["axis"])
106 if axis < 0: # Convert to positive axis
107 axis = len(op.inputs[0].shape) + 1 + axis
108
109 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
110
111 axis_4D = axis + (4 - len(desired_shape))
112
113 for idx, inp in enumerate(op.inputs):
114 op.ifm_shapes[idx] = Shape4D(desired_shape)
115 op.type = Op.PackReshaped
116
117 inputs, axis = op.get_concat_inputs_axis()
118 for idx, inp in enumerate(inputs):
119 if op.type != Op.PackReshaped:
120 op.ifm_shapes[idx] = Shape4D(inp.shape)
121 if axis >= 0:
122 axis_4D = axis + (4 - len(inp.shape))
123 else:
124 axis_4D = axis
125 write_offset = [0, 0, 0, 0]
126 write_offset[axis_4D] = offset
127 concat_end = offset + op.ifm_shapes[idx][axis_4D]
128 create_avg_pool_for_concat(
129 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
130 )
131 offset = concat_end
132 assert ofm.shape[axis] == offset
133
134 return op
135
136
137def rewrite_split_ops(tens, arch, nng):
138
139 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
140 split_op = tens.ops[0]
141
142 # Not supported so leave it and run on CPU
143 if not split_op.run_on_npu:
144 return tens
145
146 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
147
148 tens.ops = []
149 new_op = Operation(Op.SplitSliceRead, split_op.name)
150 new_op.inputs = [inp]
151 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000152 if None in (offset_end, offset_start):
153 read_shape = None
154 else:
155 # the read shape is relative to each start offset
156 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200157
158 # For Split the offset cannot be extracted from the tensor so it has to
159 # be calculated from the index of the output tensor
160 if axis is not None:
161 # Get the start and end of the split
162 offset_start = [0] * 4
163 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
164 for idx, out in enumerate(outputs):
165 if axis_4D_list is not None:
166 axis_4D = axis_4D_list[idx]
167 else:
168 split_op.ofm_shapes[idx] = Shape4D(out.shape)
169 if axis >= 0:
170 axis_4D = axis + (4 - len(out.shape))
171 else:
172 axis_4D = axis
173
174 if out == tens:
175 ofm_shape_idx = idx
176 read_shape = split_op.ofm_shapes[idx]
177 break
178
179 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
180
181 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
182 new_op.read_shapes[0] = read_shape
183 new_op.run_on_npu = True
184 new_op.set_output_tensor(tens)
185 new_op.ifm_shapes.append(Shape4D(inp.shape))
186 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
187 DebugDatabase.add_optimised(split_op, new_op)
188
189 return tens
190
191
192def remove_SplitSliceRead(op, arch):
193
194 if op.type == Op.SplitSliceRead:
195 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
196 if (
197 len(op.ofm.consumer_list) == 1
198 and op.ofm.consumer_list[0] is not None
199 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200200 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200201 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
202 ):
203 # SplitSliceRead can be performed by tensor consumer
204 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200205 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200206 else:
207 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
208 avgpool_op.add_input_tensor(op.ifm)
209 avgpool_op.outputs = [op.ofm]
210 op.ofm.ops.remove(op)
211 op.ofm.ops.append(avgpool_op)
212 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
213 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
214 avgpool_op.read_offsets[0] = op.read_offsets[0]
215 avgpool_op.read_shapes[0] = op.read_shapes[0]
216
217 op.ifm.consumer_list.remove(op)
218 DebugDatabase.add_optimised(op, avgpool_op)
219
220
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200221def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
222 k_w, k_h = kernel.dilated_wh()
223 s_x, s_y = kernel.stride
224 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
225 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
226 if padding_type == Padding.SAME:
227 left_pad = (xpad + 0) // 2
228 right_pad = (xpad + 1) // 2
229 top_pad = (ypad + 0) // 2
230 bottom_pad = (ypad + 1) // 2
231 elif padding_type == Padding.VALID:
232 left_pad = 0
233 right_pad = 0
234 top_pad = 0
235 bottom_pad = 0
236 elif padding_type == Padding.EXPLICIT:
237 # Padding is specified in a PAD operator which has been bypassed.
238 top, left, bottom, right = explicit_padding
239 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
240 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Rickard Bolin9ae34552022-06-09 13:07:17 +0000241 elif padding_type == Padding.TILE:
242 # The values in the explicit padding only represent the "direction" in which to pad
243 top_pad, left_pad, bottom_pad, right_pad = explicit_padding
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200244 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000245 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200246 padding = (top_pad, left_pad, bottom_pad, right_pad)
247 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
248 return padding, skirt
249
250
251def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
252 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
253 if padding_type == Padding.SAME:
254 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
255 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
256 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
257 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
258 left_pad = max(kernel_width - 1 - right_pad, 0)
259 top_pad = max(kernel_height - 1 - bottom_pad, 0)
260 elif padding_type == Padding.VALID:
261 right_pad = max(kernel_width - 2, 0)
262 bottom_pad = max(kernel_height - 2, 0)
263 left_pad = kernel_width - 1
264 top_pad = kernel_height - 1
265 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000266 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200267 padding = (top_pad, left_pad, bottom_pad, right_pad)
268 skirt = padding
269 return padding, skirt
270
271
272def fixup_conv2d_backprop(op, arch, nng):
273 if op.type == Op.Conv2DBackpropInput:
274 # flip the inputs
275 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
276 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000277 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200278
279 # Update strides
280 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
281
282 return op
283
284
285# Convert the op to an elementwise add
Tim Hall885033b2022-07-21 11:46:03 +0100286def convert_resize_1x1_to_add(op):
287 op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200288 op.name = op.name + "_add"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200289 # Create an input tensor filled with zeros
290 shape = op.ofm_shapes[0].as_list()
291 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100292 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200293 tens.quantization = QuantizationParameters(0.0, 255.0)
294 tens.quantization.scale_f32 = 1.0
295 tens.quantization.zero_point = 0
296 tens.consumer_list = [op]
297 tens_op = op.inputs[1].ops[0]
298 tens_op.set_output_tensor(tens)
299 # Set the add inputs
300 op.inputs[1] = op.inputs[0]
301 op.inputs[0] = tens
302 op.set_ifm_ofm_shapes()
303
304 return op
305
306
Tim Hall885033b2022-07-21 11:46:03 +0100307# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
308# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
309# to select the appropriate nearest neighbor value
310def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
311 ifm = op.ifm
312 ofm = op.ofm
313 output_depth = ofm.shape[-1]
314 dw_op_attrs = {
315 "padding": Padding.VALID,
316 "stride_h": 1,
317 "stride_w": 1,
318 "strides": (1, 1, 1, 1),
319 "depth_multiplier": 1,
320 "channel_multiplier": 1,
321 "dilation_h_factor": 1,
322 "dilation_w_factor": 1,
323 "dilation": (1, 1, 1, 1),
324 }
325
326 # change resizebilinear to depthwise
327 op.type = Op.DepthwiseConv2DBias
328 op.attrs.update(dw_op_attrs)
329 op.set_input_tensor(ifm, 0) # ifm tensor index
330 op.activation = None
331
332 # add input resample to resize by x2
333 op.ifm_resampling_mode = resampling_mode.NEAREST
334
335 # don't care about the rounding mode as it is nearest neighbor
336
337 # setup weight tensor
338 weight_quant = QuantizationParameters()
339 weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
340 weight_quant.zero_point = 0
341 weight_quant.quant_dim = 0
342 ofm_dtype = ofm.dtype
343 if ofm_dtype == DataType.uint8:
344 weight_value_dtype = np.uint8
345 weight_quant.quant_min = 0
346 weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
347 else:
348 if ofm_dtype == DataType.int8:
349 weight_value_dtype = np.int8
350 else:
351 assert ofm_dtype == DataType.int16
352 weight_value_dtype = np.int16
353
354 weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
355 weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
356
357 weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
358
359 # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
360 # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
361 # below-and-right (i.e. next) to it (D).
362 # 0---1---2
363 # | A | B |
364 # 1---*---+
365 # | C | D |
366 # 2---+---+
367 weight_values = [0] * (upscale_factor * upscale_factor)
368 centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
369 weight_values[centre_coeff] = 1
370
371 # add weight tensor, this will discard the size tensor of the resize op
372 op.set_input_tensor(
373 create_const_tensor(
374 "weights",
375 weight_shape,
376 ofm.dtype,
377 np.array(weight_values).reshape(weight_shape),
378 value_dtype=weight_value_dtype,
379 quantization=weight_quant,
380 ),
381 1, # inputs tensor weight index
382 )
383
384 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
385 # need to append the bias tensor as resize ops only have 2 inputs
386 assert len(op.inputs) == 2
387 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +0200388 fixup_bias_tensors(op, None, None, DataType.int32)
Tim Hall885033b2022-07-21 11:46:03 +0100389
390 # finally update the shape incase we've change the tensor shapes or connections
391 op.set_ifm_ofm_shapes()
392
393 return op
394
395
396# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
397# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
398# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
399def convert_resize_to_upscale_and_average_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200400 pre_op = op
401 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000402 dtype = op.ifm.dtype
Tim Hall885033b2022-07-21 11:46:03 +0100403
Rickard Boline546def2022-01-25 15:45:00 +0000404 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Tim Hall47c76362022-07-18 21:26:47 +0100405 op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
Tim Hall3c5cfe92022-03-16 16:31:57 +0000406 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200407
408 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
Tim Hall47c76362022-07-18 21:26:47 +0100409
410 # Get upscale factor that was calculated in the supported operators check
411 upscale_factor = op.attrs["upscale_factor"]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200412
Rickard Boline546def2022-01-25 15:45:00 +0000413 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100414 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
415 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
Rickard Boline546def2022-01-25 15:45:00 +0000416 n = int(np.log2(upscale_factor))
417
Tim Hall885033b2022-07-21 11:46:03 +0100418 # Perform x2 upscaling n-1 times
Rickard Boline546def2022-01-25 15:45:00 +0000419 scaled_op = pre_op
420 for count in range(n - 1):
421 if count > 0:
422 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200423 scaled_op.inputs[0] = pre_op.outputs[0]
424
Tim Hall885033b2022-07-21 11:46:03 +0100425 # Nearest neighbor x2 upscaling
Tim Hall47c76362022-07-18 21:26:47 +0100426 upscaled_shape = upscaled_shape * 2
Rickard Boline546def2022-01-25 15:45:00 +0000427 shape = op.ofm_shapes[0].as_list()
428 shape[1:3] = upscaled_shape
429 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
430 out_tens.quantization = op.outputs[0].quantization.clone()
431 scaled_op.set_output_tensor(out_tens)
432 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200433
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200434 scaled_op.set_ifm_ofm_shapes()
435
Tim Hall885033b2022-07-21 11:46:03 +0100436 # Last x2 upscaling
Rickard Boline546def2022-01-25 15:45:00 +0000437 if n > 1:
438 scaled_op = op.clone(f"_{n-1}")
439 scaled_op.inputs[0] = pre_op.outputs[0]
Tim Hall885033b2022-07-21 11:46:03 +0100440
441 if scaled_op.original_type == Op.ResizeBilinear:
442 if scaled_op.attrs["align_corners"]:
443 # no padding
444 scaled_op.attrs["padding"] = Padding.VALID
445 else:
446 # padding to the right and bottom (limits average pool to 8x8 kernel)
447 scaled_op.attrs["padding"] = Padding.EXPLICIT
448 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
449
450 # kernal size dependent on the upscaling factor
451 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
452 else: # Op.ResizeNearestNeighbor
453 if scaled_op.attrs["align_corners"]:
454 # use depthwise conv to select the correct value
455 scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
456 else:
Johan Alfvéna64616c2022-10-17 12:29:12 +0200457 # Keep 1x1 kernel and average pool, this applies both when
458 # half-pixel-centers is True and False. Calculations are the
459 # same in the reference.
Tim Hall885033b2022-07-21 11:46:03 +0100460 pass
461
Rickard Boline546def2022-01-25 15:45:00 +0000462 scaled_op.outputs = outputs
463 scaled_op.outputs[0].ops = [scaled_op]
464 scaled_op.set_ifm_ofm_shapes()
465
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200466 return op
467
468
Rickard Bolinfea15162022-07-04 16:19:16 +0000469def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
470 def _compute_interpolation_values(index, input_size, output_size):
471 scale = input_size / output_size
472 scaled_value = (index + 0.5 * half_pixel_centers) * scale - 0.5 * half_pixel_centers
473 lower_bound = max(np.floor(scaled_value), 0)
474
475 return scaled_value, lower_bound
476
477 def _compute_kernels(input_height, input_width, output_height, output_width):
478 kernels = []
479 for y in (1, 2):
480 for x in (1, 2):
481 sv_h, lb_h = _compute_interpolation_values(y, input_height, output_height)
482 sv_w, lb_w = _compute_interpolation_values(x, input_width, output_width)
483
484 # Interpolation values calculated for (x, y) = ([1, 2], [1, 2]) will always generalize to the whole
485 # input for upscale = 2 and input sizes >= 2x2 and be in the correct order for going left-to-right,
486 # top-to-bottom - same as the depthwise convolution strides across each tile
487 kernel = np.zeros((2, 2))
488 kernel[1, 1] = (1 - (sv_h - lb_h)) * (1 - (sv_w - lb_w))
489 kernel[0, 1] = (sv_h - lb_h) * (1 - (sv_w - lb_w))
490 kernel[1, 0] = (1 - (sv_h - lb_h)) * (sv_w - lb_w)
491 kernel[0, 0] = (sv_h - lb_h) * (sv_w - lb_w)
492 kernel *= 16
493 kernels.append(kernel)
494
495 return kernels
496
497 def _build_convolutions(op, kernels):
498 dw_op_attrs = {
499 "padding": Padding.TILE,
500 "stride_h": 1,
501 "stride_w": 1,
502 "strides": (1, 1, 1, 1),
503 "depth_multiplier": 1,
504 "channel_multiplier": 1,
505 "dilation_h_factor": 1,
506 "dilation_w_factor": 1,
507 "dilation": (1, 1, 1, 1),
508 }
509 ifm = op.ifm
510 ofm = op.ofm
511 ofm.ops = []
512 elem_size = 2 if ofm.dtype == DataType.int16 else 1
513
514 n, h, w, c = ifm.shape
515 _, _, ow, _ = ofm.shape
516
517 intermediate_tens = Tensor(ifm.shape, ifm.dtype, "intermediate_tens")
518 intermediate_tens.quantization = op.outputs[0].quantization.clone()
519 avgpool_op = op
520 avgpool_op.name = "rb_init_avgpool"
521 avgpool_op.type = Op.AvgPool
522 avgpool_op.attrs["padding"] = Padding.VALID
523 avgpool_op.attrs["stride_w"] = 1
524 avgpool_op.attrs["stride_h"] = 1
525 avgpool_op.attrs["filter_width"] = 1
526 avgpool_op.attrs["filter_height"] = 1
527 avgpool_op.attrs["strides"] = [1, 1, 1, 1]
528 avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
529
530 avgpool_op.add_input_tensor(ifm)
531 avgpool_op.set_output_tensor(intermediate_tens)
532 avgpool_op.set_ifm_ofm_shapes()
533
534 dw_conv = Operation(Op.DepthwiseConv2DBias, "depthwise_conv")
535 dw_conv._original_type = Op.ResizeBilinear
536 dw_conv.write_shape = Shape4D(n, h, w, c)
537 dw_conv.write_offset = Shape4D(0, 0, 0, 0)
538
539 # Set the output rounding mode. Resize bilinear requires rounding away from zero. Therefore, we need to
540 # adjust the accumulated value by a "small" amount before applying natural rounding. The "small" amount
541 # should be big enough to cause a x.5 to be rounded correctly but small enough not to cause smaller
542 # values to be incorrectly rounded
543 ofm.quantization.next_after = True
544 dw_conv.rounding_mode = NpuRoundingMode.NATURAL
545
546 # Double height and width stride to write the output of each of the four depthwise convolutions below
547 # interleaved with each other when combined with OFM tile base offsets.
548 dw_conv.ofm_stride_multiplier = [1, 2, 2] # C/H/W
549
550 # Choose tile padding direction - pad by 1 with edge values in two direction.
551 # For example, TL (top left) will pad top and left in H/W-plane in all channels.
552 directions = [[1, 1, 0, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 1]] # TL, TR, BL, BR
553 for i in (0, 1):
554 for j in (0, 1):
555 index = i * 2 + j
556 dw_conv.name = f"depthwise_conv_{index}"
557 dw_op_attrs["explicit_padding"] = directions[index]
558 dw_conv.attrs.update(dw_op_attrs)
559
560 # This will offset the start of the write by modifying the Tile 0 base address
561 dw_conv.tile_base_offsets_ofm[0] = (i * ow + j) * c * elem_size
562
563 ofm.ops.append(dw_conv)
564 dw_conv.outputs = [ofm]
565
566 kernel = kernels[index]
567 shape = [2, 2, 1, c]
568 kernel = np.dstack([kernel] * c)
569
570 quant = QuantizationParameters()
571 quant.zero_point = 0
572 quant.scale_f32 = 1.0 / 16
573
574 dw_conv.inputs = []
575 dw_conv.add_input_tensor(intermediate_tens)
576 dw_conv.add_input_tensor(
577 create_const_tensor(
578 "weights",
579 shape,
580 intermediate_tens.dtype,
581 np.array(kernel).reshape(shape),
582 value_dtype=np.int8,
583 quantization=quant,
584 ),
585 )
586
587 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
588 # need to append the bias tensor as resize ops only have 2 inputs
589 assert len(dw_conv.inputs) == 2
590 dw_conv.inputs.append(None)
Rickard Bolin017b4cc2022-09-23 10:16:48 +0000591 fixup_bias_tensors(dw_conv, None, None, dtype=DataType.int32)
Rickard Bolinfea15162022-07-04 16:19:16 +0000592
593 dw_conv.set_ifm_ofm_shapes()
594 dw_conv = dw_conv.clone(f"_{index}")
595 return op
596
597 _, input_height, input_width, _ = op.ifm.shape
598 _, output_height, output_width, _ = op.ofm.shape
599
600 kernels = _compute_kernels(input_height, input_width, output_height, output_width)
601 op = _build_convolutions(op, kernels)
602
603 return op
604
605
Tim Hall885033b2022-07-21 11:46:03 +0100606def fixup_resize(op, arch, nng):
607 if op.type.is_resize_op() and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200608 if op.ifm_shapes[0] == op.ofm_shapes[0]:
Tim Hall885033b2022-07-21 11:46:03 +0100609 # Bypass the resize op which is essentially a NOP
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200610 op.inputs = op.inputs[:1]
611 op.type = Op.Identity
612 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
Tim Hall885033b2022-07-21 11:46:03 +0100613 convert_resize_1x1_to_add(op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000614 elif op.type == Op.ResizeBilinear and op.attrs.get("half_pixel_centers", False):
615 convert_resizebilinear_to_depthwise_convolutions(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200616 else:
Tim Hall885033b2022-07-21 11:46:03 +0100617 convert_resize_to_upscale_and_average_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200618
619 return op
620
621
622def convert_nop_split_to_identity(op, arch, nng):
623 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
624 # the list comprehension should return a list with a single tensor
625 # if it shouldn't, remove_passthrough_tensor will fail appropriately
626 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
627 op.type = Op.Identity
628 return op
629
630
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100631def rewrite_fully_connected_input(op: Operation, arch, nng):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200632
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100633 if op.type == Op.FullyConnected:
634 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
635 assert new_shape is not None, "Tensor can not be reshaped to 2D"
636 op.ifm_shapes[0] = new_shape
Johan Alfvén65835e02022-10-13 10:49:30 +0200637
638 if op.ifm_shapes[0].batch > 1 and op.ofm_shapes[0].batch == 1:
639 # If IFM is batching then also make sure OFM is batching
640 h, w = op.ofm_shapes[0].height, op.ofm_shapes[0].width
641 op.ofm_shapes[0] = Shape4D([h * w, 1, 1, op.ofm_shapes[0].depth])
642
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200643 return op
644
645
646def convert_batched_fc_shape(op, arch, nng):
647 if op.type == Op.FullyConnected:
648 # Check if the first dimension indicates batching
649 if op.ifm_shapes[0].batch > 1:
650 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
651 n = op.ifm_shapes[0].batch
652 h, w = batching_split.get(n, (1, n))
653 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
654
655 # Reshape Weights to be 4D. IO becomes HWIO
656 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100657 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
658 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200659
660 n = op.ofm_shapes[0].batch
661 h, w = batching_split.get(n, (1, n))
662 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
663 return op
664
665
666def unfuse_activation_function(op):
667 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
668 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
669 op.activation = None
670 out_tens = op.outputs[0]
671 intermediate_tens = out_tens.clone("_act_intermediate")
672 act_op.set_output_tensor(out_tens)
673 act_op.add_input_tensor(intermediate_tens)
674 op.set_output_tensor(intermediate_tens)
675 act_op.set_ifm_ofm_shapes()
676
677
678def rewrite_stridedslice_output(op, arch, nng):
679 if not op.run_on_npu or op.type != Op.StridedSlice:
680 return op
681
682 new_axis_mask = op.attrs["new_axis_mask"]
683 shrink_axis_mask = op.attrs["shrink_axis_mask"]
684
685 if shrink_axis_mask == 0 and new_axis_mask == 0:
686 return op
687
688 axis_4D = [0] * len(op.outputs)
689 for idx, out_tens in enumerate(op.outputs):
690 output_shape = list(out_tens.shape)
691
692 if shrink_axis_mask != 0:
693 n = 0
694 axis = 0
695 while shrink_axis_mask:
696 prev_mask = shrink_axis_mask
697 n += 1
698 shrink_axis_mask &= shrink_axis_mask - 1
699 axis = int(math.log2(prev_mask - shrink_axis_mask))
700 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
701
702 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
703 op.attrs["shrink_axis_mask"] = 0
704 if axis >= 0:
705 axis_4D[idx] = axis + (4 - len(output_shape))
706 else:
707 axis_4D[idx] = axis
708 op.ofm_shapes[idx] = Shape4D(output_shape)
709
710 elif new_axis_mask != 0:
711 n = 0
712 axis = 0
713 while new_axis_mask:
714 prev_mask = new_axis_mask
715 n += 1
716 new_axis_mask &= new_axis_mask - 1
717 axis = int(math.log2(prev_mask - new_axis_mask))
718 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
719 new_axis_mask >>= 1
720
721 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
722 op.attrs["new_axis_mask"] = 0
723 if axis >= 0:
724 axis_4D[idx] = axis + (4 - len(output_shape))
725 else:
726 axis_4D[idx] = axis
727 op.ofm_shapes[idx] = Shape4D(output_shape)
728
729 op.attrs["split_axis_4D"] = axis_4D
730 return op
731
732
733def rewrite_unpack_output(op, arch, nng):
734 tens = op.outputs[0]
735 if op.run_on_npu and op.type == Op.Unpack:
736 # Unpack is also referred to as Unstack
737 axis = int(op.attrs["axis"])
738 if axis < 0: # Convert to positive axis
739 axis = len(op.inputs[0].shape) + 1 + axis
740 op.type = Op.UnpackReshaped
741 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
742
743 axis_4D = axis + (4 - len(desired_output_shape))
744 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
745
746 for idx, out_tens in enumerate(op.outputs):
747 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
748 return op
749
750
751def add_padding_fields(op, arch, nng):
752 if op.run_on_npu:
753 if "padding" in op.attrs:
754 input_shape = op.ifm_shapes[0]
755 output_shape = op.ofm_shapes[0]
756 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
757 kernel_size = op.inputs[1].shape[:2]
758 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
759 kernel_size = op.attrs["ksize"][1:3]
760 else:
761 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
762
763 if op.type == Op.Conv2DBackpropInputSwitchedBias:
764 upscaling_factor = output_shape.height // input_shape.height
765 padding, skirt = calc_upscaled_padding_and_skirt(
766 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
767 )
768 else:
769 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200770 op.attrs["padding"],
771 op.kernel,
772 input_shape,
773 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200774 )
775
776 op.attrs["explicit_padding"] = padding
777 op.attrs["skirt"] = skirt
778
779 return op
780
781
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200782def reorder_depthwise_weights(op, arch, nng):
783 if op.type.is_depthwise_conv2d_op():
784 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100785 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
786 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200787 weight_tensor.weight_transpose_depthwise = True
788
789 return op
790
791
792def optimise_strided_conv(op, arch, nng):
Louis Verhaard43d27582022-03-17 14:06:00 +0100793 if op.type != Op.Conv2DBias or op.op_index != 0:
794 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200795 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100796 weight_tensor = op.weights
797 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200798
799 if (
Louis Verhaard43d27582022-03-17 14:06:00 +0100800 stride_x == 2
801 and ifm_shape.depth <= 4
802 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200803 and weight_tensor is not None
804 and weight_tensor.shape[1] >= 2
805 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100806 k_w, _ = op.get_kernel_size()
807 curr_padding_x = needed_total_padding(ifm_shape.width, 2, k_w)
808 optimised_padding_x = needed_total_padding(ifm_shape.width // 2, 1, (k_w + 1) // 2)
809 if curr_padding_x != optimised_padding_x:
810 # Horizontal padding would become different after optimisation; this would not work
811 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200812 # IFM
813 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
814
815 # Weights
816 weight_shape = weight_tensor.shape
817 if weight_shape[1] % 2 != 0:
818 weight_shape[1] = weight_shape[1] + 1
819 padded_array = np.zeros(weight_shape)
820 for i in range(weight_shape[0]):
821 padded_array[i] = np.vstack(
822 [
James Peet7519d502021-07-19 16:47:58 +0100823 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200824 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
825 ]
826 )
James Peet7519d502021-07-19 16:47:58 +0100827 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200828 weight_shape[1] //= 2
829 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100830 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200831 weight_tensor.set_all_shapes(weight_shape)
832 # If multiple copies of the weights are used, we could avoid
833 # them having the same address by changing the value_id
834 weight_tensor.value_id = uuid.uuid4()
835
836 # Strides
837 stride_x = 1
838 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
839
840 return op
841
842
843def convert_conv_to_fc(op, arch, nng):
844 # Conv 1x1 can be equivalent to Fully Connected.
845 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
846 # caching/double buffering for the weights.
847 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
848 if op.type == Op.Conv2DBias:
849 h = op.ifm_shapes[0].height
850 w = op.ifm_shapes[0].width
851 kh, kw, _, _ = op.inputs[1].shape
852 if h == 1 and w == 1 and kh == 1 and kw == 1:
853 # Overwrite this op as a Fully Connected Op
854 op.name += "_fc"
855 op.type = Op.FullyConnected
856 op.attrs = {
857 "weights_format": 0,
858 }
859 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
860 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100861 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
862 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200863
864 DebugDatabase.add_optimised(op, op)
865 return op
866
867
868def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
869 if op.run_on_npu and op.type.is_relu_op():
870 ifm = op.inputs[0]
871 ofm = op.outputs[0]
872 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
873 # and requires its own to be inserted
874 if not check_quantized_tens_scaling_equal(ifm, ofm):
875 # Override this op with its own primary op (avgpool)
876 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
877 # And fuse the original activation function to it
878 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200879 # Add explicit rescaling
880 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
881 multiplier, shift = scaling.quantise_scale(rescale)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200882 relu_fused_op.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200883 # Tidy up and assign the ifm and ofm to the new op
884 ifm.consumer_list.remove(op)
885
886 relu_fused_op.add_input_tensor(ifm)
887 relu_fused_op.set_output_tensor(ofm)
888 relu_fused_op.set_ifm_ofm_shapes()
889 op = relu_fused_op
890 return op
891
892
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200893def convert_softmax(op, arch, nng):
894 if op.type == Op.Softmax and op.run_on_npu:
895 softmax = SoftMax(op)
896 op = softmax.get_graph()
897 return op
898
899
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200900def convert_prelu(op, arch, nng):
901 if op.type == Op.Prelu:
902 ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
903 if None in (ifm, alpha, ofm):
904 return op
905
Fredrik Svedberg66591652022-08-29 10:51:27 +0200906 if alpha.values is not None:
907 # If const alpha check for possible optimisations
908 alpha_zp = alpha.quantization.zero_point
909 alpha_scale = alpha.quantization.scale_f32
910 # If all alpha values are the same the PReLU can be converted to LeakyRelu
911 alpha_min = (alpha.values.min().astype(np.int) - alpha_zp) * alpha_scale
912 alpha_max = (alpha.values.max().astype(np.int) - alpha_zp) * alpha_scale
913 if alpha_min == alpha_max:
914 # or even a Relu
915 if alpha_min == 0:
916 new_op = Op.Relu
917 else:
918 new_op = Op.LeakyRelu
919 op.attrs["alpha"] = alpha_min
920 # setup alpha_scaling for bit exact result
921 ifm_scale = ifm.quantization.scale_f32
922 ofm_scale = ofm.quantization.scale_f32
923 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha_scale, ofm_scale)
924 op.attrs["alpha_scaling"] = (alpha.values.min() - alpha_zp, alpha_scale, alpha_shift)
925 # Change op type
926 op.type = new_op
927 op.name = op.name.replace("Prelu", new_op.name)
928 del op.inputs[1] # Remove alpha tensor
929 return op
930 elif alpha_max < 1:
931 # If alpha_max is less than 1 convert PReLU to Max(alpha * IFM, identity * IFM)
932 # Multiply with alpha tensor
933 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
934 mul_alpha.add_input_tensor(ifm)
935 mul_alpha.add_input_tensor(alpha)
936 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
937 mul_alpha.set_output_tensor(fm_alpha)
938 mul_alpha.set_ifm_ofm_shapes()
939 DebugDatabase.add_optimised(op, mul_alpha)
940 if check_quantized_tens_scaling_equal(ifm, ofm):
941 # No scaling is needed
942 fm_id = ifm
943 else:
944 # Add multiplication with identity
945 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
946 mul_identity.add_input_tensor(ifm)
947 # Create const tensor containing identity as scalar
948 quantization = ifm.quantization.clone()
949 quantization.scale_f32 = np.float32(1)
950 quantization.zero_point = 0
951 one = create_const_tensor("one_const", [], ifm.dtype, [1], quantization=quantization)
952 mul_identity.add_input_tensor(one)
953 # Make sure that fm_id is allocated to a different address than fm_alpha
954 fm_id = ofm.clone(op.name + "_id", set_unique=True)
955 mul_identity.set_output_tensor(fm_id)
956 mul_identity.set_ifm_ofm_shapes()
957
958 # Combine scaled and alpha multiplied values
959 max_op = Operation(Op.Maximum, op.name + "_max")
960 max_op.add_input_tensor(fm_alpha)
961 max_op.add_input_tensor(fm_id)
962 max_op.set_output_tensor(ofm)
963 max_op.set_ifm_ofm_shapes()
964
965 DebugDatabase.add_optimised(op, max_op)
966 ifm.consumer_list.remove(op)
967 return max_op
968
969 # Catch all PReLU conversion for the cases that could not be optimised above
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200970 no_scale_quant = ifm.quantization.clone()
971 no_scale_quant.scale_f32 = None
972 no_scale_quant.zero_point = 0
Fredrik Svedberg66591652022-08-29 10:51:27 +0200973 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200974
975 # Select values < 0
976 min_op = Operation(Op.Minimum, op.name + "_min")
977 min_op.add_input_tensor(ifm)
978 min_op.add_input_tensor(zero)
979 fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
980 min_op.set_output_tensor(fm_negative)
981 min_op.set_ifm_ofm_shapes()
982 DebugDatabase.add_optimised(op, min_op)
983
984 # and multiply with alpha tensor
985 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
986 mul_alpha.add_input_tensor(fm_negative)
987 mul_alpha.add_input_tensor(alpha)
988 fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
989 mul_alpha.set_output_tensor(fm_alpha)
990 mul_alpha.set_ifm_ofm_shapes()
991 DebugDatabase.add_optimised(op, mul_alpha)
992
993 # Select (and scale) values > 0
994 relu_op = Operation(Op.Relu, op.name + "_relu")
995 relu_op.add_input_tensor(ifm)
996 fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
997 relu_op.set_output_tensor(fm_scaled)
998 relu_op.set_ifm_ofm_shapes()
999 DebugDatabase.add_optimised(op, relu_op)
1000
1001 # Add scaled and alpha multiplied values (without scaling)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001002 add_op = Operation(Op.Add, op.name + "_add")
1003 add_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001004 add_op.add_input_tensor(fm_alpha)
1005 add_op.add_input_tensor(fm_scaled)
1006 add_op.set_output_tensor(ofm)
1007 add_op.set_ifm_ofm_shapes()
1008
1009 DebugDatabase.add_optimised(op, add_op)
1010 ifm.consumer_list.remove(op)
1011 op = add_op
1012
1013 return op
1014
1015
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001016def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
1017 r"""Whenever there is a subgraph with this topology:
1018
Jonas Ohlssond8575072022-03-30 10:30:25 +02001019 Input X For X = -1 or X > 0
1020 | \ / This subgraph can be replaced with either
1021 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
1022 | /
1023 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001024 """
1025
1026 if op.type == Op.Maximum:
1027 # finds the Mul input(s) to the Max
1028 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
1029 if len(muls) == 1:
1030 mul = muls[0].ops[0]
1031 elif len(muls) == 2:
1032 # In the case both inputs are Muls, find the one with the same input as the Max
Fredrik Svedberg66591652022-08-29 10:51:27 +02001033 mul_ifms = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1]
1034 if len(mul_ifms):
1035 mul = mul_ifms[0].ops[0]
1036 else:
1037 # Not using same input
1038 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001039 else:
1040 # No Mul inputs
1041 return op
1042
1043 # make sure the Mul doesn't have any other consumers
1044 mul_ofm = mul.outputs[0]
1045 if len(mul_ofm.consumers()) != 1:
1046 return op
1047 # make sure the Mul doesn't have a fused activation function
1048 if mul.activation:
1049 return op
1050 ifm, ofm = op.get_ifm_ofm()
1051 if ifm is None or ofm is None:
1052 return op
1053
1054 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1055 return op
1056 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
1057 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
1058 return op
1059
1060 # finds the branched input that goes to both the Max and the Mul
1061 shared = set(op.inputs) & set(mul.inputs)
1062 if len(shared) == 1:
1063 shared_in = shared.pop()
1064 # find the constant scalar input to the Mul
1065 const_tens = (set(mul.inputs) - {shared_in}).pop()
1066 # check that it is a scalar
1067 if const_tens.shape != []:
1068 return op
1069 const = const_tens.ops[0]
1070 # check that it is a constant
1071 if const.type != Op.Const:
1072 return op
1073 # Remove the Mul from the shared input's consumers
1074 shared_in.consumer_list.remove(mul)
1075 else:
1076 return op
1077
1078 val = const.outputs[0].values
1079 if val >= 0:
1080 new_op = Op.LeakyRelu
1081 op.attrs["alpha"] = val
1082 # to produce bit exact results, the alpha is not enough;
1083 # save additional scaling info in attr "alpha_scale", to be used as input
1084 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +01001085 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001086 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
1087 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
1088 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
1089 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
1090 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
1091 elif val == -1:
1092 new_op = Op.Abs
1093 else:
1094 return op
1095
1096 op.type = new_op
1097 op.name = op.name.replace("Maximum", new_op.name)
1098 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
1099 op.inputs = [shared_in]
1100 op.set_ifm_ofm_shapes()
1101
1102 # Record optimisation in debug database
1103 DebugDatabase.add_optimised(op, op)
1104
1105 return op
1106
1107
1108def convert_hardswish_to_lut(op, arch, nng):
1109 if op.type == Op.HardSwish:
1110 ifm, ofm = op.get_ifm_ofm()
1111 # Generate the LUT
1112 ifm_scale = np.double(ifm.quantization.scale_f32)
1113 ofm_scale = np.double(ofm.quantization.scale_f32)
1114 zp_in = ifm.quantization.zero_point
1115 zp_out = ofm.quantization.zero_point
1116 ifm_scale_hires = (1 / 128) * ifm_scale
1117 relu_multiplier = np.double(3 / 32768)
1118 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
1119 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
1120 # Use 16bit scale
1121 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
1122 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
1123
1124 values = []
1125 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1126 quantized_min = min(ix)
1127 quantized_max = max(ix)
1128 for x in ix:
1129 input_value = x - zp_in
1130 input_value_hires = input_value * 128
1131 # Compute the input value on essentially the output scale, not shifted yet
1132 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
1133 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
1134 relu_value = np.int16(input_value_hires)
1135 if relu_shift < 31:
1136 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
1137
1138 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
1139
1140 if relu_shift < 31:
1141 relu_value = fp_math.shift_left16(relu_value, 1)
1142
1143 if relu_shift > 31:
1144 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
1145
1146 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
1147 # Now convert that to a 16bit fixedpoint value in [0, 1]
1148 relu_value = (relu_value + (1 << 15)) >> 1
1149 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
1150 shift = 31 - out_shift
1151 shift = -shift if shift < 0 else 0
1152 # Finally apply the output shift
1153 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
1154 lut_result = min(quantized_max, max(quantized_min, lut_result))
1155 values.append(lut_result)
1156 return convert_to_lut(op, values, "hardswish")
1157 return op
1158
1159
1160def convert_lrelu_to_mul_max(op, arch):
1161 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
1162 # (the opposite of convert_mul_max_to_abs_or_lrelu)
1163 ifm, ofm = op.get_ifm_ofm()
1164 if ifm is None or ofm is None:
1165 return op
1166
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001167 alpha = np.float32(op.attrs["alpha"])
1168 use_mul_max = 0 < alpha < 1
Fredrik Svedberg36424312022-09-16 09:39:26 +02001169 is_converted_prelu = "alpha_scaling" in op.attrs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001170 if use_mul_max:
1171 mul_ifm = ifm
1172 new_op = Op.Maximum
1173 else:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001174 # Need to use a different approach for alpha < 0 or alpha > 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001175 no_scale_quant = ifm.quantization.clone()
1176 no_scale_quant.scale_f32 = None
1177 no_scale_quant.zero_point = 0
1178 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
1179
1180 # Select values < 0
1181 min_op = Operation(Op.Minimum, op.name + "_min")
1182 min_op.add_input_tensor(ifm)
1183 min_op.add_input_tensor(zero)
1184 mul_ifm = ifm.clone(op.name + "_negative", set_unique=True)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001185 if alpha < 0 and not is_converted_prelu:
1186 # For negative alpha that is not from a converted PReLU we need to use
1187 # int32 Mul below to perform the (negative) alpha scaling
1188 mul_ifm.dtype = DataType.int32
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001189 min_op.set_output_tensor(mul_ifm)
1190 min_op.set_ifm_ofm_shapes()
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001191 new_op = Op.Add
1192 op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001193 DebugDatabase.add_optimised(op, min_op)
1194
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001195 # Add multiplication with alpha
1196 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001197 mul_alpha.add_input_tensor(mul_ifm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001198 # Create const tensor containing alpha as scalar
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001199 quantization = ifm.quantization.clone()
1200 quantization.min = 0
1201 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
1202 quantization.zero_point = 0
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001203 alpha_dtype = mul_ifm.dtype
Fredrik Svedberg36424312022-09-16 09:39:26 +02001204 if is_converted_prelu:
1205 # The LeakyRelu was the result from convert_prelu and the scaling is provided
Fredrik Svedberg66591652022-08-29 10:51:27 +02001206 scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001207 mul_alpha.explicit_scaling = ExplicitScaling(False, [alpha_shift], [alpha_scale])
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001208 elif alpha == 0 or np.isinf(1 / alpha):
1209 # Handling of alpha near or at zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001210 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001211 scalar = 0
1212 else:
1213 quantization.scale_f32 = alpha
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001214 if alpha_dtype == DataType.int32:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001215 # When the datatype is int32 (alpha negative) we need to do the scaling with the multiplication
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001216 scalar, _ = scaling.elementwise_mul_scale(ifm.quantization.scale_f32, alpha, ofm.quantization.scale_f32)
1217 else:
1218 scalar = 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001219 alpha_tens = create_const_tensor(
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001220 op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], alpha_dtype.as_numpy_type(), quantization=quantization
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001221 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001222 mul_alpha.add_input_tensor(alpha_tens)
1223 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1224 mul_alpha.set_output_tensor(fm_alpha)
1225 mul_alpha.set_ifm_ofm_shapes()
1226 DebugDatabase.add_optimised(op, mul_alpha)
1227
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001228 if not use_mul_max:
1229 relu_op = Operation(Op.Relu, op.name + "_relu")
1230 relu_op.add_input_tensor(ifm)
1231 fm_id = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1232 relu_op.set_output_tensor(fm_id)
1233 relu_op.set_ifm_ofm_shapes()
1234 DebugDatabase.add_optimised(op, relu_op)
1235 elif check_quantized_tens_scaling_equal(ifm, ofm):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001236 # No identity multiplication is needed
1237 fm_id = ifm
1238 else:
1239 # Add multiplication with identity
1240 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1241 mul_identity.add_input_tensor(ifm)
1242 # Create const tensor containing identity as scalar
1243 quantization = ifm.quantization.clone()
1244 quantization.min = 0
1245 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001246 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001247 quantization.zero_point = 0
1248 identity_tens = create_const_tensor(
1249 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
1250 )
1251 mul_identity.add_input_tensor(identity_tens)
1252 # Make sure that fm_id is allocated to a different address than fm_alpha
1253 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1254 mul_identity.set_output_tensor(fm_id)
1255 mul_identity.set_ifm_ofm_shapes()
1256 DebugDatabase.add_optimised(op, mul_identity)
1257
1258 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001259 op.type = new_op
1260 op.name = op.name.replace("LeakyRelu", new_op.name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001261 op.inputs = []
1262 ifm.consumer_list.remove(op)
1263 op.add_input_tensor(fm_alpha)
1264 op.add_input_tensor(fm_id)
1265 op.set_ifm_ofm_shapes()
1266
1267 DebugDatabase.add_optimised(op, op)
1268 return op
1269
1270
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001271def convert_to_lut8(op, fn, fn_name):
1272 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
1273 # fn is a function(real) -> real
1274 ifm, ofm = op.get_ifm_ofm()
1275 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1276 return op
1277 # Generate the LUT
1278 ifm_scale = np.double(ifm.quantization.scale_f32)
1279 ofm_scale = np.double(ofm.quantization.scale_f32)
1280 zp_in = ifm.quantization.zero_point
1281 zp_out = ofm.quantization.zero_point
1282 values = []
1283 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1284 quantized_min = min(ix)
1285 quantized_max = max(ix)
1286 for x in ix:
1287 x_real = ifm_scale * (x - zp_in)
1288 y_real = fn(x_real)
1289 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1290 lut_result = min(quantized_max, max(quantized_min, lut_result))
1291 values.append(lut_result)
1292 return convert_to_lut(op, values, fn_name)
1293
1294
1295def convert_lrelu_to_lut(op, arch):
1296 ifm, ofm = op.get_ifm_ofm()
1297 # Generate the LUT
1298 alpha = op.attrs["alpha"]
1299 ifm_scale = np.double(ifm.quantization.scale_f32)
1300 ofm_scale = np.double(ofm.quantization.scale_f32)
1301 zp_in = ifm.quantization.zero_point
1302 zp_out = ofm.quantization.zero_point
1303 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1304 alpha_scalar = 1
1305 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1306 if "alpha_scaling" in op.attrs:
1307 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1308 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1309 values = []
1310 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1311 quantized_min = min(ix)
1312 quantized_max = max(ix)
1313 for x in ix:
1314 if x < zp_in:
1315 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1316 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1317 )
1318 else:
1319 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1320 lut_result = min(quantized_max, max(quantized_min, lut_result))
1321 values.append(lut_result)
1322 return convert_to_lut(op, values, "lrelu")
1323
1324
1325def convert_lrelu(op, arch, nng):
1326 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
1327 if op.type != Op.LeakyRelu:
1328 return op
1329 ifm, ofm = op.get_ifm_ofm()
1330 if ifm is None or ofm is None:
1331 return op
Fredrik Svedberg36424312022-09-16 09:39:26 +02001332 alpha = op.attrs["alpha"]
1333 if alpha == 0:
1334 # When alpha is 0 the opertion can be converted to a ReLU
1335 op.type = Op.Relu
1336 op.name = op.name.replace("LeakyRelu", op.type.name)
1337 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001338 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1339 # use LUT for int8/uint8
1340 return convert_lrelu_to_lut(op, arch)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001341 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16 and alpha > 0:
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001342 # use LeakyRelu unmodified for int16 with equal input/output scaling and positive alpha
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001343 return op
1344 return convert_lrelu_to_mul_max(op, arch)
1345
1346
1347def convert_tanh_sigmoid_to_lut(op, arch, nng):
1348 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
1349 if op.type == Op.Sigmoid:
1350 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1351 elif op.type == Op.Tanh:
1352 return convert_to_lut8(op, math.tanh, "tanh")
1353 return op
1354
1355
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001356def remove_memory_only_ops(op, arch):
1357 if op.run_on_npu and op.type in memory_only_ops:
1358 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001359
1360
1361def fuse_activation_function_with_prev(op, arch, nng):
1362 # if op is a no-op: attempts to move the activation function to the preceding op
1363 if not op.attrs.get("is_nop", False) or op.activation is None:
1364 return op
1365 ifm, ofm = op.get_ifm_ofm()
1366 if ifm is None or ofm is None:
1367 return op
1368 # finds the input(s) to the operation
1369 prev_op = ifm.ops[0]
1370 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1371 fuse = (
1372 prev_op.run_on_npu
1373 and prev_op.type.npu_block_type != NpuBlockType.Default
1374 and len(ifm.ops) == 1
1375 and len(prev_op.outputs[0].consumers()) == 1
1376 and prev_op.activation is None
1377 )
1378 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1379 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1380 # LUT currently only works correctly for elementwise ops
1381 fuse = False
1382 if not fuse:
1383 return op
1384 # Move the fused activation function + corresponding info to prev_op
1385 prev_op.activation = op.activation
1386 prev_op.forced_output_quantization = op.forced_output_quantization
1387 if op.activation_lut is not None:
1388 prev_op.set_activation_lut(op.activation_lut)
1389 # Bypass op
1390 prev_op.set_output_tensor(ofm)
1391 DebugDatabase.add_optimised(op, prev_op)
1392 return op
1393
1394
1395def _leading_pad_ok(leading_pad, stride, kernel_size):
1396 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1397 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1398 max_size = kernel_size // 2
1399 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1400
1401
1402def replace_pad_by_hw_pad(op: Operation, arch, nng):
1403 """
1404 Tries to completely remove a PAD operator by using hardware padding.
1405 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1406 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1407 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1408 if both operations can be run on the NPU.
1409 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1410 """
1411 if (
1412 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +00001413 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001414 and op.run_on_npu
1415 and op.attrs["padding"] == Padding.VALID
1416 ):
1417 pad_op = op.ifm.ops[0]
1418 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1419 return op
1420 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1421 return op
1422 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1423 k = op.kernel
1424 k_w, k_h = k.dilated_wh()
1425
1426 # Check if the PAD operator can be replaced by hardware padding
1427 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1428 # Too much padding, it would require hardware padding to actually insert zeros
1429 return op
1430 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1431 return op
1432
1433 if op.type.is_avgpool_op():
1434 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1435 for pad, k_size in (
1436 (left, k_w),
1437 (right, k_w),
1438 (top, k_h),
1439 (bottom, k_h),
1440 ):
1441 if pad not in (0, k_size // 2):
1442 return op
1443 # Average pool is converted to depthwise, because NPU average pool + same padding
1444 # has a special implementation that is different from PAD followed by average pool with
1445 # valid padding.
1446 k_w, k_h = op.kernel.width, op.kernel.height
1447 ifm = op.ifm
1448 # Remember other inputs
1449 other_inputs = op.inputs[1:]
1450 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1451 quantization = QuantizationParameters(0.0, 255.0)
1452 quantization.scale_f32 = 1.0 / (k_w * k_h)
1453 quantization.zero_point = 0
1454 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1455 weights = np.full(shape, 1)
1456
1457 weight_tens = create_const_tensor(
1458 op.name + "_weights",
1459 shape,
1460 op.ifm.dtype,
1461 weights,
1462 np.uint8,
1463 purpose=TensorPurpose.Weights,
1464 quantization=quantization,
1465 )
James Peet7519d502021-07-19 16:47:58 +01001466 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001467 op.type = Op.DepthwiseConv2DBias
1468 op.inputs = []
1469 op.add_input_tensor(ifm)
1470 op.add_input_tensor(weight_tens)
1471 # Add bias tensor, all biases set to 0
1472 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001473 fixup_bias_tensors(op, arch, nng, DataType.int32)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001474 # Add other inputs
1475 op.inputs.extend(other_inputs)
1476 op.rounding_mode = NpuRoundingMode.NATURAL
1477
1478 # Bypass the PAD operator
1479 op.set_input_tensor(pad_op.ifm, 0)
1480 # Adjust the padding attributes of the convolution operator
1481 op.attrs["padding"] = Padding.EXPLICIT
1482 op.attrs["explicit_padding"] = (top, left, bottom, right)
1483 op.set_ifm_ofm_shapes()
1484 return op
1485
1486
1487def convert_pad(op: Operation, arch, nng):
1488 """
1489 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1490 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1491 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1492 """
1493 if op.type != Op.Pad or not op.run_on_npu:
1494 return op
1495 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1496
1497 ifm = op.ifm
1498 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001499 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001500 ofm = op.ofm
1501 assert ofm is not None
1502 ofm.ops = []
1503 ofm_shape = op.ofm_shapes[0]
1504
1505 # Average pool op that copies IFM to the right place inside the OFM
1506 shp0 = Shape4D(0, 0, 0, 0)
1507 shp_top = shp0.with_height(top)
1508 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1509 avgpool_op.activation = op.activation
1510 quant = ofm.quantization
1511 pad_value = quant.zero_point
1512 # Add operations that fill the borders of the OFM
1513 if top > 0:
1514 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1515 zero_tens = create_const_tensor(
1516 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1517 )
1518 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1519 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1520 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1521 if bottom > 0:
1522 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1523 zero_tens = create_const_tensor(
1524 op.name + "_bottom",
1525 shape.as_list(),
1526 ofm.dtype,
1527 shape.elements() * [pad_value],
1528 np.uint8,
1529 quantization=quant,
1530 )
1531 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1532 create_avg_pool_for_concat(
1533 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1534 )
1535 if left > 0:
1536 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1537 zero_tens = create_const_tensor(
1538 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1539 )
1540 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1541 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1542 if right > 0:
1543 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1544 zero_tens = create_const_tensor(
1545 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1546 )
1547 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1548 create_avg_pool_for_concat(
1549 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1550 )
1551
1552 op.type = Op.ConcatTFLite
1553 return avgpool_op
1554
1555
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001556def fixup_bias_tensors(op, arch, nng, dtype=None):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001557 if op.type.needs_bias() and op.bias is None:
1558 # Op has no bias, add bias tensor filled with zeros
1559 nr_biases = op.inputs[1].shape[-1]
1560 bias_values = [0] * nr_biases
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001561 # The DataType of the bias tensor can be explicitly provided or deduced from the ifm
1562 # DataType. Default is int32 bias for 8-bit ifms and int64 for int16 ifms.
1563 # For int16 the selected bias DataType will have an impact on the scaling
1564 # used when encoding the scales and biases later. The default mode will match the
1565 # refence with reduced scaling for int64 bias.
1566 # This means that in cases (in the graph optimiser) where DepthwiseConv2DBias
1567 # is used to emulate average pool int32 bias should be selected for full precision
1568 # int16 scaling.
1569 if dtype is None:
1570 dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32
1571 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001572 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1573
1574 return op
1575
1576
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001577def fixup_asymmetric_weights(op, arch, nng):
1578 if op.run_on_npu and (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op()):
1579 if op.ifm.dtype == DataType.int8:
1580 if not np.all(op.weights.quantization.zero_point == 0):
1581 print(f"Warning: {op.type} '{op.name}' has asymmetric weights, zero points have been adjusted.")
1582 op.weights.quantization.zero_point *= 0
1583
1584 return op
1585
1586
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001587def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1588 if op.type == Op.Mean and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001589 inp, axis = op.inputs
1590 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001591 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001592 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001593 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001594
1595 # Height and width axes have different index depending on dimensions
1596 if axis.shape == [] or axis.shape[0] == 1: # single axis
1597 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1598 if dims in (2, 3):
1599 if axis == 0:
1600 h, w = shape[axis], 1
1601 else:
1602 h, w = 1, shape[axis]
1603 else:
1604 if axis == 1:
1605 h, w = shape[axis], 1
1606 else:
1607 h, w = 1, shape[axis]
1608 else: # multiple axes
1609 axis = sorted(axis.values)
1610 h, w = [shape[i] for i in axis]
1611
1612 # Set necessary depthwise attributes
1613 op.attrs.update(
1614 {
1615 "padding": Padding.VALID,
1616 "stride_h": 1,
1617 "stride_w": 1,
1618 "strides": (1, 1, 1, 1),
1619 "depth_multiplier": 1,
1620 "channel_multiplier": 1,
1621 "dilation_h_factor": 1,
1622 "dilation_w_factor": 1,
1623 "dilation": (1, 1, 1, 1),
1624 }
1625 )
1626 # Change op type
1627 op.type = Op.DepthwiseConv2DBias
1628 # Set IFM/OFM shapes after changing op type
1629 op.set_ifm_ofm_shapes()
1630
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001631 weight_scale, bias = 1, 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001632 ofmq, ifmq = op.ofm.quantization, inp.quantization
Johan Alfvén9d51ec42022-10-27 16:30:01 +02001633 if ifmq.is_scaling_equal(ofmq):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001634 # Here we can just use a simple AvgPool with truncating rounding,
1635 # as we're emulating simple integer division.
1636 op.rounding_mode = NpuRoundingMode.TRUNCATE
1637 op.type = Op.AvgPool
1638 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1639 else:
1640 op.rounding_mode = NpuRoundingMode.NATURAL
1641 weight_scale = 1 / (h * w)
1642 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1643 bias = -ifmq.zero_point * h * w
1644 fiq = ifmq.clone()
1645 fiq.zero_point = 0
1646 op.forced_input_quantization = fiq
1647
1648 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001649 def extend_dims(dim, in_shape):
1650 if dim < 4:
1651 in_shape = [1] + in_shape
1652 if dim == 2:
1653 in_shape += [1]
1654 return in_shape
1655
1656 if dims < 4 or dims_ofm < 4:
1657 # Fix the ofm dimension when keep_dims is false
1658 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1659 if isinstance(axis, int) and dims_ofm + 1 == dims:
1660 ofm_shape.insert(axis, 1)
1661 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1662 for i in axis:
1663 ofm_shape.insert(i, 1)
1664 shape = extend_dims(dims, shape)
1665 dims_ofm = len(ofm_shape)
1666 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1667 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001668
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001669 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001670 weight_shape = None
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001671 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001672 # This can only happen and be done for multiple axes, and
1673 # h * w <= 256 for DepthwiseConv2DBias
1674 # h * w <= 4096 for AvgPool
1675 # which is checked in supported ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001676 shape = [shape[0], 1, h * w, shape[3]]
1677 op.ifm_shapes[0] = Shape4D(shape)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001678 weight_shape = [1, h * w, shape[3], shape[0]]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001679 if h > 256 and op.type == Op.AvgPool:
1680 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1681
1682 # If the AvgPool version is used, we don't need to do anything else
1683 if op.type == Op.AvgPool:
1684 return op
1685
1686 # Make unit weight tensor quantization
1687 weight_quant = ifmq.clone()
1688 weight_quant.min = 0
1689 weight_quant.max = 255
1690 weight_quant.scale_f32 = weight_scale
1691 weight_quant.zero_point = 0
1692
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001693 if weight_shape is None:
1694 # Set weight shape to [H,W,C,B]
1695 weight_shape = [h, w, shape[3], shape[0]]
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001696
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001697 # Add unit weight tensor
1698 op.set_input_tensor(
1699 create_const_tensor(
1700 "weights",
1701 weight_shape,
1702 inp.dtype,
1703 np.ones(weight_shape),
1704 value_dtype=np.uint8,
1705 quantization=weight_quant,
1706 ),
1707 1,
1708 )
James Peet7519d502021-07-19 16:47:58 +01001709 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001710
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001711 # Add bias tensor
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001712 bias_shape = [shape[-1]]
1713 op.inputs.append(create_const_tensor("bias", bias_shape, DataType.int32, np.ones(bias_shape) * bias))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001714
1715 return op
1716
1717
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001718def optimise_quantize(op: Operation, arch, nng):
1719
1720 if op.type == Op.Quantize and op.run_on_npu:
1721
1722 ifm, ofm = op.get_ifm_ofm()
1723 input_values = ifm.values
1724
1725 # Guard clause - input not const or no values to quantize
1726 if ifm.ops[0].type != Op.Const or input_values is None:
1727 return op
1728
1729 # Singular val in numpy array, convert to indexable array
1730 if input_values.ndim == 0:
1731 input_values = np.array([input_values])
1732
Fredrik Svedberg11563172022-07-06 14:54:12 +02001733 # requantized int8 to int8 or int16 to int16
1734 if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001735
1736 # scale needs to use double precision to match TFLite reference kernel
1737 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
1738 effective_multiplier, effective_shift = quantise_scale(effective_scale)
1739
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001740 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001741 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001742 input_val = val - ifm.quantization.zero_point
1743
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001744 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
1745 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001746
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001747 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
1748 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001749
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001750 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
1751 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001752
1753 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001754 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001755
1756 quantized_vals = []
1757 for val in input_values:
1758
1759 # Derive quantized value
1760 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001761 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
1762 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001763
1764 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001765 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
1766
1767 # Unsupported data type
1768 else:
1769 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001770
1771 # Make quantize op const and disconnect from parent node
1772
1773 # Remove reference of the current quant op from the parent tensor's consumer list
1774 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1775
1776 # Clear any references to parent node
1777 op.inputs = []
1778
1779 # Convert this quantize op to const
1780 op.type = Op.Const
1781
1782 return op
1783
1784
Ayaan Masood4965fae2022-06-29 11:30:57 +01001785def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
1786 """Static optimisation for SHAPE operator output value known at compile time"""
1787
1788 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
1789
1790 if op.type == Op.Shape and op.run_on_npu:
1791
1792 ifm, ofm = op.get_ifm_ofm()
1793
1794 if len(ifm.shape) != ofm.shape[0]:
1795 return op
1796
1797 # Remove reference of the current shape op from the parent tensor's consumer list
1798 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1799
1800 # Clear any references to parent node
1801 op.inputs = []
1802
1803 # Convert this SHAPE op to const
1804 op.type = Op.Const
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01001805 DebugDatabase.add_optimised(op, op)
Ayaan Masood4965fae2022-06-29 11:30:57 +01001806
1807 # Add size calculation to shape output tensors
1808 ofm.values = np.array(ifm.shape)
1809
1810 return op
1811
1812
Tim Hallea4ba662022-11-11 18:19:53 +00001813def fixup_dilation_gt2(op, arch, nng):
1814 assert op.run_on_npu
1815 if op.type == Op.Conv2DBias or op.type == Op.DepthwiseConv2DBias:
1816 dilation_w, dilation_h = op.get_kernel_dilation()
1817
1818 # if dilation in either axis is greater than that supported by the hardware then we must manually dilate the
1819 # kernel
1820 if dilation_w > 2 or dilation_h > 2:
1821 kernel_w, kernel_h = op.get_kernel_size()
1822 kernel_ic = op.weights.shape[-2]
1823 kernel_oc = op.weights.shape[-1]
1824
1825 # if the dilation is a multiple of 2 then the hardware dialtion can be enabled to provide that multiple
1826 # of 2. this allows the kernel size to be reduced (via the scaled dilation) by half in that dimension.
1827 # odd = 1, even = 2
1828 hw_dilation_h = 1 if (dilation_h & 1) else 2
1829 hw_dilation_w = 1 if (dilation_w & 1) else 2
1830
1831 scale_dilation_h = dilation_h // hw_dilation_h
1832 scale_dilation_w = dilation_w // hw_dilation_w
1833
1834 # create new empty kernel (HWIO format)
1835 new_kernel_h = (kernel_h - 1) * scale_dilation_h + 1
1836 new_kernel_w = (kernel_w - 1) * scale_dilation_w + 1
1837
1838 new_kernel_shape = [new_kernel_h, new_kernel_w, kernel_ic, kernel_oc]
1839 new_kernel_values = np.zeros(new_kernel_shape, dtype=op.weights.values.dtype)
1840
1841 # copy the original kernel values into the new sparse kernel
1842 for h in range(0, kernel_h):
1843 for w in range(0, kernel_w):
1844 new_h = h * scale_dilation_h
1845 new_w = w * scale_dilation_w
1846 new_kernel_values[new_h, new_w, :, :] = op.weights.values[h, w, :, :]
1847
1848 # update the weight tensor with the new dilated kernel
1849 op.weights.shape = new_kernel_shape
1850 op.weights.values = new_kernel_values
1851
1852 # enable(=2) / disable(=1) hardware dilation
1853 op.attrs["dilation"] = (1, hw_dilation_h, hw_dilation_w, 1) # nhwc format
1854 op.attrs["dilation_h_factor"] = hw_dilation_h
1855 op.attrs["dilation_w_factor"] = hw_dilation_w
1856
1857 return op
1858
1859
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001860def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001861 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001862 return op
1863
1864
1865def tflite_optimise_graph(nng, arch):
Fredrik Svedberg11563172022-07-06 14:54:12 +02001866 # Compile time static optimisations
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001867 optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
1868
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001869 for idx, sg in enumerate(nng.subgraphs):
1870 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001871 nng,
1872 sg,
1873 arch,
1874 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01001875 optimisation_list,
1876 rewrite_unsupported=False,
1877 )
1878
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001879 # Pre-processing step
1880 pre_process_list = [
1881 supported_operator_check,
1882 set_ifm_ofm_op_shapes,
1883 ]
1884
Ayaan Masood4965fae2022-06-29 11:30:57 +01001885 for idx, sg in enumerate(nng.subgraphs):
1886 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1887 nng,
1888 sg,
1889 arch,
1890 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02001891 pre_process_list,
1892 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001893 )
1894
1895 # Handle Concat Ops
1896 for idx, sg in enumerate(nng.subgraphs):
1897 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1898 sg.refresh_after_modification()
1899
1900 # Handle Split Ops
1901 for idx, sg in enumerate(nng.subgraphs):
1902 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1903 nng,
1904 sg,
1905 arch,
1906 [],
1907 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1908 rewrite_unsupported=False,
1909 )
1910
1911 for idx, sg in enumerate(nng.subgraphs):
1912 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001913 nng,
1914 sg,
1915 arch,
1916 [rewrite_split_ops],
1917 [],
1918 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001919 )
1920
1921 # Handle sg input output
1922 for idx, sg in enumerate(nng.subgraphs):
1923 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001924 nng,
1925 sg,
1926 arch,
1927 [],
1928 [fix_sg_input_output],
1929 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001930 )
1931
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001932 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001933 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001934 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001935 sg.refresh_after_modification()
1936
1937 # Rewrite of operators
1938 op_rewrite_list = [
1939 set_tensor_equivalence,
1940 convert_mean_to_depthwise_conv_or_avgpool,
1941 convert_depthwise_to_conv,
1942 convert_conv_to_fc,
1943 convert_softmax,
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001944 convert_prelu,
Fredrik Svedberg36424312022-09-16 09:39:26 +02001945 convert_mul_max_to_abs_or_lrelu,
1946 convert_lrelu,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001947 optimise_strided_conv,
1948 convert_hardswish_to_lut,
1949 rewrite_fully_connected_input,
1950 convert_batched_fc_shape,
1951 fixup_conv2d_backprop,
1952 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001953 reorder_depthwise_weights,
Tim Hall885033b2022-07-21 11:46:03 +01001954 fixup_resize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001955 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001956 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001957 convert_tanh_sigmoid_to_lut,
1958 replace_pad_by_hw_pad,
Tim Hallea4ba662022-11-11 18:19:53 +00001959 fixup_dilation_gt2,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001960 ]
1961
1962 for idx, sg in enumerate(nng.subgraphs):
1963 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001964 nng,
1965 sg,
1966 arch,
1967 [],
1968 op_rewrite_list,
1969 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001970 )
1971
1972 for idx, sg in enumerate(nng.subgraphs):
1973 # remove passthrough tensors and attempt further optimizations
1974 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1975 nng,
1976 sg,
1977 arch,
1978 [remove_passthrough_tensor],
1979 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1980 )
1981
1982 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1983 # since ifm/ofm_shapes are of importance to this function
1984 for sg in nng.subgraphs:
1985 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1986 sg.refresh_after_modification()
1987
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01001988 # Make sure that const optimisations on subgraph outputs are handled correctly
1989 for sg in nng.subgraphs:
1990 for ofm in sg.output_tensors:
1991 if ofm.is_const and ofm.ops[0].type_changed:
1992 # Subgraph output cannot be const - insert a memory copy
1993 op = ofm.ops[0]
1994 ofm_clone = ofm.clone()
1995 ofm_clone.values = ofm.values
1996 ofm.values = None
1997 np_dtype = ofm.dtype.as_numpy_type()
1998 zero = create_const_tensor("zero", [1], ofm.dtype, [0], np_dtype, quantization=ofm.quantization)
1999 memcpy = create_add_nop(f"{ofm.name}_copy")
2000 memcpy.add_input_tensor(ofm_clone)
2001 memcpy.add_input_tensor(zero)
2002 memcpy.set_output_tensor(ofm)
2003 memcpy.set_ifm_ofm_shapes()
2004 op.set_output_tensor(ofm_clone)
2005 DebugDatabase.add_optimised(op, memcpy)
2006
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002007 return nng