blob: 3646b01eff30b7c5a2a0ec1f57a1a0a999b8ea2f [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from . import rewrite_graph
26from . import scaling
27from .api import NpuRoundingMode
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020028from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020029from .data_type import DataType
30from .debug_database import DebugDatabase
31from .errors import UnsupportedFeatureError
32from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020033from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020034from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020035from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020036from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020037from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020038from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020039from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020040from .graph_optimiser_util import needed_total_padding
41from .graph_optimiser_util import set_ifm_ofm_op_shapes
42from .graph_optimiser_util import set_tensor_equivalence
43from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020044from .numeric_util import round_away_zero
45from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020046from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020047from .operation import NpuBlockType
48from .operation import Op
49from .operation import Operation
50from .operation import Padding
51from .operation_util import create_avgpool_nop
52from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010053from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020054from .shape4d import Shape4D
55from .softmax import SoftMax
56from .tensor import check_quantized_tens_scaling_equal
57from .tensor import create_const_tensor
58from .tensor import create_equivalence_id
59from .tensor import QuantizationParameters
60from .tensor import Tensor
61from .tensor import TensorPurpose
62from .tflite_mapping import optype_to_builtintype
63
64passthrough_nodes = (Op.Identity,)
65
66
67def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
68 """Creates an average pool for the given concat op/input feature map"""
69 ofm = concat_op.ofm
70 avgpool_op = create_avgpool_nop(name)
71 avgpool_op.inputs = [ifm]
72 avgpool_op.outputs = [ofm]
73
74 avgpool_op.write_offset = write_offset
75 avgpool_op.write_shape = ifm_shape
76 ofm.ops.append(avgpool_op)
77 DebugDatabase.add_optimised(concat_op, avgpool_op)
78 avgpool_op.ifm_shapes.append(ifm_shape)
79 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
80 avgpool_op.memory_function = Op.ConcatSliceWrite
81 return avgpool_op
82
83
84def remove_passthrough_tensor(tens, arch, nng):
85 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
86 assert len(tens.ops[0].inputs) == 1
87 tens = tens.ops[0].inputs[0]
88 return tens
89
90
91def rewrite_concat_ops(op, arch):
92 if not op.run_on_npu or not op.type.is_concat_op():
93 return
94
95 axis_4D = 0
96 ofm = op.ofm
97 ofm.ops = []
98 offset = 0
99
100 unfuse_activation_function(op)
101
102 if op.type == Op.Pack:
103 # Pack is also referred to as Stack
104 axis = int(op.attrs["axis"])
105 if axis < 0: # Convert to positive axis
106 axis = len(op.inputs[0].shape) + 1 + axis
107
108 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
109
110 axis_4D = axis + (4 - len(desired_shape))
111
112 for idx, inp in enumerate(op.inputs):
113 op.ifm_shapes[idx] = Shape4D(desired_shape)
114 op.type = Op.PackReshaped
115
116 inputs, axis = op.get_concat_inputs_axis()
117 for idx, inp in enumerate(inputs):
118 if op.type != Op.PackReshaped:
119 op.ifm_shapes[idx] = Shape4D(inp.shape)
120 if axis >= 0:
121 axis_4D = axis + (4 - len(inp.shape))
122 else:
123 axis_4D = axis
124 write_offset = [0, 0, 0, 0]
125 write_offset[axis_4D] = offset
126 concat_end = offset + op.ifm_shapes[idx][axis_4D]
127 create_avg_pool_for_concat(
128 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
129 )
130 offset = concat_end
131 assert ofm.shape[axis] == offset
132
133 return op
134
135
136def rewrite_split_ops(tens, arch, nng):
137
138 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
139 split_op = tens.ops[0]
140
141 # Not supported so leave it and run on CPU
142 if not split_op.run_on_npu:
143 return tens
144
145 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
146
147 tens.ops = []
148 new_op = Operation(Op.SplitSliceRead, split_op.name)
149 new_op.inputs = [inp]
150 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000151 if None in (offset_end, offset_start):
152 read_shape = None
153 else:
154 # the read shape is relative to each start offset
155 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200156
157 # For Split the offset cannot be extracted from the tensor so it has to
158 # be calculated from the index of the output tensor
159 if axis is not None:
160 # Get the start and end of the split
161 offset_start = [0] * 4
162 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
163 for idx, out in enumerate(outputs):
164 if axis_4D_list is not None:
165 axis_4D = axis_4D_list[idx]
166 else:
167 split_op.ofm_shapes[idx] = Shape4D(out.shape)
168 if axis >= 0:
169 axis_4D = axis + (4 - len(out.shape))
170 else:
171 axis_4D = axis
172
173 if out == tens:
174 ofm_shape_idx = idx
175 read_shape = split_op.ofm_shapes[idx]
176 break
177
178 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
179
180 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
181 new_op.read_shapes[0] = read_shape
182 new_op.run_on_npu = True
183 new_op.set_output_tensor(tens)
184 new_op.ifm_shapes.append(Shape4D(inp.shape))
185 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
186 DebugDatabase.add_optimised(split_op, new_op)
187
188 return tens
189
190
191def remove_SplitSliceRead(op, arch):
192
193 if op.type == Op.SplitSliceRead:
194 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
195 if (
196 len(op.ofm.consumer_list) == 1
197 and op.ofm.consumer_list[0] is not None
198 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200199 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200200 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
201 ):
202 # SplitSliceRead can be performed by tensor consumer
203 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200204 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200205 else:
206 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
207 avgpool_op.add_input_tensor(op.ifm)
208 avgpool_op.outputs = [op.ofm]
209 op.ofm.ops.remove(op)
210 op.ofm.ops.append(avgpool_op)
211 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
212 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
213 avgpool_op.read_offsets[0] = op.read_offsets[0]
214 avgpool_op.read_shapes[0] = op.read_shapes[0]
215
216 op.ifm.consumer_list.remove(op)
217 DebugDatabase.add_optimised(op, avgpool_op)
218
219
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200220def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
221 k_w, k_h = kernel.dilated_wh()
222 s_x, s_y = kernel.stride
223 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
224 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
225 if padding_type == Padding.SAME:
226 left_pad = (xpad + 0) // 2
227 right_pad = (xpad + 1) // 2
228 top_pad = (ypad + 0) // 2
229 bottom_pad = (ypad + 1) // 2
230 elif padding_type == Padding.VALID:
231 left_pad = 0
232 right_pad = 0
233 top_pad = 0
234 bottom_pad = 0
235 elif padding_type == Padding.EXPLICIT:
236 # Padding is specified in a PAD operator which has been bypassed.
237 top, left, bottom, right = explicit_padding
238 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
239 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
240 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000241 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200242 padding = (top_pad, left_pad, bottom_pad, right_pad)
243 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
244 return padding, skirt
245
246
247def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
248 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
249 if padding_type == Padding.SAME:
250 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
251 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
252 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
253 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
254 left_pad = max(kernel_width - 1 - right_pad, 0)
255 top_pad = max(kernel_height - 1 - bottom_pad, 0)
256 elif padding_type == Padding.VALID:
257 right_pad = max(kernel_width - 2, 0)
258 bottom_pad = max(kernel_height - 2, 0)
259 left_pad = kernel_width - 1
260 top_pad = kernel_height - 1
261 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000262 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200263 padding = (top_pad, left_pad, bottom_pad, right_pad)
264 skirt = padding
265 return padding, skirt
266
267
268def fixup_conv2d_backprop(op, arch, nng):
269 if op.type == Op.Conv2DBackpropInput:
270 # flip the inputs
271 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
272 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000273 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200274
275 # Update strides
276 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
277
278 return op
279
280
281# Convert the op to an elementwise add
Tim Hall885033b2022-07-21 11:46:03 +0100282def convert_resize_1x1_to_add(op):
283 op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200284 op.name = op.name + "_add"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200285 # Create an input tensor filled with zeros
286 shape = op.ofm_shapes[0].as_list()
287 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100288 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200289 tens.quantization = QuantizationParameters(0.0, 255.0)
290 tens.quantization.scale_f32 = 1.0
291 tens.quantization.zero_point = 0
292 tens.consumer_list = [op]
293 tens_op = op.inputs[1].ops[0]
294 tens_op.set_output_tensor(tens)
295 # Set the add inputs
296 op.inputs[1] = op.inputs[0]
297 op.inputs[0] = tens
298 op.set_ifm_ofm_shapes()
299
300 return op
301
302
Tim Hall885033b2022-07-21 11:46:03 +0100303# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
304# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
305# to select the appropriate nearest neighbor value
306def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
307 ifm = op.ifm
308 ofm = op.ofm
309 output_depth = ofm.shape[-1]
310 dw_op_attrs = {
311 "padding": Padding.VALID,
312 "stride_h": 1,
313 "stride_w": 1,
314 "strides": (1, 1, 1, 1),
315 "depth_multiplier": 1,
316 "channel_multiplier": 1,
317 "dilation_h_factor": 1,
318 "dilation_w_factor": 1,
319 "dilation": (1, 1, 1, 1),
320 }
321
322 # change resizebilinear to depthwise
323 op.type = Op.DepthwiseConv2DBias
324 op.attrs.update(dw_op_attrs)
325 op.set_input_tensor(ifm, 0) # ifm tensor index
326 op.activation = None
327
328 # add input resample to resize by x2
329 op.ifm_resampling_mode = resampling_mode.NEAREST
330
331 # don't care about the rounding mode as it is nearest neighbor
332
333 # setup weight tensor
334 weight_quant = QuantizationParameters()
335 weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
336 weight_quant.zero_point = 0
337 weight_quant.quant_dim = 0
338 ofm_dtype = ofm.dtype
339 if ofm_dtype == DataType.uint8:
340 weight_value_dtype = np.uint8
341 weight_quant.quant_min = 0
342 weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
343 else:
344 if ofm_dtype == DataType.int8:
345 weight_value_dtype = np.int8
346 else:
347 assert ofm_dtype == DataType.int16
348 weight_value_dtype = np.int16
349
350 weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
351 weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
352
353 weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
354
355 # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
356 # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
357 # below-and-right (i.e. next) to it (D).
358 # 0---1---2
359 # | A | B |
360 # 1---*---+
361 # | C | D |
362 # 2---+---+
363 weight_values = [0] * (upscale_factor * upscale_factor)
364 centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
365 weight_values[centre_coeff] = 1
366
367 # add weight tensor, this will discard the size tensor of the resize op
368 op.set_input_tensor(
369 create_const_tensor(
370 "weights",
371 weight_shape,
372 ofm.dtype,
373 np.array(weight_values).reshape(weight_shape),
374 value_dtype=weight_value_dtype,
375 quantization=weight_quant,
376 ),
377 1, # inputs tensor weight index
378 )
379
380 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
381 # need to append the bias tensor as resize ops only have 2 inputs
382 assert len(op.inputs) == 2
383 op.inputs.append(None)
384 fixup_bias_tensors(op, None, None)
385
386 # finally update the shape incase we've change the tensor shapes or connections
387 op.set_ifm_ofm_shapes()
388
389 return op
390
391
392# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
393# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
394# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
395def convert_resize_to_upscale_and_average_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200396 pre_op = op
397 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000398 dtype = op.ifm.dtype
Tim Hall885033b2022-07-21 11:46:03 +0100399
Rickard Boline546def2022-01-25 15:45:00 +0000400 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Tim Hall47c76362022-07-18 21:26:47 +0100401 op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
Tim Hall3c5cfe92022-03-16 16:31:57 +0000402 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200403
404 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
Tim Hall47c76362022-07-18 21:26:47 +0100405
406 # Get upscale factor that was calculated in the supported operators check
407 upscale_factor = op.attrs["upscale_factor"]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200408
Rickard Boline546def2022-01-25 15:45:00 +0000409 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100410 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
411 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
Rickard Boline546def2022-01-25 15:45:00 +0000412 n = int(np.log2(upscale_factor))
413
Tim Hall885033b2022-07-21 11:46:03 +0100414 # Perform x2 upscaling n-1 times
Rickard Boline546def2022-01-25 15:45:00 +0000415 scaled_op = pre_op
416 for count in range(n - 1):
417 if count > 0:
418 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200419 scaled_op.inputs[0] = pre_op.outputs[0]
420
Tim Hall885033b2022-07-21 11:46:03 +0100421 # Nearest neighbor x2 upscaling
Tim Hall47c76362022-07-18 21:26:47 +0100422 upscaled_shape = upscaled_shape * 2
Rickard Boline546def2022-01-25 15:45:00 +0000423 shape = op.ofm_shapes[0].as_list()
424 shape[1:3] = upscaled_shape
425 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
426 out_tens.quantization = op.outputs[0].quantization.clone()
427 scaled_op.set_output_tensor(out_tens)
428 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200429
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200430 scaled_op.set_ifm_ofm_shapes()
431
Tim Hall885033b2022-07-21 11:46:03 +0100432 # Last x2 upscaling
Rickard Boline546def2022-01-25 15:45:00 +0000433 if n > 1:
434 scaled_op = op.clone(f"_{n-1}")
435 scaled_op.inputs[0] = pre_op.outputs[0]
Tim Hall885033b2022-07-21 11:46:03 +0100436
437 if scaled_op.original_type == Op.ResizeBilinear:
438 if scaled_op.attrs["align_corners"]:
439 # no padding
440 scaled_op.attrs["padding"] = Padding.VALID
441 else:
442 # padding to the right and bottom (limits average pool to 8x8 kernel)
443 scaled_op.attrs["padding"] = Padding.EXPLICIT
444 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
445
446 # kernal size dependent on the upscaling factor
447 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
448 else: # Op.ResizeNearestNeighbor
449 if scaled_op.attrs["align_corners"]:
450 # use depthwise conv to select the correct value
451 scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
452 else:
453 # keep 1x1 kernel and average pool
454 pass
455
Rickard Boline546def2022-01-25 15:45:00 +0000456 scaled_op.outputs = outputs
457 scaled_op.outputs[0].ops = [scaled_op]
458 scaled_op.set_ifm_ofm_shapes()
459
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200460 return op
461
462
Tim Hall885033b2022-07-21 11:46:03 +0100463def fixup_resize(op, arch, nng):
464 if op.type.is_resize_op() and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200465 if op.ifm_shapes[0] == op.ofm_shapes[0]:
Tim Hall885033b2022-07-21 11:46:03 +0100466 # Bypass the resize op which is essentially a NOP
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200467 op.inputs = op.inputs[:1]
468 op.type = Op.Identity
469 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
Tim Hall885033b2022-07-21 11:46:03 +0100470 convert_resize_1x1_to_add(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200471 else:
Tim Hall885033b2022-07-21 11:46:03 +0100472 convert_resize_to_upscale_and_average_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200473
474 return op
475
476
477def convert_nop_split_to_identity(op, arch, nng):
478 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
479 # the list comprehension should return a list with a single tensor
480 # if it shouldn't, remove_passthrough_tensor will fail appropriately
481 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
482 op.type = Op.Identity
483 return op
484
485
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100486def rewrite_fully_connected_input(op: Operation, arch, nng):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200487
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100488 if op.type == Op.FullyConnected:
489 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
490 assert new_shape is not None, "Tensor can not be reshaped to 2D"
491 op.ifm_shapes[0] = new_shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200492 return op
493
494
495def convert_batched_fc_shape(op, arch, nng):
496 if op.type == Op.FullyConnected:
497 # Check if the first dimension indicates batching
498 if op.ifm_shapes[0].batch > 1:
499 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
500 n = op.ifm_shapes[0].batch
501 h, w = batching_split.get(n, (1, n))
502 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
503
504 # Reshape Weights to be 4D. IO becomes HWIO
505 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100506 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
507 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200508
509 n = op.ofm_shapes[0].batch
510 h, w = batching_split.get(n, (1, n))
511 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
512 return op
513
514
515def unfuse_activation_function(op):
516 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
517 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
518 op.activation = None
519 out_tens = op.outputs[0]
520 intermediate_tens = out_tens.clone("_act_intermediate")
521 act_op.set_output_tensor(out_tens)
522 act_op.add_input_tensor(intermediate_tens)
523 op.set_output_tensor(intermediate_tens)
524 act_op.set_ifm_ofm_shapes()
525
526
527def rewrite_stridedslice_output(op, arch, nng):
528 if not op.run_on_npu or op.type != Op.StridedSlice:
529 return op
530
531 new_axis_mask = op.attrs["new_axis_mask"]
532 shrink_axis_mask = op.attrs["shrink_axis_mask"]
533
534 if shrink_axis_mask == 0 and new_axis_mask == 0:
535 return op
536
537 axis_4D = [0] * len(op.outputs)
538 for idx, out_tens in enumerate(op.outputs):
539 output_shape = list(out_tens.shape)
540
541 if shrink_axis_mask != 0:
542 n = 0
543 axis = 0
544 while shrink_axis_mask:
545 prev_mask = shrink_axis_mask
546 n += 1
547 shrink_axis_mask &= shrink_axis_mask - 1
548 axis = int(math.log2(prev_mask - shrink_axis_mask))
549 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
550
551 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
552 op.attrs["shrink_axis_mask"] = 0
553 if axis >= 0:
554 axis_4D[idx] = axis + (4 - len(output_shape))
555 else:
556 axis_4D[idx] = axis
557 op.ofm_shapes[idx] = Shape4D(output_shape)
558
559 elif new_axis_mask != 0:
560 n = 0
561 axis = 0
562 while new_axis_mask:
563 prev_mask = new_axis_mask
564 n += 1
565 new_axis_mask &= new_axis_mask - 1
566 axis = int(math.log2(prev_mask - new_axis_mask))
567 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
568 new_axis_mask >>= 1
569
570 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
571 op.attrs["new_axis_mask"] = 0
572 if axis >= 0:
573 axis_4D[idx] = axis + (4 - len(output_shape))
574 else:
575 axis_4D[idx] = axis
576 op.ofm_shapes[idx] = Shape4D(output_shape)
577
578 op.attrs["split_axis_4D"] = axis_4D
579 return op
580
581
582def rewrite_unpack_output(op, arch, nng):
583 tens = op.outputs[0]
584 if op.run_on_npu and op.type == Op.Unpack:
585 # Unpack is also referred to as Unstack
586 axis = int(op.attrs["axis"])
587 if axis < 0: # Convert to positive axis
588 axis = len(op.inputs[0].shape) + 1 + axis
589 op.type = Op.UnpackReshaped
590 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
591
592 axis_4D = axis + (4 - len(desired_output_shape))
593 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
594
595 for idx, out_tens in enumerate(op.outputs):
596 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
597 return op
598
599
600def add_padding_fields(op, arch, nng):
601 if op.run_on_npu:
602 if "padding" in op.attrs:
603 input_shape = op.ifm_shapes[0]
604 output_shape = op.ofm_shapes[0]
605 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
606 kernel_size = op.inputs[1].shape[:2]
607 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
608 kernel_size = op.attrs["ksize"][1:3]
609 else:
610 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
611
612 if op.type == Op.Conv2DBackpropInputSwitchedBias:
613 upscaling_factor = output_shape.height // input_shape.height
614 padding, skirt = calc_upscaled_padding_and_skirt(
615 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
616 )
617 else:
618 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200619 op.attrs["padding"],
620 op.kernel,
621 input_shape,
622 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200623 )
624
625 op.attrs["explicit_padding"] = padding
626 op.attrs["skirt"] = skirt
627
628 return op
629
630
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200631def reorder_depthwise_weights(op, arch, nng):
632 if op.type.is_depthwise_conv2d_op():
633 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100634 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
635 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200636 weight_tensor.weight_transpose_depthwise = True
637
638 return op
639
640
641def optimise_strided_conv(op, arch, nng):
Louis Verhaard43d27582022-03-17 14:06:00 +0100642 if op.type != Op.Conv2DBias or op.op_index != 0:
643 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200644 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100645 weight_tensor = op.weights
646 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200647
648 if (
Louis Verhaard43d27582022-03-17 14:06:00 +0100649 stride_x == 2
650 and ifm_shape.depth <= 4
651 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200652 and weight_tensor is not None
653 and weight_tensor.shape[1] >= 2
654 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100655 k_w, _ = op.get_kernel_size()
656 curr_padding_x = needed_total_padding(ifm_shape.width, 2, k_w)
657 optimised_padding_x = needed_total_padding(ifm_shape.width // 2, 1, (k_w + 1) // 2)
658 if curr_padding_x != optimised_padding_x:
659 # Horizontal padding would become different after optimisation; this would not work
660 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200661 # IFM
662 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
663
664 # Weights
665 weight_shape = weight_tensor.shape
666 if weight_shape[1] % 2 != 0:
667 weight_shape[1] = weight_shape[1] + 1
668 padded_array = np.zeros(weight_shape)
669 for i in range(weight_shape[0]):
670 padded_array[i] = np.vstack(
671 [
James Peet7519d502021-07-19 16:47:58 +0100672 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200673 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
674 ]
675 )
James Peet7519d502021-07-19 16:47:58 +0100676 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200677 weight_shape[1] //= 2
678 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100679 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200680 weight_tensor.set_all_shapes(weight_shape)
681 # If multiple copies of the weights are used, we could avoid
682 # them having the same address by changing the value_id
683 weight_tensor.value_id = uuid.uuid4()
684
685 # Strides
686 stride_x = 1
687 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
688
689 return op
690
691
692def convert_conv_to_fc(op, arch, nng):
693 # Conv 1x1 can be equivalent to Fully Connected.
694 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
695 # caching/double buffering for the weights.
696 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
697 if op.type == Op.Conv2DBias:
698 h = op.ifm_shapes[0].height
699 w = op.ifm_shapes[0].width
700 kh, kw, _, _ = op.inputs[1].shape
701 if h == 1 and w == 1 and kh == 1 and kw == 1:
702 # Overwrite this op as a Fully Connected Op
703 op.name += "_fc"
704 op.type = Op.FullyConnected
705 op.attrs = {
706 "weights_format": 0,
707 }
708 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
709 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100710 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
711 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200712
713 DebugDatabase.add_optimised(op, op)
714 return op
715
716
717def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
718 if op.run_on_npu and op.type.is_relu_op():
719 ifm = op.inputs[0]
720 ofm = op.outputs[0]
721 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
722 # and requires its own to be inserted
723 if not check_quantized_tens_scaling_equal(ifm, ofm):
724 # Override this op with its own primary op (avgpool)
725 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
726 # And fuse the original activation function to it
727 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200728 # Add explicit rescaling
729 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
730 multiplier, shift = scaling.quantise_scale(rescale)
731 relu_fused_op.rescale = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200732 # Tidy up and assign the ifm and ofm to the new op
733 ifm.consumer_list.remove(op)
734
735 relu_fused_op.add_input_tensor(ifm)
736 relu_fused_op.set_output_tensor(ofm)
737 relu_fused_op.set_ifm_ofm_shapes()
738 op = relu_fused_op
739 return op
740
741
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200742def convert_softmax(op, arch, nng):
743 if op.type == Op.Softmax and op.run_on_npu:
744 softmax = SoftMax(op)
745 op = softmax.get_graph()
746 return op
747
748
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200749def convert_prelu(op, arch, nng):
750 if op.type == Op.Prelu:
751 ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
752 if None in (ifm, alpha, ofm):
753 return op
754
755 no_scale_quant = ifm.quantization.clone()
756 no_scale_quant.scale_f32 = None
757 no_scale_quant.zero_point = 0
758 zero = create_const_tensor("zero_const", [1, 1, 1, 1], ifm.dtype, [0], quantization=no_scale_quant)
759
760 # Select values < 0
761 min_op = Operation(Op.Minimum, op.name + "_min")
762 min_op.add_input_tensor(ifm)
763 min_op.add_input_tensor(zero)
764 fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
765 min_op.set_output_tensor(fm_negative)
766 min_op.set_ifm_ofm_shapes()
767 DebugDatabase.add_optimised(op, min_op)
768
769 # and multiply with alpha tensor
770 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
771 mul_alpha.add_input_tensor(fm_negative)
772 mul_alpha.add_input_tensor(alpha)
773 fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
774 mul_alpha.set_output_tensor(fm_alpha)
775 mul_alpha.set_ifm_ofm_shapes()
776 DebugDatabase.add_optimised(op, mul_alpha)
777
778 # Select (and scale) values > 0
779 relu_op = Operation(Op.Relu, op.name + "_relu")
780 relu_op.add_input_tensor(ifm)
781 fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
782 relu_op.set_output_tensor(fm_scaled)
783 relu_op.set_ifm_ofm_shapes()
784 DebugDatabase.add_optimised(op, relu_op)
785
786 # Add scaled and alpha multiplied values (without scaling)
787 add_op = Operation(Op.RescaleAdd, op.name + "_add")
788 add_op.rescale = (1, 0) # No scale or shift
789 add_op.add_input_tensor(fm_alpha)
790 add_op.add_input_tensor(fm_scaled)
791 add_op.set_output_tensor(ofm)
792 add_op.set_ifm_ofm_shapes()
793
794 DebugDatabase.add_optimised(op, add_op)
795 ifm.consumer_list.remove(op)
796 op = add_op
797
798 return op
799
800
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200801def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
802 r"""Whenever there is a subgraph with this topology:
803
Jonas Ohlssond8575072022-03-30 10:30:25 +0200804 Input X For X = -1 or X > 0
805 | \ / This subgraph can be replaced with either
806 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
807 | /
808 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200809 """
810
811 if op.type == Op.Maximum:
812 # finds the Mul input(s) to the Max
813 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
814 if len(muls) == 1:
815 mul = muls[0].ops[0]
816 elif len(muls) == 2:
817 # In the case both inputs are Muls, find the one with the same input as the Max
818 mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
819 else:
820 # No Mul inputs
821 return op
822
823 # make sure the Mul doesn't have any other consumers
824 mul_ofm = mul.outputs[0]
825 if len(mul_ofm.consumers()) != 1:
826 return op
827 # make sure the Mul doesn't have a fused activation function
828 if mul.activation:
829 return op
830 ifm, ofm = op.get_ifm_ofm()
831 if ifm is None or ofm is None:
832 return op
833
834 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
835 return op
836 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
837 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
838 return op
839
840 # finds the branched input that goes to both the Max and the Mul
841 shared = set(op.inputs) & set(mul.inputs)
842 if len(shared) == 1:
843 shared_in = shared.pop()
844 # find the constant scalar input to the Mul
845 const_tens = (set(mul.inputs) - {shared_in}).pop()
846 # check that it is a scalar
847 if const_tens.shape != []:
848 return op
849 const = const_tens.ops[0]
850 # check that it is a constant
851 if const.type != Op.Const:
852 return op
853 # Remove the Mul from the shared input's consumers
854 shared_in.consumer_list.remove(mul)
855 else:
856 return op
857
858 val = const.outputs[0].values
859 if val >= 0:
860 new_op = Op.LeakyRelu
861 op.attrs["alpha"] = val
862 # to produce bit exact results, the alpha is not enough;
863 # save additional scaling info in attr "alpha_scale", to be used as input
864 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +0100865 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200866 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
867 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
868 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
869 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
870 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
871 elif val == -1:
872 new_op = Op.Abs
873 else:
874 return op
875
876 op.type = new_op
877 op.name = op.name.replace("Maximum", new_op.name)
878 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
879 op.inputs = [shared_in]
880 op.set_ifm_ofm_shapes()
881
882 # Record optimisation in debug database
883 DebugDatabase.add_optimised(op, op)
884
885 return op
886
887
888def convert_hardswish_to_lut(op, arch, nng):
889 if op.type == Op.HardSwish:
890 ifm, ofm = op.get_ifm_ofm()
891 # Generate the LUT
892 ifm_scale = np.double(ifm.quantization.scale_f32)
893 ofm_scale = np.double(ofm.quantization.scale_f32)
894 zp_in = ifm.quantization.zero_point
895 zp_out = ofm.quantization.zero_point
896 ifm_scale_hires = (1 / 128) * ifm_scale
897 relu_multiplier = np.double(3 / 32768)
898 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
899 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
900 # Use 16bit scale
901 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
902 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
903
904 values = []
905 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
906 quantized_min = min(ix)
907 quantized_max = max(ix)
908 for x in ix:
909 input_value = x - zp_in
910 input_value_hires = input_value * 128
911 # Compute the input value on essentially the output scale, not shifted yet
912 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
913 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
914 relu_value = np.int16(input_value_hires)
915 if relu_shift < 31:
916 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
917
918 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
919
920 if relu_shift < 31:
921 relu_value = fp_math.shift_left16(relu_value, 1)
922
923 if relu_shift > 31:
924 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
925
926 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
927 # Now convert that to a 16bit fixedpoint value in [0, 1]
928 relu_value = (relu_value + (1 << 15)) >> 1
929 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
930 shift = 31 - out_shift
931 shift = -shift if shift < 0 else 0
932 # Finally apply the output shift
933 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
934 lut_result = min(quantized_max, max(quantized_min, lut_result))
935 values.append(lut_result)
936 return convert_to_lut(op, values, "hardswish")
937 return op
938
939
940def convert_lrelu_to_mul_max(op, arch):
941 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
942 # (the opposite of convert_mul_max_to_abs_or_lrelu)
943 ifm, ofm = op.get_ifm_ofm()
944 if ifm is None or ofm is None:
945 return op
946
947 # Add multiplication with alpha
948 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
949 mul_alpha.add_input_tensor(ifm)
950 # Create const tensor containing alpha as scalar
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200951 alpha = np.float32(op.attrs["alpha"])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200952 quantization = ifm.quantization.clone()
953 quantization.min = 0
954 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
955 quantization.zero_point = 0
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200956 if np.isinf(1 / alpha):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200957 # Handling of alpha near zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200958 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200959 scalar = 0
960 else:
961 quantization.scale_f32 = alpha
962 scalar = alpha
963 alpha_tens = create_const_tensor(
964 op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
965 )
James Peet7519d502021-07-19 16:47:58 +0100966 alpha_tens.values = np.array([1])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200967 mul_alpha.add_input_tensor(alpha_tens)
968 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
969 mul_alpha.set_output_tensor(fm_alpha)
970 mul_alpha.set_ifm_ofm_shapes()
971 DebugDatabase.add_optimised(op, mul_alpha)
972
973 if check_quantized_tens_scaling_equal(ifm, ofm):
974 # No identity multiplication is needed
975 fm_id = ifm
976 else:
977 # Add multiplication with identity
978 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
979 mul_identity.add_input_tensor(ifm)
980 # Create const tensor containing identity as scalar
981 quantization = ifm.quantization.clone()
982 quantization.min = 0
983 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200984 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200985 quantization.zero_point = 0
986 identity_tens = create_const_tensor(
987 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
988 )
989 mul_identity.add_input_tensor(identity_tens)
990 # Make sure that fm_id is allocated to a different address than fm_alpha
991 fm_id = ofm.clone(op.name + "_id", set_unique=True)
992 mul_identity.set_output_tensor(fm_id)
993 mul_identity.set_ifm_ofm_shapes()
994 DebugDatabase.add_optimised(op, mul_identity)
995
996 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
997 op.type = Op.Maximum
998 op.name = op.name.replace("LeakyRelu", "Maximum")
999 op.inputs = []
1000 ifm.consumer_list.remove(op)
1001 op.add_input_tensor(fm_alpha)
1002 op.add_input_tensor(fm_id)
1003 op.set_ifm_ofm_shapes()
1004
1005 DebugDatabase.add_optimised(op, op)
1006 return op
1007
1008
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001009def convert_to_lut8(op, fn, fn_name):
1010 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
1011 # fn is a function(real) -> real
1012 ifm, ofm = op.get_ifm_ofm()
1013 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1014 return op
1015 # Generate the LUT
1016 ifm_scale = np.double(ifm.quantization.scale_f32)
1017 ofm_scale = np.double(ofm.quantization.scale_f32)
1018 zp_in = ifm.quantization.zero_point
1019 zp_out = ofm.quantization.zero_point
1020 values = []
1021 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1022 quantized_min = min(ix)
1023 quantized_max = max(ix)
1024 for x in ix:
1025 x_real = ifm_scale * (x - zp_in)
1026 y_real = fn(x_real)
1027 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1028 lut_result = min(quantized_max, max(quantized_min, lut_result))
1029 values.append(lut_result)
1030 return convert_to_lut(op, values, fn_name)
1031
1032
1033def convert_lrelu_to_lut(op, arch):
1034 ifm, ofm = op.get_ifm_ofm()
1035 # Generate the LUT
1036 alpha = op.attrs["alpha"]
1037 ifm_scale = np.double(ifm.quantization.scale_f32)
1038 ofm_scale = np.double(ofm.quantization.scale_f32)
1039 zp_in = ifm.quantization.zero_point
1040 zp_out = ofm.quantization.zero_point
1041 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1042 alpha_scalar = 1
1043 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1044 if "alpha_scaling" in op.attrs:
1045 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1046 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1047 values = []
1048 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1049 quantized_min = min(ix)
1050 quantized_max = max(ix)
1051 for x in ix:
1052 if x < zp_in:
1053 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1054 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1055 )
1056 else:
1057 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1058 lut_result = min(quantized_max, max(quantized_min, lut_result))
1059 values.append(lut_result)
1060 return convert_to_lut(op, values, "lrelu")
1061
1062
1063def convert_lrelu(op, arch, nng):
1064 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
1065 if op.type != Op.LeakyRelu:
1066 return op
1067 ifm, ofm = op.get_ifm_ofm()
1068 if ifm is None or ofm is None:
1069 return op
1070 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1071 # use LUT for int8/uint8
1072 return convert_lrelu_to_lut(op, arch)
1073 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
1074 # use LeakyRelu unmodified for int16 with equal input/output scaling
1075 return op
1076 return convert_lrelu_to_mul_max(op, arch)
1077
1078
1079def convert_tanh_sigmoid_to_lut(op, arch, nng):
1080 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
1081 if op.type == Op.Sigmoid:
1082 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1083 elif op.type == Op.Tanh:
1084 return convert_to_lut8(op, math.tanh, "tanh")
1085 return op
1086
1087
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001088def remove_memory_only_ops(op, arch):
1089 if op.run_on_npu and op.type in memory_only_ops:
1090 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001091
1092
1093def fuse_activation_function_with_prev(op, arch, nng):
1094 # if op is a no-op: attempts to move the activation function to the preceding op
1095 if not op.attrs.get("is_nop", False) or op.activation is None:
1096 return op
1097 ifm, ofm = op.get_ifm_ofm()
1098 if ifm is None or ofm is None:
1099 return op
1100 # finds the input(s) to the operation
1101 prev_op = ifm.ops[0]
1102 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1103 fuse = (
1104 prev_op.run_on_npu
1105 and prev_op.type.npu_block_type != NpuBlockType.Default
1106 and len(ifm.ops) == 1
1107 and len(prev_op.outputs[0].consumers()) == 1
1108 and prev_op.activation is None
1109 )
1110 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1111 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1112 # LUT currently only works correctly for elementwise ops
1113 fuse = False
1114 if not fuse:
1115 return op
1116 # Move the fused activation function + corresponding info to prev_op
1117 prev_op.activation = op.activation
1118 prev_op.forced_output_quantization = op.forced_output_quantization
1119 if op.activation_lut is not None:
1120 prev_op.set_activation_lut(op.activation_lut)
1121 # Bypass op
1122 prev_op.set_output_tensor(ofm)
1123 DebugDatabase.add_optimised(op, prev_op)
1124 return op
1125
1126
1127def _leading_pad_ok(leading_pad, stride, kernel_size):
1128 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1129 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1130 max_size = kernel_size // 2
1131 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1132
1133
1134def replace_pad_by_hw_pad(op: Operation, arch, nng):
1135 """
1136 Tries to completely remove a PAD operator by using hardware padding.
1137 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1138 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1139 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1140 if both operations can be run on the NPU.
1141 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1142 """
1143 if (
1144 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +00001145 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001146 and op.run_on_npu
1147 and op.attrs["padding"] == Padding.VALID
1148 ):
1149 pad_op = op.ifm.ops[0]
1150 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1151 return op
1152 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1153 return op
1154 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1155 k = op.kernel
1156 k_w, k_h = k.dilated_wh()
1157
1158 # Check if the PAD operator can be replaced by hardware padding
1159 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1160 # Too much padding, it would require hardware padding to actually insert zeros
1161 return op
1162 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1163 return op
1164
1165 if op.type.is_avgpool_op():
1166 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1167 for pad, k_size in (
1168 (left, k_w),
1169 (right, k_w),
1170 (top, k_h),
1171 (bottom, k_h),
1172 ):
1173 if pad not in (0, k_size // 2):
1174 return op
1175 # Average pool is converted to depthwise, because NPU average pool + same padding
1176 # has a special implementation that is different from PAD followed by average pool with
1177 # valid padding.
1178 k_w, k_h = op.kernel.width, op.kernel.height
1179 ifm = op.ifm
1180 # Remember other inputs
1181 other_inputs = op.inputs[1:]
1182 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1183 quantization = QuantizationParameters(0.0, 255.0)
1184 quantization.scale_f32 = 1.0 / (k_w * k_h)
1185 quantization.zero_point = 0
1186 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1187 weights = np.full(shape, 1)
1188
1189 weight_tens = create_const_tensor(
1190 op.name + "_weights",
1191 shape,
1192 op.ifm.dtype,
1193 weights,
1194 np.uint8,
1195 purpose=TensorPurpose.Weights,
1196 quantization=quantization,
1197 )
James Peet7519d502021-07-19 16:47:58 +01001198 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001199 op.type = Op.DepthwiseConv2DBias
1200 op.inputs = []
1201 op.add_input_tensor(ifm)
1202 op.add_input_tensor(weight_tens)
1203 # Add bias tensor, all biases set to 0
1204 op.inputs.append(None)
1205 fixup_bias_tensors(op, arch, nng)
1206 # Add other inputs
1207 op.inputs.extend(other_inputs)
1208 op.rounding_mode = NpuRoundingMode.NATURAL
1209
1210 # Bypass the PAD operator
1211 op.set_input_tensor(pad_op.ifm, 0)
1212 # Adjust the padding attributes of the convolution operator
1213 op.attrs["padding"] = Padding.EXPLICIT
1214 op.attrs["explicit_padding"] = (top, left, bottom, right)
1215 op.set_ifm_ofm_shapes()
1216 return op
1217
1218
1219def convert_pad(op: Operation, arch, nng):
1220 """
1221 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1222 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1223 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1224 """
1225 if op.type != Op.Pad or not op.run_on_npu:
1226 return op
1227 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1228
1229 ifm = op.ifm
1230 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001231 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001232 ofm = op.ofm
1233 assert ofm is not None
1234 ofm.ops = []
1235 ofm_shape = op.ofm_shapes[0]
1236
1237 # Average pool op that copies IFM to the right place inside the OFM
1238 shp0 = Shape4D(0, 0, 0, 0)
1239 shp_top = shp0.with_height(top)
1240 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1241 avgpool_op.activation = op.activation
1242 quant = ofm.quantization
1243 pad_value = quant.zero_point
1244 # Add operations that fill the borders of the OFM
1245 if top > 0:
1246 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1247 zero_tens = create_const_tensor(
1248 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1249 )
1250 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1251 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1252 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1253 if bottom > 0:
1254 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1255 zero_tens = create_const_tensor(
1256 op.name + "_bottom",
1257 shape.as_list(),
1258 ofm.dtype,
1259 shape.elements() * [pad_value],
1260 np.uint8,
1261 quantization=quant,
1262 )
1263 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1264 create_avg_pool_for_concat(
1265 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1266 )
1267 if left > 0:
1268 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1269 zero_tens = create_const_tensor(
1270 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1271 )
1272 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1273 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1274 if right > 0:
1275 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1276 zero_tens = create_const_tensor(
1277 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1278 )
1279 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1280 create_avg_pool_for_concat(
1281 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1282 )
1283
1284 op.type = Op.ConcatTFLite
1285 return avgpool_op
1286
1287
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001288def fixup_bias_tensors(op, arch, nng):
1289 if op.type.needs_bias() and op.bias is None:
1290 # Op has no bias, add bias tensor filled with zeros
1291 nr_biases = op.inputs[1].shape[-1]
1292 bias_values = [0] * nr_biases
1293 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001294 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1295
1296 return op
1297
1298
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001299def fixup_asymmetric_weights(op, arch, nng):
1300 if op.run_on_npu and (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op()):
1301 if op.ifm.dtype == DataType.int8:
1302 if not np.all(op.weights.quantization.zero_point == 0):
1303 print(f"Warning: {op.type} '{op.name}' has asymmetric weights, zero points have been adjusted.")
1304 op.weights.quantization.zero_point *= 0
1305
1306 return op
1307
1308
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001309def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1310 if op.type == Op.Mean and op.run_on_npu:
1311 keep_dims = op.attrs.get("keep_dims", False)
1312 inp, axis = op.inputs
1313 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001314 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001315 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001316 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001317
1318 # Height and width axes have different index depending on dimensions
1319 if axis.shape == [] or axis.shape[0] == 1: # single axis
1320 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1321 if dims in (2, 3):
1322 if axis == 0:
1323 h, w = shape[axis], 1
1324 else:
1325 h, w = 1, shape[axis]
1326 else:
1327 if axis == 1:
1328 h, w = shape[axis], 1
1329 else:
1330 h, w = 1, shape[axis]
1331 else: # multiple axes
1332 axis = sorted(axis.values)
1333 h, w = [shape[i] for i in axis]
1334
1335 # Set necessary depthwise attributes
1336 op.attrs.update(
1337 {
1338 "padding": Padding.VALID,
1339 "stride_h": 1,
1340 "stride_w": 1,
1341 "strides": (1, 1, 1, 1),
1342 "depth_multiplier": 1,
1343 "channel_multiplier": 1,
1344 "dilation_h_factor": 1,
1345 "dilation_w_factor": 1,
1346 "dilation": (1, 1, 1, 1),
1347 }
1348 )
1349 # Change op type
1350 op.type = Op.DepthwiseConv2DBias
1351 # Set IFM/OFM shapes after changing op type
1352 op.set_ifm_ofm_shapes()
1353
1354 weight_scale, bias = 1, None
1355 ofmq, ifmq = op.ofm.quantization, inp.quantization
1356 # Set rounding mode, scaling and zero point based on which reference implementation to match
1357 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1358 if inp.dtype == DataType.uint8:
1359 # This attribute means a different scaling calculation is used in order to match reference
1360 op.low_precision_scaling = True
1361 weight_scale = h * w
1362 # Set zero points to 0 as they will be adjusted for with bias term
1363 foq = ofmq.clone()
1364 foq.zero_point = 0
1365 fiq = ifmq.clone()
1366 fiq.zero_point = 0
1367 op.forced_input_quantization = fiq
1368 bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
1369 # If the bias term is outside uint8 range, we need an Add op to apply it.
1370 if bias_term < 0 or bias_term > 255:
1371 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1372 # Bias term has higher bitness (i32) than input/output (u8).
1373 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1374 # the bias can only effectively assume values in the range [-255, 255].
1375 intermediate.dtype = DataType.int16
1376 intermediate.quantization.zero_point = 0
1377 add_op = Operation(Op.Add, op.name + "_bias")
1378 add_op.forced_output_quantization = foq
1379 add_op.add_input_tensor(intermediate)
1380 quant = QuantizationParameters()
1381 quant.zero_point = 0
1382 bias_term_tens = create_const_tensor(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001383 op.name + "_bias",
1384 [1, 1, 1, 1],
1385 DataType.int16,
1386 [bias_term],
1387 np.int16,
1388 quantization=quant,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001389 )
1390 add_op.add_input_tensor(bias_term_tens)
1391 add_op.set_output_tensor(op.ofm)
1392 add_op.set_ifm_ofm_shapes()
1393 add_op.activation = op.activation
1394 op.activation = None
1395 op.set_output_tensor(intermediate)
1396 op.set_ifm_ofm_shapes()
1397 # If not, we can just do it with the OFM zero point.
1398 else:
1399 foq.zero_point = bias_term
1400 op.forced_output_quantization = foq
1401 else:
1402 assert inp.dtype == DataType.int8
1403 # Use a depthwise to calculate the sum,
1404 # followed by a multiplication with 1/N to get the MEAN
1405 weight_scale = 1
1406 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1407 intermediate.dtype = DataType.int16
1408 mul_op = Operation(Op.Mul, op.name + "_mul")
1409 mul_op.add_input_tensor(intermediate)
1410 # Create scalar containing 1/N
1411 quant = QuantizationParameters()
1412 quant.zero_point = 0
1413 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1414 # while rounding mode NATURAL would round this to -1.
1415 # This can only occur if N is even, and can be emulated by
1416 # multiplying with a number that is slightly smaller than 1/N.
1417 # It must be so small that other roundings are not affected;
1418 # the calculated value is based on worst case,
1419 # which is sum 256 * N (the maximum sum that can occur with int8)
1420 n = int(h * w)
1421 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1422 quant.scale_f32 = 1 / (n - eps)
1423 scalar = create_const_tensor(
1424 op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
1425 )
1426 mul_op.add_input_tensor(scalar)
1427 mul_op.set_output_tensor(op.ofm)
1428 mul_op.set_ifm_ofm_shapes()
1429 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1430 mul_op.activation = op.activation
1431 op.activation = None
1432 op.set_output_tensor(intermediate)
1433 op.set_ifm_ofm_shapes()
1434 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1435 # Here we can just use a simple AvgPool with truncating rounding,
1436 # as we're emulating simple integer division.
1437 op.rounding_mode = NpuRoundingMode.TRUNCATE
1438 op.type = Op.AvgPool
1439 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1440 else:
1441 op.rounding_mode = NpuRoundingMode.NATURAL
1442 weight_scale = 1 / (h * w)
1443 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1444 bias = -ifmq.zero_point * h * w
1445 fiq = ifmq.clone()
1446 fiq.zero_point = 0
1447 op.forced_input_quantization = fiq
1448
1449 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001450 def extend_dims(dim, in_shape):
1451 if dim < 4:
1452 in_shape = [1] + in_shape
1453 if dim == 2:
1454 in_shape += [1]
1455 return in_shape
1456
1457 if dims < 4 or dims_ofm < 4:
1458 # Fix the ofm dimension when keep_dims is false
1459 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1460 if isinstance(axis, int) and dims_ofm + 1 == dims:
1461 ofm_shape.insert(axis, 1)
1462 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1463 for i in axis:
1464 ofm_shape.insert(i, 1)
1465 shape = extend_dims(dims, shape)
1466 dims_ofm = len(ofm_shape)
1467 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1468 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001469
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001470 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
1471 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001472 shape = [shape[0], 1, h * w, shape[3]]
1473 op.ifm_shapes[0] = Shape4D(shape)
1474 if h > 256 and op.type == Op.AvgPool:
1475 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1476
1477 # If the AvgPool version is used, we don't need to do anything else
1478 if op.type == Op.AvgPool:
1479 return op
1480
1481 # Make unit weight tensor quantization
1482 weight_quant = ifmq.clone()
1483 weight_quant.min = 0
1484 weight_quant.max = 255
1485 weight_quant.scale_f32 = weight_scale
1486 weight_quant.zero_point = 0
1487
1488 # Set weight shape to [H,W,C,B]
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001489 weight_shape = [h, w, shape[3], shape[0]]
1490
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001491 # Add unit weight tensor
1492 op.set_input_tensor(
1493 create_const_tensor(
1494 "weights",
1495 weight_shape,
1496 inp.dtype,
1497 np.ones(weight_shape),
1498 value_dtype=np.uint8,
1499 quantization=weight_quant,
1500 ),
1501 1,
1502 )
James Peet7519d502021-07-19 16:47:58 +01001503 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001504
1505 # Add None bias tensor
1506 op.inputs.append(None)
1507 # Add bias tensor
1508 if bias:
1509 bias_shape = [shape[-1]]
1510 op.set_input_tensor(
1511 create_const_tensor(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001512 "bias",
1513 bias_shape,
1514 inp.dtype,
1515 np.ones(bias_shape) * bias,
1516 value_dtype=np.int32,
1517 quantization=None,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001518 ),
1519 2,
1520 )
1521
1522 return op
1523
1524
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001525def optimise_quantize(op: Operation, arch, nng):
1526
1527 if op.type == Op.Quantize and op.run_on_npu:
1528
1529 ifm, ofm = op.get_ifm_ofm()
1530 input_values = ifm.values
1531
1532 # Guard clause - input not const or no values to quantize
1533 if ifm.ops[0].type != Op.Const or input_values is None:
1534 return op
1535
1536 # Singular val in numpy array, convert to indexable array
1537 if input_values.ndim == 0:
1538 input_values = np.array([input_values])
1539
Fredrik Svedberg11563172022-07-06 14:54:12 +02001540 # requantized int8 to int8 or int16 to int16
1541 if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001542
1543 # scale needs to use double precision to match TFLite reference kernel
1544 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
1545 effective_multiplier, effective_shift = quantise_scale(effective_scale)
1546
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001547 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001548 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001549 input_val = val - ifm.quantization.zero_point
1550
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001551 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
1552 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001553
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001554 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
1555 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001556
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001557 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
1558 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001559
1560 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001561 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001562
1563 quantized_vals = []
1564 for val in input_values:
1565
1566 # Derive quantized value
1567 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001568 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
1569 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001570
1571 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001572 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
1573
1574 # Unsupported data type
1575 else:
1576 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001577
1578 # Make quantize op const and disconnect from parent node
1579
1580 # Remove reference of the current quant op from the parent tensor's consumer list
1581 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1582
1583 # Clear any references to parent node
1584 op.inputs = []
1585
1586 # Convert this quantize op to const
1587 op.type = Op.Const
1588
1589 return op
1590
1591
Ayaan Masood4965fae2022-06-29 11:30:57 +01001592def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
1593 """Static optimisation for SHAPE operator output value known at compile time"""
1594
1595 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
1596
1597 if op.type == Op.Shape and op.run_on_npu:
1598
1599 ifm, ofm = op.get_ifm_ofm()
1600
1601 if len(ifm.shape) != ofm.shape[0]:
1602 return op
1603
1604 # Remove reference of the current shape op from the parent tensor's consumer list
1605 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1606
1607 # Clear any references to parent node
1608 op.inputs = []
1609
1610 # Convert this SHAPE op to const
1611 op.type = Op.Const
1612
1613 # Add size calculation to shape output tensors
1614 ofm.values = np.array(ifm.shape)
1615
1616 return op
1617
1618
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001619def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001620 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001621 return op
1622
1623
1624def tflite_optimise_graph(nng, arch):
Fredrik Svedberg11563172022-07-06 14:54:12 +02001625 # Compile time static optimisations
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001626 optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
1627
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001628 for idx, sg in enumerate(nng.subgraphs):
1629 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001630 nng,
1631 sg,
1632 arch,
1633 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01001634 optimisation_list,
1635 rewrite_unsupported=False,
1636 )
1637
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001638 # Pre-processing step
1639 pre_process_list = [
1640 supported_operator_check,
1641 set_ifm_ofm_op_shapes,
1642 ]
1643
Ayaan Masood4965fae2022-06-29 11:30:57 +01001644 for idx, sg in enumerate(nng.subgraphs):
1645 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1646 nng,
1647 sg,
1648 arch,
1649 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02001650 pre_process_list,
1651 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001652 )
1653
1654 # Handle Concat Ops
1655 for idx, sg in enumerate(nng.subgraphs):
1656 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1657 sg.refresh_after_modification()
1658
1659 # Handle Split Ops
1660 for idx, sg in enumerate(nng.subgraphs):
1661 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1662 nng,
1663 sg,
1664 arch,
1665 [],
1666 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1667 rewrite_unsupported=False,
1668 )
1669
1670 for idx, sg in enumerate(nng.subgraphs):
1671 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001672 nng,
1673 sg,
1674 arch,
1675 [rewrite_split_ops],
1676 [],
1677 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001678 )
1679
1680 # Handle sg input output
1681 for idx, sg in enumerate(nng.subgraphs):
1682 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001683 nng,
1684 sg,
1685 arch,
1686 [],
1687 [fix_sg_input_output],
1688 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001689 )
1690
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001691 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001692 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001693 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001694 sg.refresh_after_modification()
1695
1696 # Rewrite of operators
1697 op_rewrite_list = [
1698 set_tensor_equivalence,
1699 convert_mean_to_depthwise_conv_or_avgpool,
1700 convert_depthwise_to_conv,
1701 convert_conv_to_fc,
1702 convert_softmax,
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001703 convert_prelu,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001704 optimise_strided_conv,
1705 convert_hardswish_to_lut,
1706 rewrite_fully_connected_input,
1707 convert_batched_fc_shape,
1708 fixup_conv2d_backprop,
1709 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001710 reorder_depthwise_weights,
Tim Hall885033b2022-07-21 11:46:03 +01001711 fixup_resize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001712 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001713 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001714 convert_mul_max_to_abs_or_lrelu,
1715 convert_lrelu,
1716 convert_tanh_sigmoid_to_lut,
1717 replace_pad_by_hw_pad,
1718 ]
1719
1720 for idx, sg in enumerate(nng.subgraphs):
1721 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001722 nng,
1723 sg,
1724 arch,
1725 [],
1726 op_rewrite_list,
1727 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001728 )
1729
1730 for idx, sg in enumerate(nng.subgraphs):
1731 # remove passthrough tensors and attempt further optimizations
1732 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1733 nng,
1734 sg,
1735 arch,
1736 [remove_passthrough_tensor],
1737 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1738 )
1739
1740 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1741 # since ifm/ofm_shapes are of importance to this function
1742 for sg in nng.subgraphs:
1743 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1744 sg.refresh_after_modification()
1745
1746 return nng