blob: 36c1de5a1cacc904fdd0f32aaa8eaa80aa25a297 [file] [log] [blame]
Rickard Bolinbc6ee582022-11-04 08:24:29 +00001# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020017# Description:
18# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
19# to do the traversal of the graph.
20import math
21import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020022
23import numpy as np
24
25from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020026from . import rewrite_graph
27from . import scaling
28from .api import NpuRoundingMode
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020029from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020030from .data_type import DataType
31from .debug_database import DebugDatabase
32from .errors import UnsupportedFeatureError
33from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020034from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020035from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020036from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020037from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020038from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020039from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020040from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020041from .graph_optimiser_util import needed_total_padding
42from .graph_optimiser_util import set_ifm_ofm_op_shapes
43from .graph_optimiser_util import set_tensor_equivalence
44from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020045from .numeric_util import round_away_zero
46from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020047from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020048from .operation import NpuBlockType
49from .operation import Op
50from .operation import Operation
51from .operation import Padding
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +010052from .operation_util import create_add_nop
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020053from .operation_util import create_avgpool_nop
54from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010055from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020056from .shape4d import Shape4D
57from .softmax import SoftMax
58from .tensor import check_quantized_tens_scaling_equal
59from .tensor import create_const_tensor
60from .tensor import create_equivalence_id
61from .tensor import QuantizationParameters
62from .tensor import Tensor
63from .tensor import TensorPurpose
64from .tflite_mapping import optype_to_builtintype
65
66passthrough_nodes = (Op.Identity,)
67
68
69def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
70 """Creates an average pool for the given concat op/input feature map"""
71 ofm = concat_op.ofm
72 avgpool_op = create_avgpool_nop(name)
73 avgpool_op.inputs = [ifm]
74 avgpool_op.outputs = [ofm]
75
76 avgpool_op.write_offset = write_offset
77 avgpool_op.write_shape = ifm_shape
78 ofm.ops.append(avgpool_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020079 avgpool_op.ifm_shapes.append(ifm_shape)
80 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
81 avgpool_op.memory_function = Op.ConcatSliceWrite
wilisa0179a89042022-11-02 17:18:43 +000082 DebugDatabase.add_optimised(concat_op, avgpool_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020083 return avgpool_op
84
85
86def remove_passthrough_tensor(tens, arch, nng):
87 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
88 assert len(tens.ops[0].inputs) == 1
89 tens = tens.ops[0].inputs[0]
90 return tens
91
92
93def rewrite_concat_ops(op, arch):
94 if not op.run_on_npu or not op.type.is_concat_op():
95 return
96
97 axis_4D = 0
98 ofm = op.ofm
99 ofm.ops = []
100 offset = 0
101
102 unfuse_activation_function(op)
103
104 if op.type == Op.Pack:
105 # Pack is also referred to as Stack
106 axis = int(op.attrs["axis"])
107 if axis < 0: # Convert to positive axis
108 axis = len(op.inputs[0].shape) + 1 + axis
109
110 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
111
112 axis_4D = axis + (4 - len(desired_shape))
113
114 for idx, inp in enumerate(op.inputs):
115 op.ifm_shapes[idx] = Shape4D(desired_shape)
116 op.type = Op.PackReshaped
117
118 inputs, axis = op.get_concat_inputs_axis()
119 for idx, inp in enumerate(inputs):
120 if op.type != Op.PackReshaped:
121 op.ifm_shapes[idx] = Shape4D(inp.shape)
122 if axis >= 0:
123 axis_4D = axis + (4 - len(inp.shape))
124 else:
125 axis_4D = axis
126 write_offset = [0, 0, 0, 0]
127 write_offset[axis_4D] = offset
128 concat_end = offset + op.ifm_shapes[idx][axis_4D]
129 create_avg_pool_for_concat(
130 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
131 )
132 offset = concat_end
133 assert ofm.shape[axis] == offset
134
135 return op
136
137
138def rewrite_split_ops(tens, arch, nng):
139
140 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
141 split_op = tens.ops[0]
142
143 # Not supported so leave it and run on CPU
144 if not split_op.run_on_npu:
145 return tens
146
147 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
148
149 tens.ops = []
150 new_op = Operation(Op.SplitSliceRead, split_op.name)
151 new_op.inputs = [inp]
152 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000153 if None in (offset_end, offset_start):
154 read_shape = None
155 else:
156 # the read shape is relative to each start offset
157 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200158
159 # For Split the offset cannot be extracted from the tensor so it has to
160 # be calculated from the index of the output tensor
161 if axis is not None:
162 # Get the start and end of the split
163 offset_start = [0] * 4
164 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
165 for idx, out in enumerate(outputs):
166 if axis_4D_list is not None:
167 axis_4D = axis_4D_list[idx]
168 else:
169 split_op.ofm_shapes[idx] = Shape4D(out.shape)
170 if axis >= 0:
171 axis_4D = axis + (4 - len(out.shape))
172 else:
173 axis_4D = axis
174
175 if out == tens:
176 ofm_shape_idx = idx
177 read_shape = split_op.ofm_shapes[idx]
178 break
179
180 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
181
182 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
183 new_op.read_shapes[0] = read_shape
184 new_op.run_on_npu = True
185 new_op.set_output_tensor(tens)
186 new_op.ifm_shapes.append(Shape4D(inp.shape))
187 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
188 DebugDatabase.add_optimised(split_op, new_op)
189
190 return tens
191
192
193def remove_SplitSliceRead(op, arch):
194
195 if op.type == Op.SplitSliceRead:
196 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
197 if (
198 len(op.ofm.consumer_list) == 1
199 and op.ofm.consumer_list[0] is not None
200 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200201 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200202 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
203 ):
204 # SplitSliceRead can be performed by tensor consumer
205 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200206 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200207 else:
208 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
209 avgpool_op.add_input_tensor(op.ifm)
210 avgpool_op.outputs = [op.ofm]
211 op.ofm.ops.remove(op)
212 op.ofm.ops.append(avgpool_op)
213 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
214 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
215 avgpool_op.read_offsets[0] = op.read_offsets[0]
216 avgpool_op.read_shapes[0] = op.read_shapes[0]
217
218 op.ifm.consumer_list.remove(op)
219 DebugDatabase.add_optimised(op, avgpool_op)
220
221
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200222def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
223 k_w, k_h = kernel.dilated_wh()
224 s_x, s_y = kernel.stride
225 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
226 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
227 if padding_type == Padding.SAME:
228 left_pad = (xpad + 0) // 2
229 right_pad = (xpad + 1) // 2
230 top_pad = (ypad + 0) // 2
231 bottom_pad = (ypad + 1) // 2
232 elif padding_type == Padding.VALID:
233 left_pad = 0
234 right_pad = 0
235 top_pad = 0
236 bottom_pad = 0
237 elif padding_type == Padding.EXPLICIT:
238 # Padding is specified in a PAD operator which has been bypassed.
239 top, left, bottom, right = explicit_padding
240 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
241 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Rickard Bolin9ae34552022-06-09 13:07:17 +0000242 elif padding_type == Padding.TILE:
243 # The values in the explicit padding only represent the "direction" in which to pad
244 top_pad, left_pad, bottom_pad, right_pad = explicit_padding
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200245 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000246 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200247 padding = (top_pad, left_pad, bottom_pad, right_pad)
248 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
249 return padding, skirt
250
251
252def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
253 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
254 if padding_type == Padding.SAME:
255 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
256 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
257 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
258 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
259 left_pad = max(kernel_width - 1 - right_pad, 0)
260 top_pad = max(kernel_height - 1 - bottom_pad, 0)
261 elif padding_type == Padding.VALID:
262 right_pad = max(kernel_width - 2, 0)
263 bottom_pad = max(kernel_height - 2, 0)
264 left_pad = kernel_width - 1
265 top_pad = kernel_height - 1
266 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000267 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200268 padding = (top_pad, left_pad, bottom_pad, right_pad)
269 skirt = padding
270 return padding, skirt
271
272
273def fixup_conv2d_backprop(op, arch, nng):
274 if op.type == Op.Conv2DBackpropInput:
275 # flip the inputs
276 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
277 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000278 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200279
280 # Update strides
281 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
wilisa0179a89042022-11-02 17:18:43 +0000282 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200283
284 return op
285
286
287# Convert the op to an elementwise add
Tim Hall885033b2022-07-21 11:46:03 +0100288def convert_resize_1x1_to_add(op):
289 op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200290 op.name = op.name + "_add"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200291 # Create an input tensor filled with zeros
292 shape = op.ofm_shapes[0].as_list()
293 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100294 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200295 tens.quantization = QuantizationParameters(0.0, 255.0)
296 tens.quantization.scale_f32 = 1.0
297 tens.quantization.zero_point = 0
298 tens.consumer_list = [op]
299 tens_op = op.inputs[1].ops[0]
300 tens_op.set_output_tensor(tens)
301 # Set the add inputs
302 op.inputs[1] = op.inputs[0]
303 op.inputs[0] = tens
304 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000305 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200306
307 return op
308
309
Tim Hall885033b2022-07-21 11:46:03 +0100310# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
311# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
312# to select the appropriate nearest neighbor value
313def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
314 ifm = op.ifm
315 ofm = op.ofm
316 output_depth = ofm.shape[-1]
317 dw_op_attrs = {
318 "padding": Padding.VALID,
319 "stride_h": 1,
320 "stride_w": 1,
321 "strides": (1, 1, 1, 1),
322 "depth_multiplier": 1,
323 "channel_multiplier": 1,
324 "dilation_h_factor": 1,
325 "dilation_w_factor": 1,
326 "dilation": (1, 1, 1, 1),
327 }
328
329 # change resizebilinear to depthwise
330 op.type = Op.DepthwiseConv2DBias
331 op.attrs.update(dw_op_attrs)
332 op.set_input_tensor(ifm, 0) # ifm tensor index
333 op.activation = None
334
335 # add input resample to resize by x2
336 op.ifm_resampling_mode = resampling_mode.NEAREST
337
338 # don't care about the rounding mode as it is nearest neighbor
339
340 # setup weight tensor
341 weight_quant = QuantizationParameters()
342 weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
343 weight_quant.zero_point = 0
344 weight_quant.quant_dim = 0
345 ofm_dtype = ofm.dtype
346 if ofm_dtype == DataType.uint8:
347 weight_value_dtype = np.uint8
348 weight_quant.quant_min = 0
349 weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
350 else:
351 if ofm_dtype == DataType.int8:
352 weight_value_dtype = np.int8
353 else:
354 assert ofm_dtype == DataType.int16
355 weight_value_dtype = np.int16
356
357 weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
358 weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
359
360 weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
361
362 # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
363 # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
364 # below-and-right (i.e. next) to it (D).
365 # 0---1---2
366 # | A | B |
367 # 1---*---+
368 # | C | D |
369 # 2---+---+
370 weight_values = [0] * (upscale_factor * upscale_factor)
371 centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
372 weight_values[centre_coeff] = 1
373
374 # add weight tensor, this will discard the size tensor of the resize op
375 op.set_input_tensor(
376 create_const_tensor(
377 "weights",
378 weight_shape,
379 ofm.dtype,
380 np.array(weight_values).reshape(weight_shape),
381 value_dtype=weight_value_dtype,
382 quantization=weight_quant,
383 ),
384 1, # inputs tensor weight index
385 )
386
387 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
388 # need to append the bias tensor as resize ops only have 2 inputs
389 assert len(op.inputs) == 2
390 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +0200391 fixup_bias_tensors(op, None, None, DataType.int32)
Tim Hall885033b2022-07-21 11:46:03 +0100392
393 # finally update the shape incase we've change the tensor shapes or connections
394 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000395 DebugDatabase.add_optimised(op, op)
Tim Hall885033b2022-07-21 11:46:03 +0100396
397 return op
398
399
400# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
401# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
402# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
403def convert_resize_to_upscale_and_average_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200404 pre_op = op
405 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000406 dtype = op.ifm.dtype
Tim Hall885033b2022-07-21 11:46:03 +0100407
Rickard Boline546def2022-01-25 15:45:00 +0000408 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Tim Hall47c76362022-07-18 21:26:47 +0100409 op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
Tim Hall3c5cfe92022-03-16 16:31:57 +0000410 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200411
412 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
Tim Hall47c76362022-07-18 21:26:47 +0100413
414 # Get upscale factor that was calculated in the supported operators check
415 upscale_factor = op.attrs["upscale_factor"]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200416
Rickard Boline546def2022-01-25 15:45:00 +0000417 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100418 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
419 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
Rickard Boline546def2022-01-25 15:45:00 +0000420 n = int(np.log2(upscale_factor))
421
Tim Hall885033b2022-07-21 11:46:03 +0100422 # Perform x2 upscaling n-1 times
Rickard Boline546def2022-01-25 15:45:00 +0000423 scaled_op = pre_op
424 for count in range(n - 1):
425 if count > 0:
426 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200427 scaled_op.inputs[0] = pre_op.outputs[0]
428
Tim Hall885033b2022-07-21 11:46:03 +0100429 # Nearest neighbor x2 upscaling
Tim Hall47c76362022-07-18 21:26:47 +0100430 upscaled_shape = upscaled_shape * 2
Rickard Boline546def2022-01-25 15:45:00 +0000431 shape = op.ofm_shapes[0].as_list()
432 shape[1:3] = upscaled_shape
433 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
434 out_tens.quantization = op.outputs[0].quantization.clone()
435 scaled_op.set_output_tensor(out_tens)
436 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200437
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200438 scaled_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000439 DebugDatabase.add_optimised(op, scaled_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200440
Tim Hall885033b2022-07-21 11:46:03 +0100441 # Last x2 upscaling
Rickard Boline546def2022-01-25 15:45:00 +0000442 if n > 1:
443 scaled_op = op.clone(f"_{n-1}")
444 scaled_op.inputs[0] = pre_op.outputs[0]
Tim Hall885033b2022-07-21 11:46:03 +0100445
446 if scaled_op.original_type == Op.ResizeBilinear:
447 if scaled_op.attrs["align_corners"]:
448 # no padding
449 scaled_op.attrs["padding"] = Padding.VALID
450 else:
451 # padding to the right and bottom (limits average pool to 8x8 kernel)
452 scaled_op.attrs["padding"] = Padding.EXPLICIT
453 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
454
455 # kernal size dependent on the upscaling factor
456 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
457 else: # Op.ResizeNearestNeighbor
458 if scaled_op.attrs["align_corners"]:
459 # use depthwise conv to select the correct value
460 scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
461 else:
Johan Alfvéna64616c2022-10-17 12:29:12 +0200462 # Keep 1x1 kernel and average pool, this applies both when
463 # half-pixel-centers is True and False. Calculations are the
464 # same in the reference.
Tim Hall885033b2022-07-21 11:46:03 +0100465 pass
466
Rickard Boline546def2022-01-25 15:45:00 +0000467 scaled_op.outputs = outputs
468 scaled_op.outputs[0].ops = [scaled_op]
469 scaled_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000470 DebugDatabase.add_optimised(op, scaled_op)
Rickard Boline546def2022-01-25 15:45:00 +0000471
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200472 return op
473
474
Rickard Bolinfea15162022-07-04 16:19:16 +0000475def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
476 def _compute_interpolation_values(index, input_size, output_size):
477 scale = input_size / output_size
478 scaled_value = (index + 0.5 * half_pixel_centers) * scale - 0.5 * half_pixel_centers
479 lower_bound = max(np.floor(scaled_value), 0)
480
481 return scaled_value, lower_bound
482
483 def _compute_kernels(input_height, input_width, output_height, output_width):
484 kernels = []
485 for y in (1, 2):
486 for x in (1, 2):
487 sv_h, lb_h = _compute_interpolation_values(y, input_height, output_height)
488 sv_w, lb_w = _compute_interpolation_values(x, input_width, output_width)
489
490 # Interpolation values calculated for (x, y) = ([1, 2], [1, 2]) will always generalize to the whole
491 # input for upscale = 2 and input sizes >= 2x2 and be in the correct order for going left-to-right,
492 # top-to-bottom - same as the depthwise convolution strides across each tile
493 kernel = np.zeros((2, 2))
494 kernel[1, 1] = (1 - (sv_h - lb_h)) * (1 - (sv_w - lb_w))
495 kernel[0, 1] = (sv_h - lb_h) * (1 - (sv_w - lb_w))
496 kernel[1, 0] = (1 - (sv_h - lb_h)) * (sv_w - lb_w)
497 kernel[0, 0] = (sv_h - lb_h) * (sv_w - lb_w)
498 kernel *= 16
499 kernels.append(kernel)
500
501 return kernels
502
503 def _build_convolutions(op, kernels):
504 dw_op_attrs = {
505 "padding": Padding.TILE,
506 "stride_h": 1,
507 "stride_w": 1,
508 "strides": (1, 1, 1, 1),
509 "depth_multiplier": 1,
510 "channel_multiplier": 1,
511 "dilation_h_factor": 1,
512 "dilation_w_factor": 1,
513 "dilation": (1, 1, 1, 1),
514 }
515 ifm = op.ifm
516 ofm = op.ofm
517 ofm.ops = []
518 elem_size = 2 if ofm.dtype == DataType.int16 else 1
519
520 n, h, w, c = ifm.shape
521 _, _, ow, _ = ofm.shape
522
523 intermediate_tens = Tensor(ifm.shape, ifm.dtype, "intermediate_tens")
524 intermediate_tens.quantization = op.outputs[0].quantization.clone()
525 avgpool_op = op
526 avgpool_op.name = "rb_init_avgpool"
527 avgpool_op.type = Op.AvgPool
528 avgpool_op.attrs["padding"] = Padding.VALID
529 avgpool_op.attrs["stride_w"] = 1
530 avgpool_op.attrs["stride_h"] = 1
531 avgpool_op.attrs["filter_width"] = 1
532 avgpool_op.attrs["filter_height"] = 1
533 avgpool_op.attrs["strides"] = [1, 1, 1, 1]
534 avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
535
536 avgpool_op.add_input_tensor(ifm)
537 avgpool_op.set_output_tensor(intermediate_tens)
538 avgpool_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000539 DebugDatabase.add_optimised(op, op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000540
541 dw_conv = Operation(Op.DepthwiseConv2DBias, "depthwise_conv")
542 dw_conv._original_type = Op.ResizeBilinear
543 dw_conv.write_shape = Shape4D(n, h, w, c)
544 dw_conv.write_offset = Shape4D(0, 0, 0, 0)
545
546 # Set the output rounding mode. Resize bilinear requires rounding away from zero. Therefore, we need to
547 # adjust the accumulated value by a "small" amount before applying natural rounding. The "small" amount
548 # should be big enough to cause a x.5 to be rounded correctly but small enough not to cause smaller
549 # values to be incorrectly rounded
550 ofm.quantization.next_after = True
551 dw_conv.rounding_mode = NpuRoundingMode.NATURAL
552
553 # Double height and width stride to write the output of each of the four depthwise convolutions below
554 # interleaved with each other when combined with OFM tile base offsets.
555 dw_conv.ofm_stride_multiplier = [1, 2, 2] # C/H/W
556
557 # Choose tile padding direction - pad by 1 with edge values in two direction.
558 # For example, TL (top left) will pad top and left in H/W-plane in all channels.
559 directions = [[1, 1, 0, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 1]] # TL, TR, BL, BR
560 for i in (0, 1):
561 for j in (0, 1):
562 index = i * 2 + j
563 dw_conv.name = f"depthwise_conv_{index}"
564 dw_op_attrs["explicit_padding"] = directions[index]
565 dw_conv.attrs.update(dw_op_attrs)
566
567 # This will offset the start of the write by modifying the Tile 0 base address
568 dw_conv.tile_base_offsets_ofm[0] = (i * ow + j) * c * elem_size
569
570 ofm.ops.append(dw_conv)
571 dw_conv.outputs = [ofm]
572
573 kernel = kernels[index]
574 shape = [2, 2, 1, c]
575 kernel = np.dstack([kernel] * c)
576
577 quant = QuantizationParameters()
578 quant.zero_point = 0
579 quant.scale_f32 = 1.0 / 16
580
581 dw_conv.inputs = []
582 dw_conv.add_input_tensor(intermediate_tens)
583 dw_conv.add_input_tensor(
584 create_const_tensor(
585 "weights",
586 shape,
587 intermediate_tens.dtype,
588 np.array(kernel).reshape(shape),
589 value_dtype=np.int8,
590 quantization=quant,
591 ),
592 )
593
594 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
595 # need to append the bias tensor as resize ops only have 2 inputs
596 assert len(dw_conv.inputs) == 2
597 dw_conv.inputs.append(None)
Rickard Bolin017b4cc2022-09-23 10:16:48 +0000598 fixup_bias_tensors(dw_conv, None, None, dtype=DataType.int32)
Rickard Bolinfea15162022-07-04 16:19:16 +0000599
600 dw_conv.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000601 DebugDatabase.add_optimised(op, dw_conv)
602
Rickard Bolinfea15162022-07-04 16:19:16 +0000603 dw_conv = dw_conv.clone(f"_{index}")
604 return op
605
606 _, input_height, input_width, _ = op.ifm.shape
607 _, output_height, output_width, _ = op.ofm.shape
608
609 kernels = _compute_kernels(input_height, input_width, output_height, output_width)
610 op = _build_convolutions(op, kernels)
611
612 return op
613
614
Tim Hall885033b2022-07-21 11:46:03 +0100615def fixup_resize(op, arch, nng):
616 if op.type.is_resize_op() and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200617 if op.ifm_shapes[0] == op.ofm_shapes[0]:
Tim Hall885033b2022-07-21 11:46:03 +0100618 # Bypass the resize op which is essentially a NOP
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200619 op.inputs = op.inputs[:1]
620 op.type = Op.Identity
621 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
Tim Hall885033b2022-07-21 11:46:03 +0100622 convert_resize_1x1_to_add(op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000623 elif op.type == Op.ResizeBilinear and op.attrs.get("half_pixel_centers", False):
624 convert_resizebilinear_to_depthwise_convolutions(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200625 else:
Tim Hall885033b2022-07-21 11:46:03 +0100626 convert_resize_to_upscale_and_average_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200627
628 return op
629
630
631def convert_nop_split_to_identity(op, arch, nng):
632 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
633 # the list comprehension should return a list with a single tensor
634 # if it shouldn't, remove_passthrough_tensor will fail appropriately
635 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
636 op.type = Op.Identity
637 return op
638
639
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100640def rewrite_fully_connected_input(op: Operation, arch, nng):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200641
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100642 if op.type == Op.FullyConnected:
643 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
644 assert new_shape is not None, "Tensor can not be reshaped to 2D"
645 op.ifm_shapes[0] = new_shape
Johan Alfvén65835e02022-10-13 10:49:30 +0200646
647 if op.ifm_shapes[0].batch > 1 and op.ofm_shapes[0].batch == 1:
648 # If IFM is batching then also make sure OFM is batching
649 h, w = op.ofm_shapes[0].height, op.ofm_shapes[0].width
650 op.ofm_shapes[0] = Shape4D([h * w, 1, 1, op.ofm_shapes[0].depth])
651
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200652 return op
653
654
655def convert_batched_fc_shape(op, arch, nng):
656 if op.type == Op.FullyConnected:
657 # Check if the first dimension indicates batching
658 if op.ifm_shapes[0].batch > 1:
659 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
660 n = op.ifm_shapes[0].batch
661 h, w = batching_split.get(n, (1, n))
662 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
663
664 # Reshape Weights to be 4D. IO becomes HWIO
665 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100666 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
667 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200668
669 n = op.ofm_shapes[0].batch
670 h, w = batching_split.get(n, (1, n))
671 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
672 return op
673
674
675def unfuse_activation_function(op):
676 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
677 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
678 op.activation = None
679 out_tens = op.outputs[0]
680 intermediate_tens = out_tens.clone("_act_intermediate")
681 act_op.set_output_tensor(out_tens)
682 act_op.add_input_tensor(intermediate_tens)
683 op.set_output_tensor(intermediate_tens)
684 act_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000685 DebugDatabase.add_optimised(op, act_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200686
687
688def rewrite_stridedslice_output(op, arch, nng):
689 if not op.run_on_npu or op.type != Op.StridedSlice:
690 return op
691
692 new_axis_mask = op.attrs["new_axis_mask"]
693 shrink_axis_mask = op.attrs["shrink_axis_mask"]
694
695 if shrink_axis_mask == 0 and new_axis_mask == 0:
696 return op
697
698 axis_4D = [0] * len(op.outputs)
699 for idx, out_tens in enumerate(op.outputs):
700 output_shape = list(out_tens.shape)
701
702 if shrink_axis_mask != 0:
703 n = 0
704 axis = 0
705 while shrink_axis_mask:
706 prev_mask = shrink_axis_mask
707 n += 1
708 shrink_axis_mask &= shrink_axis_mask - 1
709 axis = int(math.log2(prev_mask - shrink_axis_mask))
710 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
711
712 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
713 op.attrs["shrink_axis_mask"] = 0
714 if axis >= 0:
715 axis_4D[idx] = axis + (4 - len(output_shape))
716 else:
717 axis_4D[idx] = axis
718 op.ofm_shapes[idx] = Shape4D(output_shape)
719
720 elif new_axis_mask != 0:
721 n = 0
722 axis = 0
723 while new_axis_mask:
724 prev_mask = new_axis_mask
725 n += 1
726 new_axis_mask &= new_axis_mask - 1
727 axis = int(math.log2(prev_mask - new_axis_mask))
728 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
729 new_axis_mask >>= 1
730
731 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
732 op.attrs["new_axis_mask"] = 0
733 if axis >= 0:
734 axis_4D[idx] = axis + (4 - len(output_shape))
735 else:
736 axis_4D[idx] = axis
737 op.ofm_shapes[idx] = Shape4D(output_shape)
738
739 op.attrs["split_axis_4D"] = axis_4D
740 return op
741
742
743def rewrite_unpack_output(op, arch, nng):
744 tens = op.outputs[0]
745 if op.run_on_npu and op.type == Op.Unpack:
746 # Unpack is also referred to as Unstack
747 axis = int(op.attrs["axis"])
748 if axis < 0: # Convert to positive axis
749 axis = len(op.inputs[0].shape) + 1 + axis
750 op.type = Op.UnpackReshaped
751 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
752
753 axis_4D = axis + (4 - len(desired_output_shape))
754 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
755
756 for idx, out_tens in enumerate(op.outputs):
757 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
758 return op
759
760
761def add_padding_fields(op, arch, nng):
762 if op.run_on_npu:
763 if "padding" in op.attrs:
764 input_shape = op.ifm_shapes[0]
765 output_shape = op.ofm_shapes[0]
766 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
767 kernel_size = op.inputs[1].shape[:2]
768 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
769 kernel_size = op.attrs["ksize"][1:3]
770 else:
771 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
772
773 if op.type == Op.Conv2DBackpropInputSwitchedBias:
774 upscaling_factor = output_shape.height // input_shape.height
775 padding, skirt = calc_upscaled_padding_and_skirt(
776 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
777 )
778 else:
779 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200780 op.attrs["padding"],
781 op.kernel,
782 input_shape,
783 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200784 )
785
786 op.attrs["explicit_padding"] = padding
787 op.attrs["skirt"] = skirt
788
789 return op
790
791
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200792def reorder_depthwise_weights(op, arch, nng):
793 if op.type.is_depthwise_conv2d_op():
794 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100795 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
796 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200797 weight_tensor.weight_transpose_depthwise = True
798
799 return op
800
801
802def optimise_strided_conv(op, arch, nng):
Louis Verhaard43d27582022-03-17 14:06:00 +0100803 if op.type != Op.Conv2DBias or op.op_index != 0:
804 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200805 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100806 weight_tensor = op.weights
807 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200808
809 if (
Louis Verhaard43d27582022-03-17 14:06:00 +0100810 stride_x == 2
811 and ifm_shape.depth <= 4
812 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200813 and weight_tensor is not None
814 and weight_tensor.shape[1] >= 2
815 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100816 k_w, _ = op.get_kernel_size()
817 curr_padding_x = needed_total_padding(ifm_shape.width, 2, k_w)
818 optimised_padding_x = needed_total_padding(ifm_shape.width // 2, 1, (k_w + 1) // 2)
819 if curr_padding_x != optimised_padding_x:
820 # Horizontal padding would become different after optimisation; this would not work
821 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200822 # IFM
823 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
824
825 # Weights
826 weight_shape = weight_tensor.shape
827 if weight_shape[1] % 2 != 0:
828 weight_shape[1] = weight_shape[1] + 1
829 padded_array = np.zeros(weight_shape)
830 for i in range(weight_shape[0]):
831 padded_array[i] = np.vstack(
832 [
James Peet7519d502021-07-19 16:47:58 +0100833 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200834 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
835 ]
836 )
James Peet7519d502021-07-19 16:47:58 +0100837 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200838 weight_shape[1] //= 2
839 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100840 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200841 weight_tensor.set_all_shapes(weight_shape)
842 # If multiple copies of the weights are used, we could avoid
843 # them having the same address by changing the value_id
844 weight_tensor.value_id = uuid.uuid4()
845
846 # Strides
847 stride_x = 1
848 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
849
850 return op
851
852
853def convert_conv_to_fc(op, arch, nng):
854 # Conv 1x1 can be equivalent to Fully Connected.
855 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
856 # caching/double buffering for the weights.
857 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
858 if op.type == Op.Conv2DBias:
859 h = op.ifm_shapes[0].height
860 w = op.ifm_shapes[0].width
861 kh, kw, _, _ = op.inputs[1].shape
862 if h == 1 and w == 1 and kh == 1 and kw == 1:
863 # Overwrite this op as a Fully Connected Op
864 op.name += "_fc"
865 op.type = Op.FullyConnected
866 op.attrs = {
867 "weights_format": 0,
868 }
869 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
870 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100871 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
872 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200873
874 DebugDatabase.add_optimised(op, op)
875 return op
876
877
878def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
879 if op.run_on_npu and op.type.is_relu_op():
880 ifm = op.inputs[0]
881 ofm = op.outputs[0]
882 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
883 # and requires its own to be inserted
884 if not check_quantized_tens_scaling_equal(ifm, ofm):
885 # Override this op with its own primary op (avgpool)
886 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
887 # And fuse the original activation function to it
888 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200889 # Add explicit rescaling
890 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
891 multiplier, shift = scaling.quantise_scale(rescale)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200892 relu_fused_op.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200893 # Tidy up and assign the ifm and ofm to the new op
894 ifm.consumer_list.remove(op)
895
896 relu_fused_op.add_input_tensor(ifm)
897 relu_fused_op.set_output_tensor(ofm)
898 relu_fused_op.set_ifm_ofm_shapes()
899 op = relu_fused_op
900 return op
901
902
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200903def convert_softmax(op, arch, nng):
904 if op.type == Op.Softmax and op.run_on_npu:
905 softmax = SoftMax(op)
906 op = softmax.get_graph()
907 return op
908
909
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200910def convert_prelu(op, arch, nng):
911 if op.type == Op.Prelu:
912 ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
913 if None in (ifm, alpha, ofm):
914 return op
915
Fredrik Svedberg66591652022-08-29 10:51:27 +0200916 if alpha.values is not None:
917 # If const alpha check for possible optimisations
918 alpha_zp = alpha.quantization.zero_point
919 alpha_scale = alpha.quantization.scale_f32
920 # If all alpha values are the same the PReLU can be converted to LeakyRelu
921 alpha_min = (alpha.values.min().astype(np.int) - alpha_zp) * alpha_scale
922 alpha_max = (alpha.values.max().astype(np.int) - alpha_zp) * alpha_scale
923 if alpha_min == alpha_max:
924 # or even a Relu
925 if alpha_min == 0:
926 new_op = Op.Relu
927 else:
928 new_op = Op.LeakyRelu
929 op.attrs["alpha"] = alpha_min
930 # setup alpha_scaling for bit exact result
931 ifm_scale = ifm.quantization.scale_f32
932 ofm_scale = ofm.quantization.scale_f32
933 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha_scale, ofm_scale)
934 op.attrs["alpha_scaling"] = (alpha.values.min() - alpha_zp, alpha_scale, alpha_shift)
935 # Change op type
936 op.type = new_op
937 op.name = op.name.replace("Prelu", new_op.name)
938 del op.inputs[1] # Remove alpha tensor
939 return op
940 elif alpha_max < 1:
941 # If alpha_max is less than 1 convert PReLU to Max(alpha * IFM, identity * IFM)
942 # Multiply with alpha tensor
943 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
944 mul_alpha.add_input_tensor(ifm)
945 mul_alpha.add_input_tensor(alpha)
946 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
947 mul_alpha.set_output_tensor(fm_alpha)
948 mul_alpha.set_ifm_ofm_shapes()
949 DebugDatabase.add_optimised(op, mul_alpha)
950 if check_quantized_tens_scaling_equal(ifm, ofm):
951 # No scaling is needed
952 fm_id = ifm
953 else:
954 # Add multiplication with identity
955 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
956 mul_identity.add_input_tensor(ifm)
957 # Create const tensor containing identity as scalar
958 quantization = ifm.quantization.clone()
959 quantization.scale_f32 = np.float32(1)
960 quantization.zero_point = 0
961 one = create_const_tensor("one_const", [], ifm.dtype, [1], quantization=quantization)
962 mul_identity.add_input_tensor(one)
963 # Make sure that fm_id is allocated to a different address than fm_alpha
964 fm_id = ofm.clone(op.name + "_id", set_unique=True)
965 mul_identity.set_output_tensor(fm_id)
966 mul_identity.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000967 DebugDatabase.add_optimised(op, mul_identity)
Fredrik Svedberg66591652022-08-29 10:51:27 +0200968
969 # Combine scaled and alpha multiplied values
970 max_op = Operation(Op.Maximum, op.name + "_max")
971 max_op.add_input_tensor(fm_alpha)
972 max_op.add_input_tensor(fm_id)
973 max_op.set_output_tensor(ofm)
974 max_op.set_ifm_ofm_shapes()
975
976 DebugDatabase.add_optimised(op, max_op)
977 ifm.consumer_list.remove(op)
978 return max_op
979
980 # Catch all PReLU conversion for the cases that could not be optimised above
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200981 no_scale_quant = ifm.quantization.clone()
982 no_scale_quant.scale_f32 = None
983 no_scale_quant.zero_point = 0
Fredrik Svedberg66591652022-08-29 10:51:27 +0200984 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200985
986 # Select values < 0
987 min_op = Operation(Op.Minimum, op.name + "_min")
988 min_op.add_input_tensor(ifm)
989 min_op.add_input_tensor(zero)
990 fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
991 min_op.set_output_tensor(fm_negative)
992 min_op.set_ifm_ofm_shapes()
993 DebugDatabase.add_optimised(op, min_op)
994
995 # and multiply with alpha tensor
996 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
997 mul_alpha.add_input_tensor(fm_negative)
998 mul_alpha.add_input_tensor(alpha)
999 fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
1000 mul_alpha.set_output_tensor(fm_alpha)
1001 mul_alpha.set_ifm_ofm_shapes()
1002 DebugDatabase.add_optimised(op, mul_alpha)
1003
1004 # Select (and scale) values > 0
1005 relu_op = Operation(Op.Relu, op.name + "_relu")
1006 relu_op.add_input_tensor(ifm)
1007 fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1008 relu_op.set_output_tensor(fm_scaled)
1009 relu_op.set_ifm_ofm_shapes()
1010 DebugDatabase.add_optimised(op, relu_op)
1011
1012 # Add scaled and alpha multiplied values (without scaling)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001013 add_op = Operation(Op.Add, op.name + "_add")
1014 add_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001015 add_op.add_input_tensor(fm_alpha)
1016 add_op.add_input_tensor(fm_scaled)
1017 add_op.set_output_tensor(ofm)
1018 add_op.set_ifm_ofm_shapes()
1019
1020 DebugDatabase.add_optimised(op, add_op)
1021 ifm.consumer_list.remove(op)
1022 op = add_op
1023
1024 return op
1025
1026
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001027def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
1028 r"""Whenever there is a subgraph with this topology:
1029
Jonas Ohlssond8575072022-03-30 10:30:25 +02001030 Input X For X = -1 or X > 0
1031 | \ / This subgraph can be replaced with either
1032 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
1033 | /
1034 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001035 """
1036
1037 if op.type == Op.Maximum:
1038 # finds the Mul input(s) to the Max
1039 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
1040 if len(muls) == 1:
1041 mul = muls[0].ops[0]
1042 elif len(muls) == 2:
1043 # In the case both inputs are Muls, find the one with the same input as the Max
Fredrik Svedberg66591652022-08-29 10:51:27 +02001044 mul_ifms = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1]
1045 if len(mul_ifms):
1046 mul = mul_ifms[0].ops[0]
1047 else:
1048 # Not using same input
1049 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001050 else:
1051 # No Mul inputs
1052 return op
1053
1054 # make sure the Mul doesn't have any other consumers
1055 mul_ofm = mul.outputs[0]
1056 if len(mul_ofm.consumers()) != 1:
1057 return op
1058 # make sure the Mul doesn't have a fused activation function
1059 if mul.activation:
1060 return op
1061 ifm, ofm = op.get_ifm_ofm()
1062 if ifm is None or ofm is None:
1063 return op
1064
1065 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1066 return op
1067 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
1068 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
1069 return op
1070
1071 # finds the branched input that goes to both the Max and the Mul
1072 shared = set(op.inputs) & set(mul.inputs)
1073 if len(shared) == 1:
1074 shared_in = shared.pop()
1075 # find the constant scalar input to the Mul
1076 const_tens = (set(mul.inputs) - {shared_in}).pop()
1077 # check that it is a scalar
1078 if const_tens.shape != []:
1079 return op
1080 const = const_tens.ops[0]
1081 # check that it is a constant
1082 if const.type != Op.Const:
1083 return op
1084 # Remove the Mul from the shared input's consumers
1085 shared_in.consumer_list.remove(mul)
1086 else:
1087 return op
1088
1089 val = const.outputs[0].values
1090 if val >= 0:
1091 new_op = Op.LeakyRelu
1092 op.attrs["alpha"] = val
1093 # to produce bit exact results, the alpha is not enough;
1094 # save additional scaling info in attr "alpha_scale", to be used as input
1095 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +01001096 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001097 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
1098 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
1099 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
1100 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
1101 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
1102 elif val == -1:
1103 new_op = Op.Abs
1104 else:
1105 return op
1106
1107 op.type = new_op
1108 op.name = op.name.replace("Maximum", new_op.name)
1109 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
1110 op.inputs = [shared_in]
1111 op.set_ifm_ofm_shapes()
1112
1113 # Record optimisation in debug database
1114 DebugDatabase.add_optimised(op, op)
1115
1116 return op
1117
1118
1119def convert_hardswish_to_lut(op, arch, nng):
1120 if op.type == Op.HardSwish:
1121 ifm, ofm = op.get_ifm_ofm()
1122 # Generate the LUT
1123 ifm_scale = np.double(ifm.quantization.scale_f32)
1124 ofm_scale = np.double(ofm.quantization.scale_f32)
1125 zp_in = ifm.quantization.zero_point
1126 zp_out = ofm.quantization.zero_point
1127 ifm_scale_hires = (1 / 128) * ifm_scale
1128 relu_multiplier = np.double(3 / 32768)
1129 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
1130 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
1131 # Use 16bit scale
1132 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
1133 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
1134
1135 values = []
1136 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1137 quantized_min = min(ix)
1138 quantized_max = max(ix)
1139 for x in ix:
1140 input_value = x - zp_in
1141 input_value_hires = input_value * 128
1142 # Compute the input value on essentially the output scale, not shifted yet
1143 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
1144 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
1145 relu_value = np.int16(input_value_hires)
1146 if relu_shift < 31:
1147 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
1148
1149 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
1150
1151 if relu_shift < 31:
1152 relu_value = fp_math.shift_left16(relu_value, 1)
1153
1154 if relu_shift > 31:
1155 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
1156
1157 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
1158 # Now convert that to a 16bit fixedpoint value in [0, 1]
1159 relu_value = (relu_value + (1 << 15)) >> 1
1160 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
1161 shift = 31 - out_shift
1162 shift = -shift if shift < 0 else 0
1163 # Finally apply the output shift
1164 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
1165 lut_result = min(quantized_max, max(quantized_min, lut_result))
1166 values.append(lut_result)
1167 return convert_to_lut(op, values, "hardswish")
1168 return op
1169
1170
1171def convert_lrelu_to_mul_max(op, arch):
1172 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
1173 # (the opposite of convert_mul_max_to_abs_or_lrelu)
1174 ifm, ofm = op.get_ifm_ofm()
1175 if ifm is None or ofm is None:
1176 return op
1177
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001178 alpha = np.float32(op.attrs["alpha"])
1179 use_mul_max = 0 < alpha < 1
Fredrik Svedberg36424312022-09-16 09:39:26 +02001180 is_converted_prelu = "alpha_scaling" in op.attrs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001181 if use_mul_max:
1182 mul_ifm = ifm
1183 new_op = Op.Maximum
1184 else:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001185 # Need to use a different approach for alpha < 0 or alpha > 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001186 no_scale_quant = ifm.quantization.clone()
1187 no_scale_quant.scale_f32 = None
1188 no_scale_quant.zero_point = 0
1189 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
1190
1191 # Select values < 0
1192 min_op = Operation(Op.Minimum, op.name + "_min")
1193 min_op.add_input_tensor(ifm)
1194 min_op.add_input_tensor(zero)
1195 mul_ifm = ifm.clone(op.name + "_negative", set_unique=True)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001196 if alpha < 0 and not is_converted_prelu:
1197 # For negative alpha that is not from a converted PReLU we need to use
1198 # int32 Mul below to perform the (negative) alpha scaling
1199 mul_ifm.dtype = DataType.int32
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001200 min_op.set_output_tensor(mul_ifm)
1201 min_op.set_ifm_ofm_shapes()
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001202 new_op = Op.Add
1203 op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001204 DebugDatabase.add_optimised(op, min_op)
1205
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001206 # Add multiplication with alpha
1207 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001208 mul_alpha.add_input_tensor(mul_ifm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001209 # Create const tensor containing alpha as scalar
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001210 quantization = ifm.quantization.clone()
1211 quantization.min = 0
1212 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
1213 quantization.zero_point = 0
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001214 alpha_dtype = mul_ifm.dtype
Fredrik Svedberg36424312022-09-16 09:39:26 +02001215 if is_converted_prelu:
1216 # The LeakyRelu was the result from convert_prelu and the scaling is provided
Fredrik Svedberg66591652022-08-29 10:51:27 +02001217 scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001218 mul_alpha.explicit_scaling = ExplicitScaling(False, [alpha_shift], [alpha_scale])
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001219 elif alpha == 0 or np.isinf(1 / alpha):
1220 # Handling of alpha near or at zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001221 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001222 scalar = 0
1223 else:
1224 quantization.scale_f32 = alpha
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001225 if alpha_dtype == DataType.int32:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001226 # When the datatype is int32 (alpha negative) we need to do the scaling with the multiplication
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001227 scalar, _ = scaling.elementwise_mul_scale(ifm.quantization.scale_f32, alpha, ofm.quantization.scale_f32)
1228 else:
1229 scalar = 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001230 alpha_tens = create_const_tensor(
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001231 op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], alpha_dtype.as_numpy_type(), quantization=quantization
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001232 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001233 mul_alpha.add_input_tensor(alpha_tens)
1234 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1235 mul_alpha.set_output_tensor(fm_alpha)
1236 mul_alpha.set_ifm_ofm_shapes()
1237 DebugDatabase.add_optimised(op, mul_alpha)
1238
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001239 if not use_mul_max:
1240 relu_op = Operation(Op.Relu, op.name + "_relu")
1241 relu_op.add_input_tensor(ifm)
1242 fm_id = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1243 relu_op.set_output_tensor(fm_id)
1244 relu_op.set_ifm_ofm_shapes()
1245 DebugDatabase.add_optimised(op, relu_op)
1246 elif check_quantized_tens_scaling_equal(ifm, ofm):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001247 # No identity multiplication is needed
1248 fm_id = ifm
1249 else:
1250 # Add multiplication with identity
1251 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1252 mul_identity.add_input_tensor(ifm)
1253 # Create const tensor containing identity as scalar
1254 quantization = ifm.quantization.clone()
1255 quantization.min = 0
1256 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001257 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001258 quantization.zero_point = 0
1259 identity_tens = create_const_tensor(
1260 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
1261 )
1262 mul_identity.add_input_tensor(identity_tens)
1263 # Make sure that fm_id is allocated to a different address than fm_alpha
1264 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1265 mul_identity.set_output_tensor(fm_id)
1266 mul_identity.set_ifm_ofm_shapes()
1267 DebugDatabase.add_optimised(op, mul_identity)
1268
1269 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001270 op.type = new_op
1271 op.name = op.name.replace("LeakyRelu", new_op.name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001272 op.inputs = []
1273 ifm.consumer_list.remove(op)
1274 op.add_input_tensor(fm_alpha)
1275 op.add_input_tensor(fm_id)
1276 op.set_ifm_ofm_shapes()
1277
1278 DebugDatabase.add_optimised(op, op)
1279 return op
1280
1281
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001282def convert_to_lut8(op, fn, fn_name):
1283 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
1284 # fn is a function(real) -> real
1285 ifm, ofm = op.get_ifm_ofm()
1286 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1287 return op
1288 # Generate the LUT
1289 ifm_scale = np.double(ifm.quantization.scale_f32)
1290 ofm_scale = np.double(ofm.quantization.scale_f32)
1291 zp_in = ifm.quantization.zero_point
1292 zp_out = ofm.quantization.zero_point
1293 values = []
1294 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1295 quantized_min = min(ix)
1296 quantized_max = max(ix)
1297 for x in ix:
1298 x_real = ifm_scale * (x - zp_in)
1299 y_real = fn(x_real)
1300 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1301 lut_result = min(quantized_max, max(quantized_min, lut_result))
1302 values.append(lut_result)
1303 return convert_to_lut(op, values, fn_name)
1304
1305
1306def convert_lrelu_to_lut(op, arch):
1307 ifm, ofm = op.get_ifm_ofm()
1308 # Generate the LUT
1309 alpha = op.attrs["alpha"]
1310 ifm_scale = np.double(ifm.quantization.scale_f32)
1311 ofm_scale = np.double(ofm.quantization.scale_f32)
1312 zp_in = ifm.quantization.zero_point
1313 zp_out = ofm.quantization.zero_point
1314 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1315 alpha_scalar = 1
1316 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1317 if "alpha_scaling" in op.attrs:
1318 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1319 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1320 values = []
1321 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1322 quantized_min = min(ix)
1323 quantized_max = max(ix)
1324 for x in ix:
1325 if x < zp_in:
1326 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1327 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1328 )
1329 else:
1330 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1331 lut_result = min(quantized_max, max(quantized_min, lut_result))
1332 values.append(lut_result)
1333 return convert_to_lut(op, values, "lrelu")
1334
1335
1336def convert_lrelu(op, arch, nng):
1337 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
1338 if op.type != Op.LeakyRelu:
1339 return op
1340 ifm, ofm = op.get_ifm_ofm()
1341 if ifm is None or ofm is None:
1342 return op
Fredrik Svedberg36424312022-09-16 09:39:26 +02001343 alpha = op.attrs["alpha"]
1344 if alpha == 0:
1345 # When alpha is 0 the opertion can be converted to a ReLU
1346 op.type = Op.Relu
1347 op.name = op.name.replace("LeakyRelu", op.type.name)
1348 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001349 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1350 # use LUT for int8/uint8
1351 return convert_lrelu_to_lut(op, arch)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001352 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16 and alpha > 0:
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001353 # use LeakyRelu unmodified for int16 with equal input/output scaling and positive alpha
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001354 return op
1355 return convert_lrelu_to_mul_max(op, arch)
1356
1357
1358def convert_tanh_sigmoid_to_lut(op, arch, nng):
1359 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
1360 if op.type == Op.Sigmoid:
1361 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1362 elif op.type == Op.Tanh:
1363 return convert_to_lut8(op, math.tanh, "tanh")
1364 return op
1365
1366
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001367def remove_memory_only_ops(op, arch):
1368 if op.run_on_npu and op.type in memory_only_ops:
1369 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001370
1371
1372def fuse_activation_function_with_prev(op, arch, nng):
1373 # if op is a no-op: attempts to move the activation function to the preceding op
1374 if not op.attrs.get("is_nop", False) or op.activation is None:
1375 return op
1376 ifm, ofm = op.get_ifm_ofm()
1377 if ifm is None or ofm is None:
1378 return op
1379 # finds the input(s) to the operation
1380 prev_op = ifm.ops[0]
1381 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1382 fuse = (
1383 prev_op.run_on_npu
1384 and prev_op.type.npu_block_type != NpuBlockType.Default
1385 and len(ifm.ops) == 1
1386 and len(prev_op.outputs[0].consumers()) == 1
1387 and prev_op.activation is None
1388 )
1389 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1390 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1391 # LUT currently only works correctly for elementwise ops
1392 fuse = False
1393 if not fuse:
1394 return op
1395 # Move the fused activation function + corresponding info to prev_op
1396 prev_op.activation = op.activation
1397 prev_op.forced_output_quantization = op.forced_output_quantization
1398 if op.activation_lut is not None:
1399 prev_op.set_activation_lut(op.activation_lut)
1400 # Bypass op
1401 prev_op.set_output_tensor(ofm)
wilisa0179a89042022-11-02 17:18:43 +00001402 DebugDatabase.add_optimised(prev_op, prev_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001403 return op
1404
1405
1406def _leading_pad_ok(leading_pad, stride, kernel_size):
1407 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1408 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1409 max_size = kernel_size // 2
1410 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1411
1412
1413def replace_pad_by_hw_pad(op: Operation, arch, nng):
1414 """
1415 Tries to completely remove a PAD operator by using hardware padding.
1416 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1417 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1418 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1419 if both operations can be run on the NPU.
1420 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1421 """
1422 if (
1423 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +00001424 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001425 and op.run_on_npu
1426 and op.attrs["padding"] == Padding.VALID
1427 ):
1428 pad_op = op.ifm.ops[0]
1429 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1430 return op
1431 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1432 return op
1433 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1434 k = op.kernel
1435 k_w, k_h = k.dilated_wh()
1436
1437 # Check if the PAD operator can be replaced by hardware padding
1438 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1439 # Too much padding, it would require hardware padding to actually insert zeros
1440 return op
1441 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1442 return op
1443
1444 if op.type.is_avgpool_op():
1445 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1446 for pad, k_size in (
1447 (left, k_w),
1448 (right, k_w),
1449 (top, k_h),
1450 (bottom, k_h),
1451 ):
1452 if pad not in (0, k_size // 2):
1453 return op
1454 # Average pool is converted to depthwise, because NPU average pool + same padding
1455 # has a special implementation that is different from PAD followed by average pool with
1456 # valid padding.
1457 k_w, k_h = op.kernel.width, op.kernel.height
1458 ifm = op.ifm
1459 # Remember other inputs
1460 other_inputs = op.inputs[1:]
1461 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1462 quantization = QuantizationParameters(0.0, 255.0)
1463 quantization.scale_f32 = 1.0 / (k_w * k_h)
1464 quantization.zero_point = 0
1465 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1466 weights = np.full(shape, 1)
1467
1468 weight_tens = create_const_tensor(
1469 op.name + "_weights",
1470 shape,
1471 op.ifm.dtype,
1472 weights,
1473 np.uint8,
1474 purpose=TensorPurpose.Weights,
1475 quantization=quantization,
1476 )
James Peet7519d502021-07-19 16:47:58 +01001477 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001478 op.type = Op.DepthwiseConv2DBias
1479 op.inputs = []
1480 op.add_input_tensor(ifm)
1481 op.add_input_tensor(weight_tens)
1482 # Add bias tensor, all biases set to 0
1483 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001484 fixup_bias_tensors(op, arch, nng, DataType.int32)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001485 # Add other inputs
1486 op.inputs.extend(other_inputs)
1487 op.rounding_mode = NpuRoundingMode.NATURAL
1488
1489 # Bypass the PAD operator
1490 op.set_input_tensor(pad_op.ifm, 0)
1491 # Adjust the padding attributes of the convolution operator
1492 op.attrs["padding"] = Padding.EXPLICIT
1493 op.attrs["explicit_padding"] = (top, left, bottom, right)
1494 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +00001495 DebugDatabase.add_optimised(op, op)
1496
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001497 return op
1498
1499
1500def convert_pad(op: Operation, arch, nng):
1501 """
1502 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1503 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1504 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1505 """
1506 if op.type != Op.Pad or not op.run_on_npu:
1507 return op
1508 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1509
1510 ifm = op.ifm
1511 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001512 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001513 ofm = op.ofm
1514 assert ofm is not None
1515 ofm.ops = []
1516 ofm_shape = op.ofm_shapes[0]
1517
1518 # Average pool op that copies IFM to the right place inside the OFM
1519 shp0 = Shape4D(0, 0, 0, 0)
1520 shp_top = shp0.with_height(top)
1521 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1522 avgpool_op.activation = op.activation
1523 quant = ofm.quantization
1524 pad_value = quant.zero_point
1525 # Add operations that fill the borders of the OFM
1526 if top > 0:
1527 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1528 zero_tens = create_const_tensor(
1529 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1530 )
1531 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1532 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1533 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1534 if bottom > 0:
1535 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1536 zero_tens = create_const_tensor(
1537 op.name + "_bottom",
1538 shape.as_list(),
1539 ofm.dtype,
1540 shape.elements() * [pad_value],
1541 np.uint8,
1542 quantization=quant,
1543 )
1544 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1545 create_avg_pool_for_concat(
1546 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1547 )
1548 if left > 0:
1549 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1550 zero_tens = create_const_tensor(
1551 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1552 )
1553 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1554 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1555 if right > 0:
1556 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1557 zero_tens = create_const_tensor(
1558 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1559 )
1560 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1561 create_avg_pool_for_concat(
1562 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1563 )
1564
1565 op.type = Op.ConcatTFLite
1566 return avgpool_op
1567
1568
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001569def fixup_bias_tensors(op, arch, nng, dtype=None):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001570 if op.type.needs_bias() and op.bias is None:
1571 # Op has no bias, add bias tensor filled with zeros
1572 nr_biases = op.inputs[1].shape[-1]
1573 bias_values = [0] * nr_biases
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001574 # The DataType of the bias tensor can be explicitly provided or deduced from the ifm
1575 # DataType. Default is int32 bias for 8-bit ifms and int64 for int16 ifms.
1576 # For int16 the selected bias DataType will have an impact on the scaling
1577 # used when encoding the scales and biases later. The default mode will match the
1578 # refence with reduced scaling for int64 bias.
1579 # This means that in cases (in the graph optimiser) where DepthwiseConv2DBias
1580 # is used to emulate average pool int32 bias should be selected for full precision
1581 # int16 scaling.
1582 if dtype is None:
1583 dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32
1584 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001585 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1586
1587 return op
1588
1589
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001590def fixup_asymmetric_weights(op, arch, nng):
1591 if op.run_on_npu and (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op()):
1592 if op.ifm.dtype == DataType.int8:
1593 if not np.all(op.weights.quantization.zero_point == 0):
1594 print(f"Warning: {op.type} '{op.name}' has asymmetric weights, zero points have been adjusted.")
1595 op.weights.quantization.zero_point *= 0
1596
1597 return op
1598
1599
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001600def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1601 if op.type == Op.Mean and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001602 inp, axis = op.inputs
1603 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001604 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001605 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001606 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001607
1608 # Height and width axes have different index depending on dimensions
1609 if axis.shape == [] or axis.shape[0] == 1: # single axis
1610 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1611 if dims in (2, 3):
1612 if axis == 0:
1613 h, w = shape[axis], 1
1614 else:
1615 h, w = 1, shape[axis]
1616 else:
1617 if axis == 1:
1618 h, w = shape[axis], 1
1619 else:
1620 h, w = 1, shape[axis]
1621 else: # multiple axes
1622 axis = sorted(axis.values)
1623 h, w = [shape[i] for i in axis]
1624
1625 # Set necessary depthwise attributes
1626 op.attrs.update(
1627 {
1628 "padding": Padding.VALID,
1629 "stride_h": 1,
1630 "stride_w": 1,
1631 "strides": (1, 1, 1, 1),
1632 "depth_multiplier": 1,
1633 "channel_multiplier": 1,
1634 "dilation_h_factor": 1,
1635 "dilation_w_factor": 1,
1636 "dilation": (1, 1, 1, 1),
1637 }
1638 )
1639 # Change op type
1640 op.type = Op.DepthwiseConv2DBias
1641 # Set IFM/OFM shapes after changing op type
1642 op.set_ifm_ofm_shapes()
1643
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001644 weight_scale, bias = 1, 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001645 ofmq, ifmq = op.ofm.quantization, inp.quantization
Johan Alfvén9d51ec42022-10-27 16:30:01 +02001646 if ifmq.is_scaling_equal(ofmq):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001647 # Here we can just use a simple AvgPool with truncating rounding,
1648 # as we're emulating simple integer division.
1649 op.rounding_mode = NpuRoundingMode.TRUNCATE
1650 op.type = Op.AvgPool
1651 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1652 else:
1653 op.rounding_mode = NpuRoundingMode.NATURAL
1654 weight_scale = 1 / (h * w)
1655 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1656 bias = -ifmq.zero_point * h * w
1657 fiq = ifmq.clone()
1658 fiq.zero_point = 0
1659 op.forced_input_quantization = fiq
1660
1661 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001662 def extend_dims(dim, in_shape):
1663 if dim < 4:
1664 in_shape = [1] + in_shape
1665 if dim == 2:
1666 in_shape += [1]
1667 return in_shape
1668
1669 if dims < 4 or dims_ofm < 4:
1670 # Fix the ofm dimension when keep_dims is false
1671 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1672 if isinstance(axis, int) and dims_ofm + 1 == dims:
1673 ofm_shape.insert(axis, 1)
1674 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1675 for i in axis:
1676 ofm_shape.insert(i, 1)
1677 shape = extend_dims(dims, shape)
1678 dims_ofm = len(ofm_shape)
1679 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1680 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001681
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001682 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001683 weight_shape = None
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001684 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001685 # This can only happen and be done for multiple axes, and
1686 # h * w <= 256 for DepthwiseConv2DBias
1687 # h * w <= 4096 for AvgPool
1688 # which is checked in supported ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001689 shape = [shape[0], 1, h * w, shape[3]]
1690 op.ifm_shapes[0] = Shape4D(shape)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001691 weight_shape = [1, h * w, shape[3], shape[0]]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001692 if h > 256 and op.type == Op.AvgPool:
1693 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1694
1695 # If the AvgPool version is used, we don't need to do anything else
1696 if op.type == Op.AvgPool:
wilisa0179a89042022-11-02 17:18:43 +00001697 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001698 return op
1699
1700 # Make unit weight tensor quantization
1701 weight_quant = ifmq.clone()
1702 weight_quant.min = 0
1703 weight_quant.max = 255
1704 weight_quant.scale_f32 = weight_scale
1705 weight_quant.zero_point = 0
1706
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001707 if weight_shape is None:
1708 # Set weight shape to [H,W,C,B]
1709 weight_shape = [h, w, shape[3], shape[0]]
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001710
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001711 # Add unit weight tensor
1712 op.set_input_tensor(
1713 create_const_tensor(
1714 "weights",
1715 weight_shape,
1716 inp.dtype,
1717 np.ones(weight_shape),
1718 value_dtype=np.uint8,
1719 quantization=weight_quant,
1720 ),
1721 1,
1722 )
James Peet7519d502021-07-19 16:47:58 +01001723 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001724
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001725 # Add bias tensor
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001726 bias_shape = [shape[-1]]
1727 op.inputs.append(create_const_tensor("bias", bias_shape, DataType.int32, np.ones(bias_shape) * bias))
wilisa0179a89042022-11-02 17:18:43 +00001728 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001729
1730 return op
1731
1732
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001733def optimise_quantize(op: Operation, arch, nng):
1734
1735 if op.type == Op.Quantize and op.run_on_npu:
1736
1737 ifm, ofm = op.get_ifm_ofm()
1738 input_values = ifm.values
1739
1740 # Guard clause - input not const or no values to quantize
1741 if ifm.ops[0].type != Op.Const or input_values is None:
1742 return op
1743
1744 # Singular val in numpy array, convert to indexable array
1745 if input_values.ndim == 0:
1746 input_values = np.array([input_values])
1747
Fredrik Svedberg11563172022-07-06 14:54:12 +02001748 # requantized int8 to int8 or int16 to int16
1749 if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001750
1751 # scale needs to use double precision to match TFLite reference kernel
1752 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
1753 effective_multiplier, effective_shift = quantise_scale(effective_scale)
1754
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001755 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001756 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001757 input_val = val - ifm.quantization.zero_point
1758
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001759 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
1760 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001761
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001762 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
1763 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001764
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001765 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
1766 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001767
1768 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001769 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001770
1771 quantized_vals = []
1772 for val in input_values:
1773
1774 # Derive quantized value
1775 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001776 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
1777 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001778
1779 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001780 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
1781
1782 # Unsupported data type
1783 else:
1784 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001785
1786 # Make quantize op const and disconnect from parent node
1787
1788 # Remove reference of the current quant op from the parent tensor's consumer list
1789 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1790
1791 # Clear any references to parent node
1792 op.inputs = []
1793
1794 # Convert this quantize op to const
1795 op.type = Op.Const
1796
1797 return op
1798
1799
Ayaan Masood4965fae2022-06-29 11:30:57 +01001800def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
1801 """Static optimisation for SHAPE operator output value known at compile time"""
1802
1803 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
1804
1805 if op.type == Op.Shape and op.run_on_npu:
1806
1807 ifm, ofm = op.get_ifm_ofm()
1808
1809 if len(ifm.shape) != ofm.shape[0]:
1810 return op
1811
1812 # Remove reference of the current shape op from the parent tensor's consumer list
1813 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1814
1815 # Clear any references to parent node
1816 op.inputs = []
1817
1818 # Convert this SHAPE op to const
1819 op.type = Op.Const
1820
1821 # Add size calculation to shape output tensors
1822 ofm.values = np.array(ifm.shape)
1823
1824 return op
1825
1826
Tim Hallea4ba662022-11-11 18:19:53 +00001827def fixup_dilation_gt2(op, arch, nng):
1828 assert op.run_on_npu
1829 if op.type == Op.Conv2DBias or op.type == Op.DepthwiseConv2DBias:
1830 dilation_w, dilation_h = op.get_kernel_dilation()
1831
1832 # if dilation in either axis is greater than that supported by the hardware then we must manually dilate the
1833 # kernel
1834 if dilation_w > 2 or dilation_h > 2:
1835 kernel_w, kernel_h = op.get_kernel_size()
1836 kernel_ic = op.weights.shape[-2]
1837 kernel_oc = op.weights.shape[-1]
1838
1839 # if the dilation is a multiple of 2 then the hardware dialtion can be enabled to provide that multiple
1840 # of 2. this allows the kernel size to be reduced (via the scaled dilation) by half in that dimension.
1841 # odd = 1, even = 2
1842 hw_dilation_h = 1 if (dilation_h & 1) else 2
1843 hw_dilation_w = 1 if (dilation_w & 1) else 2
1844
1845 scale_dilation_h = dilation_h // hw_dilation_h
1846 scale_dilation_w = dilation_w // hw_dilation_w
1847
1848 # create new empty kernel (HWIO format)
1849 new_kernel_h = (kernel_h - 1) * scale_dilation_h + 1
1850 new_kernel_w = (kernel_w - 1) * scale_dilation_w + 1
1851
1852 new_kernel_shape = [new_kernel_h, new_kernel_w, kernel_ic, kernel_oc]
1853 new_kernel_values = np.zeros(new_kernel_shape, dtype=op.weights.values.dtype)
1854
1855 # copy the original kernel values into the new sparse kernel
1856 for h in range(0, kernel_h):
1857 for w in range(0, kernel_w):
1858 new_h = h * scale_dilation_h
1859 new_w = w * scale_dilation_w
1860 new_kernel_values[new_h, new_w, :, :] = op.weights.values[h, w, :, :]
1861
1862 # update the weight tensor with the new dilated kernel
1863 op.weights.shape = new_kernel_shape
1864 op.weights.values = new_kernel_values
1865
1866 # enable(=2) / disable(=1) hardware dilation
1867 op.attrs["dilation"] = (1, hw_dilation_h, hw_dilation_w, 1) # nhwc format
1868 op.attrs["dilation_h_factor"] = hw_dilation_h
1869 op.attrs["dilation_w_factor"] = hw_dilation_w
1870
1871 return op
1872
1873
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001874def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001875 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001876 return op
1877
1878
1879def tflite_optimise_graph(nng, arch):
Fredrik Svedberg11563172022-07-06 14:54:12 +02001880 # Compile time static optimisations
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001881 optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
1882
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001883 for idx, sg in enumerate(nng.subgraphs):
1884 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001885 nng,
1886 sg,
1887 arch,
1888 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01001889 optimisation_list,
1890 rewrite_unsupported=False,
1891 )
1892
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001893 # Pre-processing step
1894 pre_process_list = [
1895 supported_operator_check,
1896 set_ifm_ofm_op_shapes,
1897 ]
1898
Ayaan Masood4965fae2022-06-29 11:30:57 +01001899 for idx, sg in enumerate(nng.subgraphs):
1900 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1901 nng,
1902 sg,
1903 arch,
1904 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02001905 pre_process_list,
1906 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001907 )
1908
1909 # Handle Concat Ops
1910 for idx, sg in enumerate(nng.subgraphs):
1911 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1912 sg.refresh_after_modification()
1913
1914 # Handle Split Ops
1915 for idx, sg in enumerate(nng.subgraphs):
1916 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1917 nng,
1918 sg,
1919 arch,
1920 [],
1921 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1922 rewrite_unsupported=False,
1923 )
1924
1925 for idx, sg in enumerate(nng.subgraphs):
1926 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001927 nng,
1928 sg,
1929 arch,
1930 [rewrite_split_ops],
1931 [],
1932 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001933 )
1934
1935 # Handle sg input output
1936 for idx, sg in enumerate(nng.subgraphs):
1937 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001938 nng,
1939 sg,
1940 arch,
1941 [],
1942 [fix_sg_input_output],
1943 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001944 )
1945
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001946 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001947 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001948 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001949 sg.refresh_after_modification()
1950
1951 # Rewrite of operators
1952 op_rewrite_list = [
1953 set_tensor_equivalence,
1954 convert_mean_to_depthwise_conv_or_avgpool,
1955 convert_depthwise_to_conv,
1956 convert_conv_to_fc,
1957 convert_softmax,
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001958 convert_prelu,
Fredrik Svedberg36424312022-09-16 09:39:26 +02001959 convert_mul_max_to_abs_or_lrelu,
1960 convert_lrelu,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001961 optimise_strided_conv,
1962 convert_hardswish_to_lut,
1963 rewrite_fully_connected_input,
1964 convert_batched_fc_shape,
1965 fixup_conv2d_backprop,
1966 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001967 reorder_depthwise_weights,
Tim Hall885033b2022-07-21 11:46:03 +01001968 fixup_resize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001969 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001970 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001971 convert_tanh_sigmoid_to_lut,
1972 replace_pad_by_hw_pad,
Tim Hallea4ba662022-11-11 18:19:53 +00001973 fixup_dilation_gt2,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001974 ]
1975
1976 for idx, sg in enumerate(nng.subgraphs):
1977 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001978 nng,
1979 sg,
1980 arch,
1981 [],
1982 op_rewrite_list,
1983 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001984 )
1985
1986 for idx, sg in enumerate(nng.subgraphs):
1987 # remove passthrough tensors and attempt further optimizations
1988 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1989 nng,
1990 sg,
1991 arch,
1992 [remove_passthrough_tensor],
1993 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1994 )
1995
1996 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1997 # since ifm/ofm_shapes are of importance to this function
1998 for sg in nng.subgraphs:
1999 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
2000 sg.refresh_after_modification()
2001
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01002002 # Make sure that const optimisations on subgraph outputs are handled correctly
2003 for sg in nng.subgraphs:
2004 for ofm in sg.output_tensors:
2005 if ofm.is_const and ofm.ops[0].type_changed:
2006 # Subgraph output cannot be const - insert a memory copy
2007 op = ofm.ops[0]
2008 ofm_clone = ofm.clone()
2009 ofm_clone.values = ofm.values
2010 ofm.values = None
2011 np_dtype = ofm.dtype.as_numpy_type()
2012 zero = create_const_tensor("zero", [1], ofm.dtype, [0], np_dtype, quantization=ofm.quantization)
2013 memcpy = create_add_nop(f"{ofm.name}_copy")
2014 memcpy.add_input_tensor(ofm_clone)
2015 memcpy.add_input_tensor(zero)
2016 memcpy.set_output_tensor(ofm)
2017 memcpy.set_ifm_ofm_shapes()
2018 op.set_output_tensor(ofm_clone)
2019 DebugDatabase.add_optimised(op, memcpy)
2020
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002021 return nng