blob: af8695b79502aca30cf82fff92aa311a9ff0f18c [file] [log] [blame]
Tim Hall3b1578e2023-01-13 17:57:25 +00001# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020017# Description:
18# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
19# to do the traversal of the graph.
20import math
21import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020022
23import numpy as np
24
25from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020026from . import rewrite_graph
27from . import scaling
28from .api import NpuRoundingMode
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020029from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020030from .data_type import DataType
31from .debug_database import DebugDatabase
32from .errors import UnsupportedFeatureError
33from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020034from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020035from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020036from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020037from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020038from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020039from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020040from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020041from .graph_optimiser_util import needed_total_padding
42from .graph_optimiser_util import set_ifm_ofm_op_shapes
43from .graph_optimiser_util import set_tensor_equivalence
44from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020045from .numeric_util import round_away_zero
46from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020047from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020048from .operation import NpuBlockType
49from .operation import Op
50from .operation import Operation
51from .operation import Padding
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +010052from .operation_util import create_add_nop
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020053from .operation_util import create_avgpool_nop
54from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010055from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020056from .shape4d import Shape4D
57from .softmax import SoftMax
58from .tensor import check_quantized_tens_scaling_equal
59from .tensor import create_const_tensor
60from .tensor import create_equivalence_id
61from .tensor import QuantizationParameters
62from .tensor import Tensor
63from .tensor import TensorPurpose
64from .tflite_mapping import optype_to_builtintype
65
66passthrough_nodes = (Op.Identity,)
67
68
69def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
70 """Creates an average pool for the given concat op/input feature map"""
71 ofm = concat_op.ofm
72 avgpool_op = create_avgpool_nop(name)
73 avgpool_op.inputs = [ifm]
74 avgpool_op.outputs = [ofm]
75
76 avgpool_op.write_offset = write_offset
77 avgpool_op.write_shape = ifm_shape
78 ofm.ops.append(avgpool_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020079 avgpool_op.ifm_shapes.append(ifm_shape)
80 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
81 avgpool_op.memory_function = Op.ConcatSliceWrite
wilisa0179a89042022-11-02 17:18:43 +000082 DebugDatabase.add_optimised(concat_op, avgpool_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020083 return avgpool_op
84
85
86def remove_passthrough_tensor(tens, arch, nng):
87 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
88 assert len(tens.ops[0].inputs) == 1
89 tens = tens.ops[0].inputs[0]
90 return tens
91
92
93def rewrite_concat_ops(op, arch):
94 if not op.run_on_npu or not op.type.is_concat_op():
95 return
96
97 axis_4D = 0
98 ofm = op.ofm
99 ofm.ops = []
100 offset = 0
101
102 unfuse_activation_function(op)
103
104 if op.type == Op.Pack:
105 # Pack is also referred to as Stack
106 axis = int(op.attrs["axis"])
107 if axis < 0: # Convert to positive axis
108 axis = len(op.inputs[0].shape) + 1 + axis
109
110 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
111
112 axis_4D = axis + (4 - len(desired_shape))
113
114 for idx, inp in enumerate(op.inputs):
115 op.ifm_shapes[idx] = Shape4D(desired_shape)
116 op.type = Op.PackReshaped
117
118 inputs, axis = op.get_concat_inputs_axis()
119 for idx, inp in enumerate(inputs):
120 if op.type != Op.PackReshaped:
121 op.ifm_shapes[idx] = Shape4D(inp.shape)
122 if axis >= 0:
123 axis_4D = axis + (4 - len(inp.shape))
124 else:
125 axis_4D = axis
126 write_offset = [0, 0, 0, 0]
127 write_offset[axis_4D] = offset
128 concat_end = offset + op.ifm_shapes[idx][axis_4D]
129 create_avg_pool_for_concat(
130 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
131 )
132 offset = concat_end
133 assert ofm.shape[axis] == offset
134
135 return op
136
137
138def rewrite_split_ops(tens, arch, nng):
139
140 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
141 split_op = tens.ops[0]
142
143 # Not supported so leave it and run on CPU
144 if not split_op.run_on_npu:
145 return tens
146
147 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
148
149 tens.ops = []
150 new_op = Operation(Op.SplitSliceRead, split_op.name)
151 new_op.inputs = [inp]
152 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000153 if None in (offset_end, offset_start):
154 read_shape = None
155 else:
156 # the read shape is relative to each start offset
157 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200158
159 # For Split the offset cannot be extracted from the tensor so it has to
160 # be calculated from the index of the output tensor
161 if axis is not None:
162 # Get the start and end of the split
163 offset_start = [0] * 4
164 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
165 for idx, out in enumerate(outputs):
166 if axis_4D_list is not None:
167 axis_4D = axis_4D_list[idx]
168 else:
169 split_op.ofm_shapes[idx] = Shape4D(out.shape)
170 if axis >= 0:
171 axis_4D = axis + (4 - len(out.shape))
172 else:
173 axis_4D = axis
174
175 if out == tens:
176 ofm_shape_idx = idx
177 read_shape = split_op.ofm_shapes[idx]
178 break
179
180 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
181
182 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
183 new_op.read_shapes[0] = read_shape
184 new_op.run_on_npu = True
185 new_op.set_output_tensor(tens)
186 new_op.ifm_shapes.append(Shape4D(inp.shape))
187 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
188 DebugDatabase.add_optimised(split_op, new_op)
189
190 return tens
191
192
193def remove_SplitSliceRead(op, arch):
194
195 if op.type == Op.SplitSliceRead:
196 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
197 if (
198 len(op.ofm.consumer_list) == 1
199 and op.ofm.consumer_list[0] is not None
200 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200201 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200202 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
203 ):
204 # SplitSliceRead can be performed by tensor consumer
205 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200206 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200207 else:
208 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
209 avgpool_op.add_input_tensor(op.ifm)
210 avgpool_op.outputs = [op.ofm]
211 op.ofm.ops.remove(op)
212 op.ofm.ops.append(avgpool_op)
213 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
214 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
215 avgpool_op.read_offsets[0] = op.read_offsets[0]
216 avgpool_op.read_shapes[0] = op.read_shapes[0]
217
218 op.ifm.consumer_list.remove(op)
219 DebugDatabase.add_optimised(op, avgpool_op)
220
221
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200222def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
223 k_w, k_h = kernel.dilated_wh()
224 s_x, s_y = kernel.stride
225 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
226 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
227 if padding_type == Padding.SAME:
228 left_pad = (xpad + 0) // 2
229 right_pad = (xpad + 1) // 2
230 top_pad = (ypad + 0) // 2
231 bottom_pad = (ypad + 1) // 2
232 elif padding_type == Padding.VALID:
233 left_pad = 0
234 right_pad = 0
235 top_pad = 0
236 bottom_pad = 0
237 elif padding_type == Padding.EXPLICIT:
238 # Padding is specified in a PAD operator which has been bypassed.
239 top, left, bottom, right = explicit_padding
240 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
241 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Rickard Bolin9ae34552022-06-09 13:07:17 +0000242 elif padding_type == Padding.TILE:
243 # The values in the explicit padding only represent the "direction" in which to pad
244 top_pad, left_pad, bottom_pad, right_pad = explicit_padding
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200245 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000246 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200247 padding = (top_pad, left_pad, bottom_pad, right_pad)
248 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
249 return padding, skirt
250
251
252def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
253 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
254 if padding_type == Padding.SAME:
255 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
256 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
257 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
258 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
259 left_pad = max(kernel_width - 1 - right_pad, 0)
260 top_pad = max(kernel_height - 1 - bottom_pad, 0)
261 elif padding_type == Padding.VALID:
262 right_pad = max(kernel_width - 2, 0)
263 bottom_pad = max(kernel_height - 2, 0)
264 left_pad = kernel_width - 1
265 top_pad = kernel_height - 1
266 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000267 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200268 padding = (top_pad, left_pad, bottom_pad, right_pad)
269 skirt = padding
270 return padding, skirt
271
272
273def fixup_conv2d_backprop(op, arch, nng):
274 if op.type == Op.Conv2DBackpropInput:
275 # flip the inputs
276 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
277 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000278 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200279
280 # Update strides
281 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
wilisa0179a89042022-11-02 17:18:43 +0000282 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200283
284 return op
285
286
287# Convert the op to an elementwise add
Tim Hall885033b2022-07-21 11:46:03 +0100288def convert_resize_1x1_to_add(op):
289 op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200290 op.name = op.name + "_add"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200291 # Create an input tensor filled with zeros
wilisa018289d512023-01-12 08:17:23 +0000292 name = op.inputs[1].name + "_add"
293 dtype = op.inputs[0].dtype
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200294 shape = op.ofm_shapes[0].as_list()
wilisa018289d512023-01-12 08:17:23 +0000295 values = np.zeros(shape, dtype.as_numpy_type())
296 quantization = QuantizationParameters(0.0, 255.0)
297 quantization.scale_f32 = 1.0
298 quantization.zero_point = 0
299 op.set_input_tensor(create_const_tensor(name, shape, dtype, values, quantization=quantization), 1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200300 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000301 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200302
303 return op
304
305
Tim Hall885033b2022-07-21 11:46:03 +0100306# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
307# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
308# to select the appropriate nearest neighbor value
309def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
310 ifm = op.ifm
311 ofm = op.ofm
312 output_depth = ofm.shape[-1]
313 dw_op_attrs = {
314 "padding": Padding.VALID,
315 "stride_h": 1,
316 "stride_w": 1,
317 "strides": (1, 1, 1, 1),
318 "depth_multiplier": 1,
319 "channel_multiplier": 1,
320 "dilation_h_factor": 1,
321 "dilation_w_factor": 1,
322 "dilation": (1, 1, 1, 1),
323 }
324
325 # change resizebilinear to depthwise
326 op.type = Op.DepthwiseConv2DBias
327 op.attrs.update(dw_op_attrs)
328 op.set_input_tensor(ifm, 0) # ifm tensor index
329 op.activation = None
330
331 # add input resample to resize by x2
332 op.ifm_resampling_mode = resampling_mode.NEAREST
333
334 # don't care about the rounding mode as it is nearest neighbor
335
336 # setup weight tensor
337 weight_quant = QuantizationParameters()
338 weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
339 weight_quant.zero_point = 0
340 weight_quant.quant_dim = 0
341 ofm_dtype = ofm.dtype
Tim Hall3b1578e2023-01-13 17:57:25 +0000342 if ofm_dtype.type == BaseType.UnsignedInt:
Tim Hall885033b2022-07-21 11:46:03 +0100343 weight_quant.quant_min = 0
344 weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
345 else:
Tim Hall885033b2022-07-21 11:46:03 +0100346 weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
347 weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
348
349 weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
350
351 # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
352 # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
353 # below-and-right (i.e. next) to it (D).
354 # 0---1---2
355 # | A | B |
356 # 1---*---+
357 # | C | D |
358 # 2---+---+
359 weight_values = [0] * (upscale_factor * upscale_factor)
360 centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
361 weight_values[centre_coeff] = 1
362
363 # add weight tensor, this will discard the size tensor of the resize op
364 op.set_input_tensor(
365 create_const_tensor(
366 "weights",
367 weight_shape,
Tim Hall3b1578e2023-01-13 17:57:25 +0000368 ofm_dtype,
Tim Hall885033b2022-07-21 11:46:03 +0100369 np.array(weight_values).reshape(weight_shape),
Tim Hall885033b2022-07-21 11:46:03 +0100370 quantization=weight_quant,
371 ),
372 1, # inputs tensor weight index
373 )
374
375 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
376 # need to append the bias tensor as resize ops only have 2 inputs
377 assert len(op.inputs) == 2
378 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +0200379 fixup_bias_tensors(op, None, None, DataType.int32)
Tim Hall885033b2022-07-21 11:46:03 +0100380
381 # finally update the shape incase we've change the tensor shapes or connections
382 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000383 DebugDatabase.add_optimised(op, op)
Tim Hall885033b2022-07-21 11:46:03 +0100384
385 return op
386
387
388# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
389# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
390# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
391def convert_resize_to_upscale_and_average_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200392 pre_op = op
393 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000394 dtype = op.ifm.dtype
Tim Hall885033b2022-07-21 11:46:03 +0100395
Rickard Boline546def2022-01-25 15:45:00 +0000396 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Tim Hall47c76362022-07-18 21:26:47 +0100397 op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
Tim Hall3c5cfe92022-03-16 16:31:57 +0000398 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200399
400 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
Tim Hall47c76362022-07-18 21:26:47 +0100401
402 # Get upscale factor that was calculated in the supported operators check
403 upscale_factor = op.attrs["upscale_factor"]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200404
Rickard Boline546def2022-01-25 15:45:00 +0000405 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100406 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
407 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
Rickard Boline546def2022-01-25 15:45:00 +0000408 n = int(np.log2(upscale_factor))
409
Tim Hall885033b2022-07-21 11:46:03 +0100410 # Perform x2 upscaling n-1 times
Rickard Boline546def2022-01-25 15:45:00 +0000411 scaled_op = pre_op
412 for count in range(n - 1):
413 if count > 0:
414 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200415 scaled_op.inputs[0] = pre_op.outputs[0]
416
Tim Hall885033b2022-07-21 11:46:03 +0100417 # Nearest neighbor x2 upscaling
Tim Hall47c76362022-07-18 21:26:47 +0100418 upscaled_shape = upscaled_shape * 2
Rickard Boline546def2022-01-25 15:45:00 +0000419 shape = op.ofm_shapes[0].as_list()
420 shape[1:3] = upscaled_shape
421 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
422 out_tens.quantization = op.outputs[0].quantization.clone()
423 scaled_op.set_output_tensor(out_tens)
424 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200425
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200426 scaled_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000427 DebugDatabase.add_optimised(op, scaled_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200428
Tim Hall885033b2022-07-21 11:46:03 +0100429 # Last x2 upscaling
Rickard Boline546def2022-01-25 15:45:00 +0000430 if n > 1:
431 scaled_op = op.clone(f"_{n-1}")
432 scaled_op.inputs[0] = pre_op.outputs[0]
Tim Hall885033b2022-07-21 11:46:03 +0100433
434 if scaled_op.original_type == Op.ResizeBilinear:
435 if scaled_op.attrs["align_corners"]:
436 # no padding
437 scaled_op.attrs["padding"] = Padding.VALID
438 else:
439 # padding to the right and bottom (limits average pool to 8x8 kernel)
440 scaled_op.attrs["padding"] = Padding.EXPLICIT
441 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
442
443 # kernal size dependent on the upscaling factor
444 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
445 else: # Op.ResizeNearestNeighbor
446 if scaled_op.attrs["align_corners"]:
447 # use depthwise conv to select the correct value
448 scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
449 else:
Johan Alfvéna64616c2022-10-17 12:29:12 +0200450 # Keep 1x1 kernel and average pool, this applies both when
451 # half-pixel-centers is True and False. Calculations are the
452 # same in the reference.
Tim Hall885033b2022-07-21 11:46:03 +0100453 pass
454
Rickard Boline546def2022-01-25 15:45:00 +0000455 scaled_op.outputs = outputs
456 scaled_op.outputs[0].ops = [scaled_op]
457 scaled_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000458 DebugDatabase.add_optimised(op, scaled_op)
Rickard Boline546def2022-01-25 15:45:00 +0000459
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200460 return op
461
462
Rickard Bolinfea15162022-07-04 16:19:16 +0000463def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
464 def _compute_interpolation_values(index, input_size, output_size):
465 scale = input_size / output_size
466 scaled_value = (index + 0.5 * half_pixel_centers) * scale - 0.5 * half_pixel_centers
467 lower_bound = max(np.floor(scaled_value), 0)
468
469 return scaled_value, lower_bound
470
471 def _compute_kernels(input_height, input_width, output_height, output_width):
472 kernels = []
473 for y in (1, 2):
474 for x in (1, 2):
475 sv_h, lb_h = _compute_interpolation_values(y, input_height, output_height)
476 sv_w, lb_w = _compute_interpolation_values(x, input_width, output_width)
477
478 # Interpolation values calculated for (x, y) = ([1, 2], [1, 2]) will always generalize to the whole
479 # input for upscale = 2 and input sizes >= 2x2 and be in the correct order for going left-to-right,
480 # top-to-bottom - same as the depthwise convolution strides across each tile
481 kernel = np.zeros((2, 2))
482 kernel[1, 1] = (1 - (sv_h - lb_h)) * (1 - (sv_w - lb_w))
483 kernel[0, 1] = (sv_h - lb_h) * (1 - (sv_w - lb_w))
484 kernel[1, 0] = (1 - (sv_h - lb_h)) * (sv_w - lb_w)
485 kernel[0, 0] = (sv_h - lb_h) * (sv_w - lb_w)
486 kernel *= 16
487 kernels.append(kernel)
488
489 return kernels
490
491 def _build_convolutions(op, kernels):
492 dw_op_attrs = {
493 "padding": Padding.TILE,
494 "stride_h": 1,
495 "stride_w": 1,
496 "strides": (1, 1, 1, 1),
497 "depth_multiplier": 1,
498 "channel_multiplier": 1,
499 "dilation_h_factor": 1,
500 "dilation_w_factor": 1,
501 "dilation": (1, 1, 1, 1),
502 }
503 ifm = op.ifm
504 ofm = op.ofm
505 ofm.ops = []
506 elem_size = 2 if ofm.dtype == DataType.int16 else 1
507
508 n, h, w, c = ifm.shape
509 _, _, ow, _ = ofm.shape
510
511 intermediate_tens = Tensor(ifm.shape, ifm.dtype, "intermediate_tens")
512 intermediate_tens.quantization = op.outputs[0].quantization.clone()
513 avgpool_op = op
514 avgpool_op.name = "rb_init_avgpool"
515 avgpool_op.type = Op.AvgPool
516 avgpool_op.attrs["padding"] = Padding.VALID
517 avgpool_op.attrs["stride_w"] = 1
518 avgpool_op.attrs["stride_h"] = 1
519 avgpool_op.attrs["filter_width"] = 1
520 avgpool_op.attrs["filter_height"] = 1
521 avgpool_op.attrs["strides"] = [1, 1, 1, 1]
522 avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
523
524 avgpool_op.add_input_tensor(ifm)
525 avgpool_op.set_output_tensor(intermediate_tens)
526 avgpool_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000527 DebugDatabase.add_optimised(op, op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000528
529 dw_conv = Operation(Op.DepthwiseConv2DBias, "depthwise_conv")
530 dw_conv._original_type = Op.ResizeBilinear
531 dw_conv.write_shape = Shape4D(n, h, w, c)
532 dw_conv.write_offset = Shape4D(0, 0, 0, 0)
533
534 # Set the output rounding mode. Resize bilinear requires rounding away from zero. Therefore, we need to
535 # adjust the accumulated value by a "small" amount before applying natural rounding. The "small" amount
536 # should be big enough to cause a x.5 to be rounded correctly but small enough not to cause smaller
537 # values to be incorrectly rounded
538 ofm.quantization.next_after = True
539 dw_conv.rounding_mode = NpuRoundingMode.NATURAL
540
541 # Double height and width stride to write the output of each of the four depthwise convolutions below
542 # interleaved with each other when combined with OFM tile base offsets.
543 dw_conv.ofm_stride_multiplier = [1, 2, 2] # C/H/W
544
545 # Choose tile padding direction - pad by 1 with edge values in two direction.
546 # For example, TL (top left) will pad top and left in H/W-plane in all channels.
547 directions = [[1, 1, 0, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 1]] # TL, TR, BL, BR
548 for i in (0, 1):
549 for j in (0, 1):
550 index = i * 2 + j
551 dw_conv.name = f"depthwise_conv_{index}"
552 dw_op_attrs["explicit_padding"] = directions[index]
553 dw_conv.attrs.update(dw_op_attrs)
554
555 # This will offset the start of the write by modifying the Tile 0 base address
556 dw_conv.tile_base_offsets_ofm[0] = (i * ow + j) * c * elem_size
557
558 ofm.ops.append(dw_conv)
559 dw_conv.outputs = [ofm]
560
561 kernel = kernels[index]
562 shape = [2, 2, 1, c]
563 kernel = np.dstack([kernel] * c)
564
565 quant = QuantizationParameters()
566 quant.zero_point = 0
567 quant.scale_f32 = 1.0 / 16
568
569 dw_conv.inputs = []
570 dw_conv.add_input_tensor(intermediate_tens)
571 dw_conv.add_input_tensor(
572 create_const_tensor(
573 "weights",
574 shape,
575 intermediate_tens.dtype,
576 np.array(kernel).reshape(shape),
Rickard Bolinfea15162022-07-04 16:19:16 +0000577 quantization=quant,
578 ),
579 )
580
581 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
582 # need to append the bias tensor as resize ops only have 2 inputs
583 assert len(dw_conv.inputs) == 2
584 dw_conv.inputs.append(None)
Rickard Bolin017b4cc2022-09-23 10:16:48 +0000585 fixup_bias_tensors(dw_conv, None, None, dtype=DataType.int32)
Rickard Bolinfea15162022-07-04 16:19:16 +0000586
587 dw_conv.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000588 DebugDatabase.add_optimised(op, dw_conv)
589
Rickard Bolinfea15162022-07-04 16:19:16 +0000590 dw_conv = dw_conv.clone(f"_{index}")
591 return op
592
593 _, input_height, input_width, _ = op.ifm.shape
594 _, output_height, output_width, _ = op.ofm.shape
595
596 kernels = _compute_kernels(input_height, input_width, output_height, output_width)
597 op = _build_convolutions(op, kernels)
598
599 return op
600
601
Tim Hall885033b2022-07-21 11:46:03 +0100602def fixup_resize(op, arch, nng):
603 if op.type.is_resize_op() and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200604 if op.ifm_shapes[0] == op.ofm_shapes[0]:
Tim Hall885033b2022-07-21 11:46:03 +0100605 # Bypass the resize op which is essentially a NOP
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200606 op.inputs = op.inputs[:1]
607 op.type = Op.Identity
608 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
Tim Hall885033b2022-07-21 11:46:03 +0100609 convert_resize_1x1_to_add(op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000610 elif op.type == Op.ResizeBilinear and op.attrs.get("half_pixel_centers", False):
611 convert_resizebilinear_to_depthwise_convolutions(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200612 else:
Tim Hall885033b2022-07-21 11:46:03 +0100613 convert_resize_to_upscale_and_average_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200614
615 return op
616
617
618def convert_nop_split_to_identity(op, arch, nng):
619 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
620 # the list comprehension should return a list with a single tensor
621 # if it shouldn't, remove_passthrough_tensor will fail appropriately
622 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
623 op.type = Op.Identity
624 return op
625
626
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100627def rewrite_fully_connected_input(op: Operation, arch, nng):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200628
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100629 if op.type == Op.FullyConnected:
630 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
631 assert new_shape is not None, "Tensor can not be reshaped to 2D"
632 op.ifm_shapes[0] = new_shape
Johan Alfvén65835e02022-10-13 10:49:30 +0200633
634 if op.ifm_shapes[0].batch > 1 and op.ofm_shapes[0].batch == 1:
635 # If IFM is batching then also make sure OFM is batching
636 h, w = op.ofm_shapes[0].height, op.ofm_shapes[0].width
637 op.ofm_shapes[0] = Shape4D([h * w, 1, 1, op.ofm_shapes[0].depth])
638
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200639 return op
640
641
642def convert_batched_fc_shape(op, arch, nng):
643 if op.type == Op.FullyConnected:
644 # Check if the first dimension indicates batching
645 if op.ifm_shapes[0].batch > 1:
646 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
647 n = op.ifm_shapes[0].batch
648 h, w = batching_split.get(n, (1, n))
649 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
650
651 # Reshape Weights to be 4D. IO becomes HWIO
652 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100653 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
654 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200655
656 n = op.ofm_shapes[0].batch
657 h, w = batching_split.get(n, (1, n))
658 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
659 return op
660
661
662def unfuse_activation_function(op):
663 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
664 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
665 op.activation = None
666 out_tens = op.outputs[0]
667 intermediate_tens = out_tens.clone("_act_intermediate")
668 act_op.set_output_tensor(out_tens)
669 act_op.add_input_tensor(intermediate_tens)
670 op.set_output_tensor(intermediate_tens)
671 act_op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000672 DebugDatabase.add_optimised(op, act_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200673
674
675def rewrite_stridedslice_output(op, arch, nng):
676 if not op.run_on_npu or op.type != Op.StridedSlice:
677 return op
678
679 new_axis_mask = op.attrs["new_axis_mask"]
680 shrink_axis_mask = op.attrs["shrink_axis_mask"]
681
682 if shrink_axis_mask == 0 and new_axis_mask == 0:
683 return op
684
685 axis_4D = [0] * len(op.outputs)
686 for idx, out_tens in enumerate(op.outputs):
687 output_shape = list(out_tens.shape)
688
689 if shrink_axis_mask != 0:
690 n = 0
691 axis = 0
692 while shrink_axis_mask:
693 prev_mask = shrink_axis_mask
694 n += 1
695 shrink_axis_mask &= shrink_axis_mask - 1
696 axis = int(math.log2(prev_mask - shrink_axis_mask))
697 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
698
699 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
700 op.attrs["shrink_axis_mask"] = 0
701 if axis >= 0:
702 axis_4D[idx] = axis + (4 - len(output_shape))
703 else:
704 axis_4D[idx] = axis
705 op.ofm_shapes[idx] = Shape4D(output_shape)
706
707 elif new_axis_mask != 0:
708 n = 0
709 axis = 0
710 while new_axis_mask:
711 prev_mask = new_axis_mask
712 n += 1
713 new_axis_mask &= new_axis_mask - 1
714 axis = int(math.log2(prev_mask - new_axis_mask))
715 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
716 new_axis_mask >>= 1
717
718 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
719 op.attrs["new_axis_mask"] = 0
720 if axis >= 0:
721 axis_4D[idx] = axis + (4 - len(output_shape))
722 else:
723 axis_4D[idx] = axis
724 op.ofm_shapes[idx] = Shape4D(output_shape)
725
726 op.attrs["split_axis_4D"] = axis_4D
727 return op
728
729
730def rewrite_unpack_output(op, arch, nng):
731 tens = op.outputs[0]
732 if op.run_on_npu and op.type == Op.Unpack:
733 # Unpack is also referred to as Unstack
734 axis = int(op.attrs["axis"])
735 if axis < 0: # Convert to positive axis
736 axis = len(op.inputs[0].shape) + 1 + axis
737 op.type = Op.UnpackReshaped
738 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
739
740 axis_4D = axis + (4 - len(desired_output_shape))
741 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
742
743 for idx, out_tens in enumerate(op.outputs):
744 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
745 return op
746
747
748def add_padding_fields(op, arch, nng):
749 if op.run_on_npu:
750 if "padding" in op.attrs:
751 input_shape = op.ifm_shapes[0]
752 output_shape = op.ofm_shapes[0]
753 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
754 kernel_size = op.inputs[1].shape[:2]
755 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
756 kernel_size = op.attrs["ksize"][1:3]
757 else:
758 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
759
760 if op.type == Op.Conv2DBackpropInputSwitchedBias:
761 upscaling_factor = output_shape.height // input_shape.height
762 padding, skirt = calc_upscaled_padding_and_skirt(
763 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
764 )
765 else:
766 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200767 op.attrs["padding"],
768 op.kernel,
769 input_shape,
770 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200771 )
772
773 op.attrs["explicit_padding"] = padding
774 op.attrs["skirt"] = skirt
775
776 return op
777
778
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200779def reorder_depthwise_weights(op, arch, nng):
780 if op.type.is_depthwise_conv2d_op():
781 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100782 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
783 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200784 weight_tensor.weight_transpose_depthwise = True
785
786 return op
787
788
Raul Farkas090f18a2023-01-24 16:29:06 +0000789def fixup_strided_conv(op, arch, nng):
790 if op.type != Op.Conv2DBias:
Louis Verhaard43d27582022-03-17 14:06:00 +0100791 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200792 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100793 weight_tensor = op.weights
794 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200795
Raul Farkas090f18a2023-01-24 16:29:06 +0000796 # Do not optimize if op is not the first in the network and stride is
797 # supported by the hardware
798 if op.op_index != 0 and stride_x < 4:
799 return op
800 op.ifm.needs_linear_format = True
801
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200802 if (
Raul Farkas090f18a2023-01-24 16:29:06 +0000803 (stride_x == 2 or stride_x == 4)
Louis Verhaard43d27582022-03-17 14:06:00 +0100804 and ifm_shape.depth <= 4
805 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200806 and weight_tensor is not None
807 and weight_tensor.shape[1] >= 2
808 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100809 k_w, _ = op.get_kernel_size()
Raul Farkas090f18a2023-01-24 16:29:06 +0000810 curr_padding_x = needed_total_padding(ifm_shape.width, stride_x, k_w)
811 optimised_padding_x = needed_total_padding(ifm_shape.width // stride_x, 1, (k_w + 1) // stride_x)
812 padding_type = op.attrs.get("padding", None)
813
814 # If padding is enabled, check if current padding matches optimised padding
815 if not padding_type or (padding_type != Padding.VALID and curr_padding_x != optimised_padding_x):
Louis Verhaard43d27582022-03-17 14:06:00 +0100816 # Horizontal padding would become different after optimisation; this would not work
817 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200818 # IFM
Raul Farkas090f18a2023-01-24 16:29:06 +0000819 op.ifm_shapes[0] = Shape4D(
820 [ifm_shape.batch, ifm_shape.height, ifm_shape.width // stride_x, ifm_shape.depth * stride_x]
821 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200822
823 # Weights
824 weight_shape = weight_tensor.shape
825 if weight_shape[1] % 2 != 0:
826 weight_shape[1] = weight_shape[1] + 1
827 padded_array = np.zeros(weight_shape)
828 for i in range(weight_shape[0]):
829 padded_array[i] = np.vstack(
830 [
James Peet7519d502021-07-19 16:47:58 +0100831 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200832 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
833 ]
834 )
James Peet7519d502021-07-19 16:47:58 +0100835 weight_tensor.values = padded_array
Raul Farkas090f18a2023-01-24 16:29:06 +0000836
837 # Change weight shape based on stride_x
838 weight_shape[1] //= stride_x
839 weight_shape[2] *= stride_x
840
James Peet7519d502021-07-19 16:47:58 +0100841 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200842 weight_tensor.set_all_shapes(weight_shape)
843 # If multiple copies of the weights are used, we could avoid
844 # them having the same address by changing the value_id
845 weight_tensor.value_id = uuid.uuid4()
846
847 # Strides
848 stride_x = 1
849 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
850
851 return op
852
853
854def convert_conv_to_fc(op, arch, nng):
855 # Conv 1x1 can be equivalent to Fully Connected.
856 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
857 # caching/double buffering for the weights.
858 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
859 if op.type == Op.Conv2DBias:
860 h = op.ifm_shapes[0].height
861 w = op.ifm_shapes[0].width
862 kh, kw, _, _ = op.inputs[1].shape
863 if h == 1 and w == 1 and kh == 1 and kw == 1:
864 # Overwrite this op as a Fully Connected Op
865 op.name += "_fc"
866 op.type = Op.FullyConnected
867 op.attrs = {
868 "weights_format": 0,
869 }
870 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
871 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100872 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
873 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200874
875 DebugDatabase.add_optimised(op, op)
876 return op
877
878
879def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
880 if op.run_on_npu and op.type.is_relu_op():
881 ifm = op.inputs[0]
882 ofm = op.outputs[0]
883 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
884 # and requires its own to be inserted
885 if not check_quantized_tens_scaling_equal(ifm, ofm):
886 # Override this op with its own primary op (avgpool)
887 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
888 # And fuse the original activation function to it
889 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200890 # Add explicit rescaling
891 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
892 multiplier, shift = scaling.quantise_scale(rescale)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200893 relu_fused_op.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200894 # Tidy up and assign the ifm and ofm to the new op
895 ifm.consumer_list.remove(op)
896
897 relu_fused_op.add_input_tensor(ifm)
898 relu_fused_op.set_output_tensor(ofm)
899 relu_fused_op.set_ifm_ofm_shapes()
900 op = relu_fused_op
901 return op
902
903
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200904def convert_softmax(op, arch, nng):
905 if op.type == Op.Softmax and op.run_on_npu:
906 softmax = SoftMax(op)
907 op = softmax.get_graph()
908 return op
909
910
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200911def convert_prelu(op, arch, nng):
912 if op.type == Op.Prelu:
913 ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
914 if None in (ifm, alpha, ofm):
915 return op
916
Fredrik Svedberg66591652022-08-29 10:51:27 +0200917 if alpha.values is not None:
918 # If const alpha check for possible optimisations
919 alpha_zp = alpha.quantization.zero_point
920 alpha_scale = alpha.quantization.scale_f32
921 # If all alpha values are the same the PReLU can be converted to LeakyRelu
Rickard Bolin5fdcf172022-12-19 12:56:17 +0000922 alpha_min = (alpha.values.min().astype(int) - alpha_zp) * alpha_scale
923 alpha_max = (alpha.values.max().astype(int) - alpha_zp) * alpha_scale
Fredrik Svedberg66591652022-08-29 10:51:27 +0200924 if alpha_min == alpha_max:
925 # or even a Relu
926 if alpha_min == 0:
927 new_op = Op.Relu
928 else:
929 new_op = Op.LeakyRelu
930 op.attrs["alpha"] = alpha_min
931 # setup alpha_scaling for bit exact result
932 ifm_scale = ifm.quantization.scale_f32
933 ofm_scale = ofm.quantization.scale_f32
934 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha_scale, ofm_scale)
935 op.attrs["alpha_scaling"] = (alpha.values.min() - alpha_zp, alpha_scale, alpha_shift)
936 # Change op type
937 op.type = new_op
938 op.name = op.name.replace("Prelu", new_op.name)
939 del op.inputs[1] # Remove alpha tensor
940 return op
941 elif alpha_max < 1:
942 # If alpha_max is less than 1 convert PReLU to Max(alpha * IFM, identity * IFM)
943 # Multiply with alpha tensor
944 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
945 mul_alpha.add_input_tensor(ifm)
946 mul_alpha.add_input_tensor(alpha)
947 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
948 mul_alpha.set_output_tensor(fm_alpha)
949 mul_alpha.set_ifm_ofm_shapes()
950 DebugDatabase.add_optimised(op, mul_alpha)
951 if check_quantized_tens_scaling_equal(ifm, ofm):
952 # No scaling is needed
953 fm_id = ifm
954 else:
955 # Add multiplication with identity
956 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
957 mul_identity.add_input_tensor(ifm)
958 # Create const tensor containing identity as scalar
959 quantization = ifm.quantization.clone()
960 quantization.scale_f32 = np.float32(1)
961 quantization.zero_point = 0
962 one = create_const_tensor("one_const", [], ifm.dtype, [1], quantization=quantization)
963 mul_identity.add_input_tensor(one)
964 # Make sure that fm_id is allocated to a different address than fm_alpha
965 fm_id = ofm.clone(op.name + "_id", set_unique=True)
966 mul_identity.set_output_tensor(fm_id)
967 mul_identity.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000968 DebugDatabase.add_optimised(op, mul_identity)
Fredrik Svedberg66591652022-08-29 10:51:27 +0200969
970 # Combine scaled and alpha multiplied values
971 max_op = Operation(Op.Maximum, op.name + "_max")
972 max_op.add_input_tensor(fm_alpha)
973 max_op.add_input_tensor(fm_id)
974 max_op.set_output_tensor(ofm)
975 max_op.set_ifm_ofm_shapes()
976
977 DebugDatabase.add_optimised(op, max_op)
978 ifm.consumer_list.remove(op)
979 return max_op
980
981 # Catch all PReLU conversion for the cases that could not be optimised above
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200982 no_scale_quant = ifm.quantization.clone()
983 no_scale_quant.scale_f32 = None
984 no_scale_quant.zero_point = 0
Fredrik Svedberg66591652022-08-29 10:51:27 +0200985 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200986
987 # Select values < 0
988 min_op = Operation(Op.Minimum, op.name + "_min")
989 min_op.add_input_tensor(ifm)
990 min_op.add_input_tensor(zero)
991 fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
992 min_op.set_output_tensor(fm_negative)
993 min_op.set_ifm_ofm_shapes()
994 DebugDatabase.add_optimised(op, min_op)
995
996 # and multiply with alpha tensor
997 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
998 mul_alpha.add_input_tensor(fm_negative)
999 mul_alpha.add_input_tensor(alpha)
1000 fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
1001 mul_alpha.set_output_tensor(fm_alpha)
1002 mul_alpha.set_ifm_ofm_shapes()
1003 DebugDatabase.add_optimised(op, mul_alpha)
1004
1005 # Select (and scale) values > 0
1006 relu_op = Operation(Op.Relu, op.name + "_relu")
1007 relu_op.add_input_tensor(ifm)
1008 fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1009 relu_op.set_output_tensor(fm_scaled)
1010 relu_op.set_ifm_ofm_shapes()
1011 DebugDatabase.add_optimised(op, relu_op)
1012
1013 # Add scaled and alpha multiplied values (without scaling)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001014 add_op = Operation(Op.Add, op.name + "_add")
1015 add_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001016 add_op.add_input_tensor(fm_alpha)
1017 add_op.add_input_tensor(fm_scaled)
1018 add_op.set_output_tensor(ofm)
1019 add_op.set_ifm_ofm_shapes()
1020
1021 DebugDatabase.add_optimised(op, add_op)
1022 ifm.consumer_list.remove(op)
1023 op = add_op
1024
1025 return op
1026
1027
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001028def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
1029 r"""Whenever there is a subgraph with this topology:
1030
Jonas Ohlssond8575072022-03-30 10:30:25 +02001031 Input X For X = -1 or X > 0
1032 | \ / This subgraph can be replaced with either
1033 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
1034 | /
1035 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001036 """
1037
1038 if op.type == Op.Maximum:
1039 # finds the Mul input(s) to the Max
1040 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
1041 if len(muls) == 1:
1042 mul = muls[0].ops[0]
1043 elif len(muls) == 2:
1044 # In the case both inputs are Muls, find the one with the same input as the Max
Fredrik Svedberg66591652022-08-29 10:51:27 +02001045 mul_ifms = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1]
1046 if len(mul_ifms):
1047 mul = mul_ifms[0].ops[0]
1048 else:
1049 # Not using same input
1050 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001051 else:
1052 # No Mul inputs
1053 return op
1054
1055 # make sure the Mul doesn't have any other consumers
1056 mul_ofm = mul.outputs[0]
1057 if len(mul_ofm.consumers()) != 1:
1058 return op
1059 # make sure the Mul doesn't have a fused activation function
1060 if mul.activation:
1061 return op
1062 ifm, ofm = op.get_ifm_ofm()
1063 if ifm is None or ofm is None:
1064 return op
1065
1066 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1067 return op
1068 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
1069 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
1070 return op
1071
1072 # finds the branched input that goes to both the Max and the Mul
1073 shared = set(op.inputs) & set(mul.inputs)
1074 if len(shared) == 1:
1075 shared_in = shared.pop()
1076 # find the constant scalar input to the Mul
1077 const_tens = (set(mul.inputs) - {shared_in}).pop()
1078 # check that it is a scalar
1079 if const_tens.shape != []:
1080 return op
1081 const = const_tens.ops[0]
1082 # check that it is a constant
1083 if const.type != Op.Const:
1084 return op
1085 # Remove the Mul from the shared input's consumers
1086 shared_in.consumer_list.remove(mul)
1087 else:
1088 return op
1089
1090 val = const.outputs[0].values
1091 if val >= 0:
1092 new_op = Op.LeakyRelu
1093 op.attrs["alpha"] = val
1094 # to produce bit exact results, the alpha is not enough;
1095 # save additional scaling info in attr "alpha_scale", to be used as input
1096 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +01001097 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001098 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
1099 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
1100 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
1101 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
1102 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
1103 elif val == -1:
1104 new_op = Op.Abs
1105 else:
1106 return op
1107
1108 op.type = new_op
1109 op.name = op.name.replace("Maximum", new_op.name)
1110 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
1111 op.inputs = [shared_in]
1112 op.set_ifm_ofm_shapes()
1113
1114 # Record optimisation in debug database
1115 DebugDatabase.add_optimised(op, op)
1116
1117 return op
1118
1119
1120def convert_hardswish_to_lut(op, arch, nng):
1121 if op.type == Op.HardSwish:
1122 ifm, ofm = op.get_ifm_ofm()
1123 # Generate the LUT
1124 ifm_scale = np.double(ifm.quantization.scale_f32)
1125 ofm_scale = np.double(ofm.quantization.scale_f32)
1126 zp_in = ifm.quantization.zero_point
1127 zp_out = ofm.quantization.zero_point
1128 ifm_scale_hires = (1 / 128) * ifm_scale
1129 relu_multiplier = np.double(3 / 32768)
1130 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
1131 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
1132 # Use 16bit scale
1133 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
1134 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
1135
1136 values = []
1137 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1138 quantized_min = min(ix)
1139 quantized_max = max(ix)
1140 for x in ix:
1141 input_value = x - zp_in
1142 input_value_hires = input_value * 128
1143 # Compute the input value on essentially the output scale, not shifted yet
1144 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
1145 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
1146 relu_value = np.int16(input_value_hires)
1147 if relu_shift < 31:
1148 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
1149
1150 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
1151
1152 if relu_shift < 31:
1153 relu_value = fp_math.shift_left16(relu_value, 1)
1154
1155 if relu_shift > 31:
1156 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
1157
1158 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
1159 # Now convert that to a 16bit fixedpoint value in [0, 1]
1160 relu_value = (relu_value + (1 << 15)) >> 1
1161 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
1162 shift = 31 - out_shift
1163 shift = -shift if shift < 0 else 0
1164 # Finally apply the output shift
1165 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
1166 lut_result = min(quantized_max, max(quantized_min, lut_result))
1167 values.append(lut_result)
1168 return convert_to_lut(op, values, "hardswish")
1169 return op
1170
1171
1172def convert_lrelu_to_mul_max(op, arch):
1173 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
1174 # (the opposite of convert_mul_max_to_abs_or_lrelu)
1175 ifm, ofm = op.get_ifm_ofm()
1176 if ifm is None or ofm is None:
1177 return op
1178
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001179 alpha = np.float32(op.attrs["alpha"])
1180 use_mul_max = 0 < alpha < 1
Fredrik Svedberg36424312022-09-16 09:39:26 +02001181 is_converted_prelu = "alpha_scaling" in op.attrs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001182 if use_mul_max:
1183 mul_ifm = ifm
1184 new_op = Op.Maximum
1185 else:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001186 # Need to use a different approach for alpha < 0 or alpha > 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001187 no_scale_quant = ifm.quantization.clone()
1188 no_scale_quant.scale_f32 = None
1189 no_scale_quant.zero_point = 0
1190 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
1191
1192 # Select values < 0
1193 min_op = Operation(Op.Minimum, op.name + "_min")
1194 min_op.add_input_tensor(ifm)
1195 min_op.add_input_tensor(zero)
1196 mul_ifm = ifm.clone(op.name + "_negative", set_unique=True)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001197 if alpha < 0 and not is_converted_prelu:
1198 # For negative alpha that is not from a converted PReLU we need to use
1199 # int32 Mul below to perform the (negative) alpha scaling
1200 mul_ifm.dtype = DataType.int32
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001201 min_op.set_output_tensor(mul_ifm)
1202 min_op.set_ifm_ofm_shapes()
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001203 new_op = Op.Add
1204 op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001205 DebugDatabase.add_optimised(op, min_op)
1206
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001207 # Add multiplication with alpha
1208 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001209 mul_alpha.add_input_tensor(mul_ifm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001210 # Create const tensor containing alpha as scalar
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001211 quantization = ifm.quantization.clone()
1212 quantization.min = 0
1213 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
1214 quantization.zero_point = 0
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001215 alpha_dtype = mul_ifm.dtype
Fredrik Svedberg36424312022-09-16 09:39:26 +02001216 if is_converted_prelu:
1217 # The LeakyRelu was the result from convert_prelu and the scaling is provided
Fredrik Svedberg66591652022-08-29 10:51:27 +02001218 scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001219 mul_alpha.explicit_scaling = ExplicitScaling(False, [alpha_shift], [alpha_scale])
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001220 elif alpha == 0 or np.isinf(1 / alpha):
1221 # Handling of alpha near or at zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001222 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001223 scalar = 0
1224 else:
1225 quantization.scale_f32 = alpha
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001226 if alpha_dtype == DataType.int32:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001227 # When the datatype is int32 (alpha negative) we need to do the scaling with the multiplication
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001228 scalar, _ = scaling.elementwise_mul_scale(ifm.quantization.scale_f32, alpha, ofm.quantization.scale_f32)
1229 else:
1230 scalar = 1
Tim Hall3b1578e2023-01-13 17:57:25 +00001231 alpha_tens = create_const_tensor(op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], quantization=quantization)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001232 mul_alpha.add_input_tensor(alpha_tens)
1233 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1234 mul_alpha.set_output_tensor(fm_alpha)
1235 mul_alpha.set_ifm_ofm_shapes()
1236 DebugDatabase.add_optimised(op, mul_alpha)
1237
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001238 if not use_mul_max:
1239 relu_op = Operation(Op.Relu, op.name + "_relu")
1240 relu_op.add_input_tensor(ifm)
1241 fm_id = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1242 relu_op.set_output_tensor(fm_id)
1243 relu_op.set_ifm_ofm_shapes()
1244 DebugDatabase.add_optimised(op, relu_op)
1245 elif check_quantized_tens_scaling_equal(ifm, ofm):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001246 # No identity multiplication is needed
1247 fm_id = ifm
1248 else:
1249 # Add multiplication with identity
1250 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1251 mul_identity.add_input_tensor(ifm)
1252 # Create const tensor containing identity as scalar
1253 quantization = ifm.quantization.clone()
1254 quantization.min = 0
1255 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001256 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001257 quantization.zero_point = 0
Tim Hall3b1578e2023-01-13 17:57:25 +00001258 identity_tens = create_const_tensor(op.name + "_id_scalar", [], ifm.dtype, [1], quantization=quantization)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001259 mul_identity.add_input_tensor(identity_tens)
1260 # Make sure that fm_id is allocated to a different address than fm_alpha
1261 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1262 mul_identity.set_output_tensor(fm_id)
1263 mul_identity.set_ifm_ofm_shapes()
1264 DebugDatabase.add_optimised(op, mul_identity)
1265
1266 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001267 op.type = new_op
1268 op.name = op.name.replace("LeakyRelu", new_op.name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001269 op.inputs = []
1270 ifm.consumer_list.remove(op)
1271 op.add_input_tensor(fm_alpha)
1272 op.add_input_tensor(fm_id)
1273 op.set_ifm_ofm_shapes()
1274
1275 DebugDatabase.add_optimised(op, op)
1276 return op
1277
1278
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001279def convert_to_lut8(op, fn, fn_name):
1280 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
1281 # fn is a function(real) -> real
1282 ifm, ofm = op.get_ifm_ofm()
1283 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1284 return op
1285 # Generate the LUT
1286 ifm_scale = np.double(ifm.quantization.scale_f32)
1287 ofm_scale = np.double(ofm.quantization.scale_f32)
1288 zp_in = ifm.quantization.zero_point
1289 zp_out = ofm.quantization.zero_point
1290 values = []
1291 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1292 quantized_min = min(ix)
1293 quantized_max = max(ix)
1294 for x in ix:
1295 x_real = ifm_scale * (x - zp_in)
1296 y_real = fn(x_real)
1297 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1298 lut_result = min(quantized_max, max(quantized_min, lut_result))
1299 values.append(lut_result)
1300 return convert_to_lut(op, values, fn_name)
1301
1302
1303def convert_lrelu_to_lut(op, arch):
1304 ifm, ofm = op.get_ifm_ofm()
1305 # Generate the LUT
1306 alpha = op.attrs["alpha"]
1307 ifm_scale = np.double(ifm.quantization.scale_f32)
1308 ofm_scale = np.double(ofm.quantization.scale_f32)
1309 zp_in = ifm.quantization.zero_point
1310 zp_out = ofm.quantization.zero_point
1311 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1312 alpha_scalar = 1
1313 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1314 if "alpha_scaling" in op.attrs:
1315 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1316 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1317 values = []
1318 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1319 quantized_min = min(ix)
1320 quantized_max = max(ix)
1321 for x in ix:
1322 if x < zp_in:
1323 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1324 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1325 )
1326 else:
1327 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1328 lut_result = min(quantized_max, max(quantized_min, lut_result))
1329 values.append(lut_result)
1330 return convert_to_lut(op, values, "lrelu")
1331
1332
1333def convert_lrelu(op, arch, nng):
1334 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
1335 if op.type != Op.LeakyRelu:
1336 return op
1337 ifm, ofm = op.get_ifm_ofm()
1338 if ifm is None or ofm is None:
1339 return op
Fredrik Svedberg36424312022-09-16 09:39:26 +02001340 alpha = op.attrs["alpha"]
1341 if alpha == 0:
1342 # When alpha is 0 the opertion can be converted to a ReLU
1343 op.type = Op.Relu
1344 op.name = op.name.replace("LeakyRelu", op.type.name)
1345 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001346 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1347 # use LUT for int8/uint8
1348 return convert_lrelu_to_lut(op, arch)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001349 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16 and alpha > 0:
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001350 # use LeakyRelu unmodified for int16 with equal input/output scaling and positive alpha
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001351 return op
1352 return convert_lrelu_to_mul_max(op, arch)
1353
1354
1355def convert_tanh_sigmoid_to_lut(op, arch, nng):
1356 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
1357 if op.type == Op.Sigmoid:
1358 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1359 elif op.type == Op.Tanh:
1360 return convert_to_lut8(op, math.tanh, "tanh")
1361 return op
1362
1363
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001364def remove_memory_only_ops(op, arch):
1365 if op.run_on_npu and op.type in memory_only_ops:
1366 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001367
1368
1369def fuse_activation_function_with_prev(op, arch, nng):
1370 # if op is a no-op: attempts to move the activation function to the preceding op
1371 if not op.attrs.get("is_nop", False) or op.activation is None:
1372 return op
1373 ifm, ofm = op.get_ifm_ofm()
1374 if ifm is None or ofm is None:
1375 return op
1376 # finds the input(s) to the operation
1377 prev_op = ifm.ops[0]
1378 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1379 fuse = (
1380 prev_op.run_on_npu
1381 and prev_op.type.npu_block_type != NpuBlockType.Default
1382 and len(ifm.ops) == 1
1383 and len(prev_op.outputs[0].consumers()) == 1
1384 and prev_op.activation is None
1385 )
1386 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1387 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1388 # LUT currently only works correctly for elementwise ops
1389 fuse = False
1390 if not fuse:
1391 return op
1392 # Move the fused activation function + corresponding info to prev_op
1393 prev_op.activation = op.activation
1394 prev_op.forced_output_quantization = op.forced_output_quantization
1395 if op.activation_lut is not None:
1396 prev_op.set_activation_lut(op.activation_lut)
1397 # Bypass op
1398 prev_op.set_output_tensor(ofm)
wilisa0179a89042022-11-02 17:18:43 +00001399 DebugDatabase.add_optimised(prev_op, prev_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001400 return op
1401
1402
1403def _leading_pad_ok(leading_pad, stride, kernel_size):
1404 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1405 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1406 max_size = kernel_size // 2
1407 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1408
1409
1410def replace_pad_by_hw_pad(op: Operation, arch, nng):
1411 """
1412 Tries to completely remove a PAD operator by using hardware padding.
1413 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1414 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1415 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1416 if both operations can be run on the NPU.
1417 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1418 """
1419 if (
1420 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +00001421 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001422 and op.run_on_npu
1423 and op.attrs["padding"] == Padding.VALID
1424 ):
1425 pad_op = op.ifm.ops[0]
1426 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1427 return op
1428 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1429 return op
1430 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1431 k = op.kernel
1432 k_w, k_h = k.dilated_wh()
1433
1434 # Check if the PAD operator can be replaced by hardware padding
1435 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1436 # Too much padding, it would require hardware padding to actually insert zeros
1437 return op
1438 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1439 return op
1440
1441 if op.type.is_avgpool_op():
1442 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1443 for pad, k_size in (
1444 (left, k_w),
1445 (right, k_w),
1446 (top, k_h),
1447 (bottom, k_h),
1448 ):
1449 if pad not in (0, k_size // 2):
1450 return op
1451 # Average pool is converted to depthwise, because NPU average pool + same padding
1452 # has a special implementation that is different from PAD followed by average pool with
1453 # valid padding.
1454 k_w, k_h = op.kernel.width, op.kernel.height
1455 ifm = op.ifm
1456 # Remember other inputs
1457 other_inputs = op.inputs[1:]
1458 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1459 quantization = QuantizationParameters(0.0, 255.0)
1460 quantization.scale_f32 = 1.0 / (k_w * k_h)
1461 quantization.zero_point = 0
1462 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1463 weights = np.full(shape, 1)
1464
1465 weight_tens = create_const_tensor(
1466 op.name + "_weights",
1467 shape,
1468 op.ifm.dtype,
1469 weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001470 purpose=TensorPurpose.Weights,
1471 quantization=quantization,
1472 )
James Peet7519d502021-07-19 16:47:58 +01001473 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001474 op.type = Op.DepthwiseConv2DBias
1475 op.inputs = []
1476 op.add_input_tensor(ifm)
1477 op.add_input_tensor(weight_tens)
1478 # Add bias tensor, all biases set to 0
1479 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001480 fixup_bias_tensors(op, arch, nng, DataType.int32)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001481 # Add other inputs
1482 op.inputs.extend(other_inputs)
1483 op.rounding_mode = NpuRoundingMode.NATURAL
1484
1485 # Bypass the PAD operator
1486 op.set_input_tensor(pad_op.ifm, 0)
1487 # Adjust the padding attributes of the convolution operator
1488 op.attrs["padding"] = Padding.EXPLICIT
1489 op.attrs["explicit_padding"] = (top, left, bottom, right)
1490 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +00001491 DebugDatabase.add_optimised(op, op)
1492
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001493 return op
1494
1495
1496def convert_pad(op: Operation, arch, nng):
1497 """
1498 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1499 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1500 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1501 """
1502 if op.type != Op.Pad or not op.run_on_npu:
1503 return op
1504 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1505
1506 ifm = op.ifm
1507 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001508 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001509 ofm = op.ofm
1510 assert ofm is not None
1511 ofm.ops = []
1512 ofm_shape = op.ofm_shapes[0]
1513
1514 # Average pool op that copies IFM to the right place inside the OFM
1515 shp0 = Shape4D(0, 0, 0, 0)
1516 shp_top = shp0.with_height(top)
1517 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1518 avgpool_op.activation = op.activation
1519 quant = ofm.quantization
1520 pad_value = quant.zero_point
1521 # Add operations that fill the borders of the OFM
1522 if top > 0:
1523 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1524 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001525 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001526 )
1527 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1528 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1529 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1530 if bottom > 0:
1531 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1532 zero_tens = create_const_tensor(
1533 op.name + "_bottom",
1534 shape.as_list(),
1535 ofm.dtype,
1536 shape.elements() * [pad_value],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001537 quantization=quant,
1538 )
1539 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1540 create_avg_pool_for_concat(
1541 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1542 )
1543 if left > 0:
1544 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1545 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001546 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001547 )
1548 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1549 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1550 if right > 0:
1551 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1552 zero_tens = create_const_tensor(
Tim Hall3b1578e2023-01-13 17:57:25 +00001553 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001554 )
1555 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1556 create_avg_pool_for_concat(
1557 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1558 )
1559
1560 op.type = Op.ConcatTFLite
1561 return avgpool_op
1562
1563
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001564def fixup_bias_tensors(op, arch, nng, dtype=None):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001565 if op.type.needs_bias() and op.bias is None:
1566 # Op has no bias, add bias tensor filled with zeros
1567 nr_biases = op.inputs[1].shape[-1]
1568 bias_values = [0] * nr_biases
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001569 # The DataType of the bias tensor can be explicitly provided or deduced from the ifm
1570 # DataType. Default is int32 bias for 8-bit ifms and int64 for int16 ifms.
1571 # For int16 the selected bias DataType will have an impact on the scaling
1572 # used when encoding the scales and biases later. The default mode will match the
1573 # refence with reduced scaling for int64 bias.
1574 # This means that in cases (in the graph optimiser) where DepthwiseConv2DBias
1575 # is used to emulate average pool int32 bias should be selected for full precision
1576 # int16 scaling.
1577 if dtype is None:
1578 dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32
1579 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001580 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1581
1582 return op
1583
1584
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001585def fixup_asymmetric_weights(op, arch, nng):
1586 if op.run_on_npu and (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op()):
1587 if op.ifm.dtype == DataType.int8:
1588 if not np.all(op.weights.quantization.zero_point == 0):
1589 print(f"Warning: {op.type} '{op.name}' has asymmetric weights, zero points have been adjusted.")
1590 op.weights.quantization.zero_point *= 0
1591
1592 return op
1593
1594
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001595def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1596 if op.type == Op.Mean and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001597 inp, axis = op.inputs
1598 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001599 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001600 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001601 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001602
1603 # Height and width axes have different index depending on dimensions
1604 if axis.shape == [] or axis.shape[0] == 1: # single axis
1605 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1606 if dims in (2, 3):
1607 if axis == 0:
1608 h, w = shape[axis], 1
1609 else:
1610 h, w = 1, shape[axis]
1611 else:
1612 if axis == 1:
1613 h, w = shape[axis], 1
1614 else:
1615 h, w = 1, shape[axis]
1616 else: # multiple axes
1617 axis = sorted(axis.values)
1618 h, w = [shape[i] for i in axis]
1619
1620 # Set necessary depthwise attributes
1621 op.attrs.update(
1622 {
1623 "padding": Padding.VALID,
1624 "stride_h": 1,
1625 "stride_w": 1,
1626 "strides": (1, 1, 1, 1),
1627 "depth_multiplier": 1,
1628 "channel_multiplier": 1,
1629 "dilation_h_factor": 1,
1630 "dilation_w_factor": 1,
1631 "dilation": (1, 1, 1, 1),
1632 }
1633 )
1634 # Change op type
1635 op.type = Op.DepthwiseConv2DBias
1636 # Set IFM/OFM shapes after changing op type
1637 op.set_ifm_ofm_shapes()
1638
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001639 weight_scale, bias = 1, 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001640 ofmq, ifmq = op.ofm.quantization, inp.quantization
Johan Alfvén9d51ec42022-10-27 16:30:01 +02001641 if ifmq.is_scaling_equal(ofmq):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001642 # Here we can just use a simple AvgPool with truncating rounding,
1643 # as we're emulating simple integer division.
1644 op.rounding_mode = NpuRoundingMode.TRUNCATE
1645 op.type = Op.AvgPool
1646 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1647 else:
1648 op.rounding_mode = NpuRoundingMode.NATURAL
1649 weight_scale = 1 / (h * w)
1650 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1651 bias = -ifmq.zero_point * h * w
1652 fiq = ifmq.clone()
1653 fiq.zero_point = 0
1654 op.forced_input_quantization = fiq
1655
1656 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001657 def extend_dims(dim, in_shape):
1658 if dim < 4:
1659 in_shape = [1] + in_shape
1660 if dim == 2:
1661 in_shape += [1]
1662 return in_shape
1663
1664 if dims < 4 or dims_ofm < 4:
1665 # Fix the ofm dimension when keep_dims is false
1666 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1667 if isinstance(axis, int) and dims_ofm + 1 == dims:
1668 ofm_shape.insert(axis, 1)
1669 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1670 for i in axis:
1671 ofm_shape.insert(i, 1)
1672 shape = extend_dims(dims, shape)
1673 dims_ofm = len(ofm_shape)
1674 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1675 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001676
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001677 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001678 weight_shape = None
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001679 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001680 # This can only happen and be done for multiple axes, and
1681 # h * w <= 256 for DepthwiseConv2DBias
1682 # h * w <= 4096 for AvgPool
1683 # which is checked in supported ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001684 shape = [shape[0], 1, h * w, shape[3]]
1685 op.ifm_shapes[0] = Shape4D(shape)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001686 weight_shape = [1, h * w, shape[3], shape[0]]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001687 if h > 256 and op.type == Op.AvgPool:
1688 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1689
1690 # If the AvgPool version is used, we don't need to do anything else
1691 if op.type == Op.AvgPool:
wilisa0179a89042022-11-02 17:18:43 +00001692 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001693 return op
1694
1695 # Make unit weight tensor quantization
1696 weight_quant = ifmq.clone()
1697 weight_quant.min = 0
1698 weight_quant.max = 255
1699 weight_quant.scale_f32 = weight_scale
1700 weight_quant.zero_point = 0
1701
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001702 if weight_shape is None:
1703 # Set weight shape to [H,W,C,B]
1704 weight_shape = [h, w, shape[3], shape[0]]
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001705
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001706 # Add unit weight tensor
1707 op.set_input_tensor(
1708 create_const_tensor(
1709 "weights",
1710 weight_shape,
1711 inp.dtype,
1712 np.ones(weight_shape),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001713 quantization=weight_quant,
1714 ),
1715 1,
1716 )
James Peet7519d502021-07-19 16:47:58 +01001717 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001718
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001719 # Add bias tensor
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001720 bias_shape = [shape[-1]]
1721 op.inputs.append(create_const_tensor("bias", bias_shape, DataType.int32, np.ones(bias_shape) * bias))
wilisa0179a89042022-11-02 17:18:43 +00001722 DebugDatabase.add_optimised(op, op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001723
1724 return op
1725
1726
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001727def optimise_quantize(op: Operation, arch, nng):
1728
1729 if op.type == Op.Quantize and op.run_on_npu:
1730
1731 ifm, ofm = op.get_ifm_ofm()
1732 input_values = ifm.values
1733
1734 # Guard clause - input not const or no values to quantize
1735 if ifm.ops[0].type != Op.Const or input_values is None:
1736 return op
1737
1738 # Singular val in numpy array, convert to indexable array
1739 if input_values.ndim == 0:
1740 input_values = np.array([input_values])
1741
Fredrik Svedberg11563172022-07-06 14:54:12 +02001742 # requantized int8 to int8 or int16 to int16
1743 if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001744
1745 # scale needs to use double precision to match TFLite reference kernel
1746 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
1747 effective_multiplier, effective_shift = quantise_scale(effective_scale)
1748
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001749 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001750 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001751 input_val = val - ifm.quantization.zero_point
1752
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001753 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
1754 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001755
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001756 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
1757 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001758
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001759 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
1760 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001761
1762 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001763 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001764
1765 quantized_vals = []
1766 for val in input_values:
1767
1768 # Derive quantized value
1769 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001770 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
1771 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001772
1773 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001774 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
1775
1776 # Unsupported data type
1777 else:
1778 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001779
1780 # Make quantize op const and disconnect from parent node
1781
1782 # Remove reference of the current quant op from the parent tensor's consumer list
1783 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1784
1785 # Clear any references to parent node
1786 op.inputs = []
1787
1788 # Convert this quantize op to const
1789 op.type = Op.Const
1790
1791 return op
1792
1793
Ayaan Masood4965fae2022-06-29 11:30:57 +01001794def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
1795 """Static optimisation for SHAPE operator output value known at compile time"""
1796
1797 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
1798
1799 if op.type == Op.Shape and op.run_on_npu:
1800
1801 ifm, ofm = op.get_ifm_ofm()
1802
1803 if len(ifm.shape) != ofm.shape[0]:
1804 return op
1805
1806 # Remove reference of the current shape op from the parent tensor's consumer list
1807 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1808
1809 # Clear any references to parent node
1810 op.inputs = []
1811
1812 # Convert this SHAPE op to const
1813 op.type = Op.Const
1814
1815 # Add size calculation to shape output tensors
1816 ofm.values = np.array(ifm.shape)
1817
1818 return op
1819
1820
Tim Hallea4ba662022-11-11 18:19:53 +00001821def fixup_dilation_gt2(op, arch, nng):
1822 assert op.run_on_npu
1823 if op.type == Op.Conv2DBias or op.type == Op.DepthwiseConv2DBias:
1824 dilation_w, dilation_h = op.get_kernel_dilation()
1825
1826 # if dilation in either axis is greater than that supported by the hardware then we must manually dilate the
1827 # kernel
1828 if dilation_w > 2 or dilation_h > 2:
1829 kernel_w, kernel_h = op.get_kernel_size()
1830 kernel_ic = op.weights.shape[-2]
1831 kernel_oc = op.weights.shape[-1]
1832
1833 # if the dilation is a multiple of 2 then the hardware dialtion can be enabled to provide that multiple
1834 # of 2. this allows the kernel size to be reduced (via the scaled dilation) by half in that dimension.
1835 # odd = 1, even = 2
1836 hw_dilation_h = 1 if (dilation_h & 1) else 2
1837 hw_dilation_w = 1 if (dilation_w & 1) else 2
1838
1839 scale_dilation_h = dilation_h // hw_dilation_h
1840 scale_dilation_w = dilation_w // hw_dilation_w
1841
1842 # create new empty kernel (HWIO format)
1843 new_kernel_h = (kernel_h - 1) * scale_dilation_h + 1
1844 new_kernel_w = (kernel_w - 1) * scale_dilation_w + 1
1845
1846 new_kernel_shape = [new_kernel_h, new_kernel_w, kernel_ic, kernel_oc]
1847 new_kernel_values = np.zeros(new_kernel_shape, dtype=op.weights.values.dtype)
1848
1849 # copy the original kernel values into the new sparse kernel
1850 for h in range(0, kernel_h):
1851 for w in range(0, kernel_w):
1852 new_h = h * scale_dilation_h
1853 new_w = w * scale_dilation_w
1854 new_kernel_values[new_h, new_w, :, :] = op.weights.values[h, w, :, :]
1855
1856 # update the weight tensor with the new dilated kernel
1857 op.weights.shape = new_kernel_shape
1858 op.weights.values = new_kernel_values
1859
1860 # enable(=2) / disable(=1) hardware dilation
1861 op.attrs["dilation"] = (1, hw_dilation_h, hw_dilation_w, 1) # nhwc format
1862 op.attrs["dilation_h_factor"] = hw_dilation_h
1863 op.attrs["dilation_w_factor"] = hw_dilation_w
1864
1865 return op
1866
1867
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001868def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001869 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001870 return op
1871
1872
1873def tflite_optimise_graph(nng, arch):
Fredrik Svedberg11563172022-07-06 14:54:12 +02001874 # Compile time static optimisations
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001875 optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
1876
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001877 for idx, sg in enumerate(nng.subgraphs):
1878 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001879 nng,
1880 sg,
1881 arch,
1882 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01001883 optimisation_list,
1884 rewrite_unsupported=False,
1885 )
1886
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001887 # Pre-processing step
1888 pre_process_list = [
1889 supported_operator_check,
1890 set_ifm_ofm_op_shapes,
1891 ]
1892
Ayaan Masood4965fae2022-06-29 11:30:57 +01001893 for idx, sg in enumerate(nng.subgraphs):
1894 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1895 nng,
1896 sg,
1897 arch,
1898 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02001899 pre_process_list,
1900 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001901 )
1902
1903 # Handle Concat Ops
1904 for idx, sg in enumerate(nng.subgraphs):
1905 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1906 sg.refresh_after_modification()
1907
1908 # Handle Split Ops
1909 for idx, sg in enumerate(nng.subgraphs):
1910 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1911 nng,
1912 sg,
1913 arch,
1914 [],
1915 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1916 rewrite_unsupported=False,
1917 )
1918
1919 for idx, sg in enumerate(nng.subgraphs):
1920 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001921 nng,
1922 sg,
1923 arch,
1924 [rewrite_split_ops],
1925 [],
1926 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001927 )
1928
1929 # Handle sg input output
1930 for idx, sg in enumerate(nng.subgraphs):
1931 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001932 nng,
1933 sg,
1934 arch,
1935 [],
1936 [fix_sg_input_output],
1937 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001938 )
1939
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001940 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001941 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001942 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001943 sg.refresh_after_modification()
1944
1945 # Rewrite of operators
1946 op_rewrite_list = [
1947 set_tensor_equivalence,
1948 convert_mean_to_depthwise_conv_or_avgpool,
1949 convert_depthwise_to_conv,
1950 convert_conv_to_fc,
1951 convert_softmax,
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001952 convert_prelu,
Fredrik Svedberg36424312022-09-16 09:39:26 +02001953 convert_mul_max_to_abs_or_lrelu,
1954 convert_lrelu,
Raul Farkas090f18a2023-01-24 16:29:06 +00001955 fixup_strided_conv,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001956 convert_hardswish_to_lut,
1957 rewrite_fully_connected_input,
1958 convert_batched_fc_shape,
1959 fixup_conv2d_backprop,
1960 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001961 reorder_depthwise_weights,
Tim Hall885033b2022-07-21 11:46:03 +01001962 fixup_resize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001963 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001964 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001965 convert_tanh_sigmoid_to_lut,
1966 replace_pad_by_hw_pad,
Tim Hallea4ba662022-11-11 18:19:53 +00001967 fixup_dilation_gt2,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001968 ]
1969
1970 for idx, sg in enumerate(nng.subgraphs):
1971 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001972 nng,
1973 sg,
1974 arch,
1975 [],
1976 op_rewrite_list,
1977 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001978 )
1979
1980 for idx, sg in enumerate(nng.subgraphs):
1981 # remove passthrough tensors and attempt further optimizations
1982 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1983 nng,
1984 sg,
1985 arch,
1986 [remove_passthrough_tensor],
1987 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1988 )
1989
1990 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1991 # since ifm/ofm_shapes are of importance to this function
1992 for sg in nng.subgraphs:
1993 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1994 sg.refresh_after_modification()
1995
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01001996 # Make sure that const optimisations on subgraph outputs are handled correctly
1997 for sg in nng.subgraphs:
1998 for ofm in sg.output_tensors:
1999 if ofm.is_const and ofm.ops[0].type_changed:
2000 # Subgraph output cannot be const - insert a memory copy
2001 op = ofm.ops[0]
2002 ofm_clone = ofm.clone()
2003 ofm_clone.values = ofm.values
2004 ofm.values = None
Tim Hall3b1578e2023-01-13 17:57:25 +00002005 zero = create_const_tensor("zero", [1], ofm.dtype, [0], quantization=ofm.quantization)
Fredrik Svedbergf3c7d552022-11-04 09:48:49 +01002006 memcpy = create_add_nop(f"{ofm.name}_copy")
2007 memcpy.add_input_tensor(ofm_clone)
2008 memcpy.add_input_tensor(zero)
2009 memcpy.set_output_tensor(ofm)
2010 memcpy.set_ifm_ofm_shapes()
2011 op.set_output_tensor(ofm_clone)
2012 DebugDatabase.add_optimised(op, memcpy)
2013
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002014 return nng