blob: aaa778edaa229c41309978d6f344d298bf621656 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from . import rewrite_graph
26from . import scaling
27from .api import NpuRoundingMode
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020028from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020029from .data_type import DataType
30from .debug_database import DebugDatabase
31from .errors import UnsupportedFeatureError
32from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020033from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020034from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020035from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020036from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020037from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020038from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020039from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020040from .graph_optimiser_util import needed_total_padding
41from .graph_optimiser_util import set_ifm_ofm_op_shapes
42from .graph_optimiser_util import set_tensor_equivalence
43from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020044from .numeric_util import round_away_zero
Johan Alfvén17009392022-08-30 09:14:56 +020045from .numeric_util import round_up_to_int
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020046from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020047from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020048from .operation import NpuBlockType
49from .operation import Op
50from .operation import Operation
51from .operation import Padding
52from .operation_util import create_avgpool_nop
53from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010054from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020055from .shape4d import Shape4D
56from .softmax import SoftMax
57from .tensor import check_quantized_tens_scaling_equal
58from .tensor import create_const_tensor
59from .tensor import create_equivalence_id
60from .tensor import QuantizationParameters
61from .tensor import Tensor
62from .tensor import TensorPurpose
63from .tflite_mapping import optype_to_builtintype
64
65passthrough_nodes = (Op.Identity,)
66
67
68def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
69 """Creates an average pool for the given concat op/input feature map"""
70 ofm = concat_op.ofm
71 avgpool_op = create_avgpool_nop(name)
72 avgpool_op.inputs = [ifm]
73 avgpool_op.outputs = [ofm]
74
75 avgpool_op.write_offset = write_offset
76 avgpool_op.write_shape = ifm_shape
77 ofm.ops.append(avgpool_op)
78 DebugDatabase.add_optimised(concat_op, avgpool_op)
79 avgpool_op.ifm_shapes.append(ifm_shape)
80 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
81 avgpool_op.memory_function = Op.ConcatSliceWrite
82 return avgpool_op
83
84
85def remove_passthrough_tensor(tens, arch, nng):
86 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
87 assert len(tens.ops[0].inputs) == 1
88 tens = tens.ops[0].inputs[0]
89 return tens
90
91
92def rewrite_concat_ops(op, arch):
93 if not op.run_on_npu or not op.type.is_concat_op():
94 return
95
96 axis_4D = 0
97 ofm = op.ofm
98 ofm.ops = []
99 offset = 0
100
101 unfuse_activation_function(op)
102
103 if op.type == Op.Pack:
104 # Pack is also referred to as Stack
105 axis = int(op.attrs["axis"])
106 if axis < 0: # Convert to positive axis
107 axis = len(op.inputs[0].shape) + 1 + axis
108
109 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
110
111 axis_4D = axis + (4 - len(desired_shape))
112
113 for idx, inp in enumerate(op.inputs):
114 op.ifm_shapes[idx] = Shape4D(desired_shape)
115 op.type = Op.PackReshaped
116
117 inputs, axis = op.get_concat_inputs_axis()
118 for idx, inp in enumerate(inputs):
119 if op.type != Op.PackReshaped:
120 op.ifm_shapes[idx] = Shape4D(inp.shape)
121 if axis >= 0:
122 axis_4D = axis + (4 - len(inp.shape))
123 else:
124 axis_4D = axis
125 write_offset = [0, 0, 0, 0]
126 write_offset[axis_4D] = offset
127 concat_end = offset + op.ifm_shapes[idx][axis_4D]
128 create_avg_pool_for_concat(
129 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
130 )
131 offset = concat_end
132 assert ofm.shape[axis] == offset
133
134 return op
135
136
137def rewrite_split_ops(tens, arch, nng):
138
139 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
140 split_op = tens.ops[0]
141
142 # Not supported so leave it and run on CPU
143 if not split_op.run_on_npu:
144 return tens
145
146 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
147
148 tens.ops = []
149 new_op = Operation(Op.SplitSliceRead, split_op.name)
150 new_op.inputs = [inp]
151 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000152 if None in (offset_end, offset_start):
153 read_shape = None
154 else:
155 # the read shape is relative to each start offset
156 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200157
158 # For Split the offset cannot be extracted from the tensor so it has to
159 # be calculated from the index of the output tensor
160 if axis is not None:
161 # Get the start and end of the split
162 offset_start = [0] * 4
163 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
164 for idx, out in enumerate(outputs):
165 if axis_4D_list is not None:
166 axis_4D = axis_4D_list[idx]
167 else:
168 split_op.ofm_shapes[idx] = Shape4D(out.shape)
169 if axis >= 0:
170 axis_4D = axis + (4 - len(out.shape))
171 else:
172 axis_4D = axis
173
174 if out == tens:
175 ofm_shape_idx = idx
176 read_shape = split_op.ofm_shapes[idx]
177 break
178
179 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
180
181 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
182 new_op.read_shapes[0] = read_shape
183 new_op.run_on_npu = True
184 new_op.set_output_tensor(tens)
185 new_op.ifm_shapes.append(Shape4D(inp.shape))
186 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
187 DebugDatabase.add_optimised(split_op, new_op)
188
189 return tens
190
191
192def remove_SplitSliceRead(op, arch):
193
194 if op.type == Op.SplitSliceRead:
195 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
196 if (
197 len(op.ofm.consumer_list) == 1
198 and op.ofm.consumer_list[0] is not None
199 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200200 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200201 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
202 ):
203 # SplitSliceRead can be performed by tensor consumer
204 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200205 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200206 else:
207 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
208 avgpool_op.add_input_tensor(op.ifm)
209 avgpool_op.outputs = [op.ofm]
210 op.ofm.ops.remove(op)
211 op.ofm.ops.append(avgpool_op)
212 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
213 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
214 avgpool_op.read_offsets[0] = op.read_offsets[0]
215 avgpool_op.read_shapes[0] = op.read_shapes[0]
216
217 op.ifm.consumer_list.remove(op)
218 DebugDatabase.add_optimised(op, avgpool_op)
219
220
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200221def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
222 k_w, k_h = kernel.dilated_wh()
223 s_x, s_y = kernel.stride
224 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
225 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
226 if padding_type == Padding.SAME:
227 left_pad = (xpad + 0) // 2
228 right_pad = (xpad + 1) // 2
229 top_pad = (ypad + 0) // 2
230 bottom_pad = (ypad + 1) // 2
231 elif padding_type == Padding.VALID:
232 left_pad = 0
233 right_pad = 0
234 top_pad = 0
235 bottom_pad = 0
236 elif padding_type == Padding.EXPLICIT:
237 # Padding is specified in a PAD operator which has been bypassed.
238 top, left, bottom, right = explicit_padding
239 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
240 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
241 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000242 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200243 padding = (top_pad, left_pad, bottom_pad, right_pad)
244 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
245 return padding, skirt
246
247
248def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
249 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
250 if padding_type == Padding.SAME:
251 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
252 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
253 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
254 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
255 left_pad = max(kernel_width - 1 - right_pad, 0)
256 top_pad = max(kernel_height - 1 - bottom_pad, 0)
257 elif padding_type == Padding.VALID:
258 right_pad = max(kernel_width - 2, 0)
259 bottom_pad = max(kernel_height - 2, 0)
260 left_pad = kernel_width - 1
261 top_pad = kernel_height - 1
262 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000263 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200264 padding = (top_pad, left_pad, bottom_pad, right_pad)
265 skirt = padding
266 return padding, skirt
267
268
269def fixup_conv2d_backprop(op, arch, nng):
270 if op.type == Op.Conv2DBackpropInput:
271 # flip the inputs
272 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
273 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000274 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200275
276 # Update strides
277 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
278
279 return op
280
281
282# Convert the op to an elementwise add
Tim Hall885033b2022-07-21 11:46:03 +0100283def convert_resize_1x1_to_add(op):
284 op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200285 op.name = op.name + "_add"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200286 # Create an input tensor filled with zeros
287 shape = op.ofm_shapes[0].as_list()
288 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100289 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200290 tens.quantization = QuantizationParameters(0.0, 255.0)
291 tens.quantization.scale_f32 = 1.0
292 tens.quantization.zero_point = 0
293 tens.consumer_list = [op]
294 tens_op = op.inputs[1].ops[0]
295 tens_op.set_output_tensor(tens)
296 # Set the add inputs
297 op.inputs[1] = op.inputs[0]
298 op.inputs[0] = tens
299 op.set_ifm_ofm_shapes()
300
301 return op
302
303
Tim Hall885033b2022-07-21 11:46:03 +0100304# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
305# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
306# to select the appropriate nearest neighbor value
307def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
308 ifm = op.ifm
309 ofm = op.ofm
310 output_depth = ofm.shape[-1]
311 dw_op_attrs = {
312 "padding": Padding.VALID,
313 "stride_h": 1,
314 "stride_w": 1,
315 "strides": (1, 1, 1, 1),
316 "depth_multiplier": 1,
317 "channel_multiplier": 1,
318 "dilation_h_factor": 1,
319 "dilation_w_factor": 1,
320 "dilation": (1, 1, 1, 1),
321 }
322
323 # change resizebilinear to depthwise
324 op.type = Op.DepthwiseConv2DBias
325 op.attrs.update(dw_op_attrs)
326 op.set_input_tensor(ifm, 0) # ifm tensor index
327 op.activation = None
328
329 # add input resample to resize by x2
330 op.ifm_resampling_mode = resampling_mode.NEAREST
331
332 # don't care about the rounding mode as it is nearest neighbor
333
334 # setup weight tensor
335 weight_quant = QuantizationParameters()
336 weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
337 weight_quant.zero_point = 0
338 weight_quant.quant_dim = 0
339 ofm_dtype = ofm.dtype
340 if ofm_dtype == DataType.uint8:
341 weight_value_dtype = np.uint8
342 weight_quant.quant_min = 0
343 weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
344 else:
345 if ofm_dtype == DataType.int8:
346 weight_value_dtype = np.int8
347 else:
348 assert ofm_dtype == DataType.int16
349 weight_value_dtype = np.int16
350
351 weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
352 weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
353
354 weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
355
356 # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
357 # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
358 # below-and-right (i.e. next) to it (D).
359 # 0---1---2
360 # | A | B |
361 # 1---*---+
362 # | C | D |
363 # 2---+---+
364 weight_values = [0] * (upscale_factor * upscale_factor)
365 centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
366 weight_values[centre_coeff] = 1
367
368 # add weight tensor, this will discard the size tensor of the resize op
369 op.set_input_tensor(
370 create_const_tensor(
371 "weights",
372 weight_shape,
373 ofm.dtype,
374 np.array(weight_values).reshape(weight_shape),
375 value_dtype=weight_value_dtype,
376 quantization=weight_quant,
377 ),
378 1, # inputs tensor weight index
379 )
380
381 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
382 # need to append the bias tensor as resize ops only have 2 inputs
383 assert len(op.inputs) == 2
384 op.inputs.append(None)
385 fixup_bias_tensors(op, None, None)
386
387 # finally update the shape incase we've change the tensor shapes or connections
388 op.set_ifm_ofm_shapes()
389
390 return op
391
392
393# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
394# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
395# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
396def convert_resize_to_upscale_and_average_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200397 pre_op = op
398 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000399 dtype = op.ifm.dtype
Tim Hall885033b2022-07-21 11:46:03 +0100400
Rickard Boline546def2022-01-25 15:45:00 +0000401 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Tim Hall47c76362022-07-18 21:26:47 +0100402 op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
Tim Hall3c5cfe92022-03-16 16:31:57 +0000403 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200404
405 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
Tim Hall47c76362022-07-18 21:26:47 +0100406
407 # Get upscale factor that was calculated in the supported operators check
408 upscale_factor = op.attrs["upscale_factor"]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200409
Rickard Boline546def2022-01-25 15:45:00 +0000410 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100411 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
412 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
Rickard Boline546def2022-01-25 15:45:00 +0000413 n = int(np.log2(upscale_factor))
414
Tim Hall885033b2022-07-21 11:46:03 +0100415 # Perform x2 upscaling n-1 times
Rickard Boline546def2022-01-25 15:45:00 +0000416 scaled_op = pre_op
417 for count in range(n - 1):
418 if count > 0:
419 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200420 scaled_op.inputs[0] = pre_op.outputs[0]
421
Tim Hall885033b2022-07-21 11:46:03 +0100422 # Nearest neighbor x2 upscaling
Tim Hall47c76362022-07-18 21:26:47 +0100423 upscaled_shape = upscaled_shape * 2
Rickard Boline546def2022-01-25 15:45:00 +0000424 shape = op.ofm_shapes[0].as_list()
425 shape[1:3] = upscaled_shape
426 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
427 out_tens.quantization = op.outputs[0].quantization.clone()
428 scaled_op.set_output_tensor(out_tens)
429 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200430
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200431 scaled_op.set_ifm_ofm_shapes()
432
Tim Hall885033b2022-07-21 11:46:03 +0100433 # Last x2 upscaling
Rickard Boline546def2022-01-25 15:45:00 +0000434 if n > 1:
435 scaled_op = op.clone(f"_{n-1}")
436 scaled_op.inputs[0] = pre_op.outputs[0]
Tim Hall885033b2022-07-21 11:46:03 +0100437
438 if scaled_op.original_type == Op.ResizeBilinear:
439 if scaled_op.attrs["align_corners"]:
440 # no padding
441 scaled_op.attrs["padding"] = Padding.VALID
442 else:
443 # padding to the right and bottom (limits average pool to 8x8 kernel)
444 scaled_op.attrs["padding"] = Padding.EXPLICIT
445 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
446
447 # kernal size dependent on the upscaling factor
448 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
449 else: # Op.ResizeNearestNeighbor
450 if scaled_op.attrs["align_corners"]:
451 # use depthwise conv to select the correct value
452 scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
453 else:
454 # keep 1x1 kernel and average pool
455 pass
456
Rickard Boline546def2022-01-25 15:45:00 +0000457 scaled_op.outputs = outputs
458 scaled_op.outputs[0].ops = [scaled_op]
459 scaled_op.set_ifm_ofm_shapes()
460
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200461 return op
462
463
Tim Hall885033b2022-07-21 11:46:03 +0100464def fixup_resize(op, arch, nng):
465 if op.type.is_resize_op() and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200466 if op.ifm_shapes[0] == op.ofm_shapes[0]:
Tim Hall885033b2022-07-21 11:46:03 +0100467 # Bypass the resize op which is essentially a NOP
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200468 op.inputs = op.inputs[:1]
469 op.type = Op.Identity
470 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
Tim Hall885033b2022-07-21 11:46:03 +0100471 convert_resize_1x1_to_add(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200472 else:
Tim Hall885033b2022-07-21 11:46:03 +0100473 convert_resize_to_upscale_and_average_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200474
475 return op
476
477
478def convert_nop_split_to_identity(op, arch, nng):
479 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
480 # the list comprehension should return a list with a single tensor
481 # if it shouldn't, remove_passthrough_tensor will fail appropriately
482 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
483 op.type = Op.Identity
484 return op
485
486
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100487def rewrite_fully_connected_input(op: Operation, arch, nng):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200488
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100489 if op.type == Op.FullyConnected:
490 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
491 assert new_shape is not None, "Tensor can not be reshaped to 2D"
492 op.ifm_shapes[0] = new_shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200493 return op
494
495
496def convert_batched_fc_shape(op, arch, nng):
497 if op.type == Op.FullyConnected:
498 # Check if the first dimension indicates batching
499 if op.ifm_shapes[0].batch > 1:
500 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
501 n = op.ifm_shapes[0].batch
502 h, w = batching_split.get(n, (1, n))
503 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
504
505 # Reshape Weights to be 4D. IO becomes HWIO
506 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100507 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
508 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200509
510 n = op.ofm_shapes[0].batch
511 h, w = batching_split.get(n, (1, n))
512 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
513 return op
514
515
516def unfuse_activation_function(op):
517 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
518 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
519 op.activation = None
520 out_tens = op.outputs[0]
521 intermediate_tens = out_tens.clone("_act_intermediate")
522 act_op.set_output_tensor(out_tens)
523 act_op.add_input_tensor(intermediate_tens)
524 op.set_output_tensor(intermediate_tens)
525 act_op.set_ifm_ofm_shapes()
526
527
528def rewrite_stridedslice_output(op, arch, nng):
529 if not op.run_on_npu or op.type != Op.StridedSlice:
530 return op
531
532 new_axis_mask = op.attrs["new_axis_mask"]
533 shrink_axis_mask = op.attrs["shrink_axis_mask"]
534
535 if shrink_axis_mask == 0 and new_axis_mask == 0:
536 return op
537
538 axis_4D = [0] * len(op.outputs)
539 for idx, out_tens in enumerate(op.outputs):
540 output_shape = list(out_tens.shape)
541
542 if shrink_axis_mask != 0:
543 n = 0
544 axis = 0
545 while shrink_axis_mask:
546 prev_mask = shrink_axis_mask
547 n += 1
548 shrink_axis_mask &= shrink_axis_mask - 1
549 axis = int(math.log2(prev_mask - shrink_axis_mask))
550 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
551
552 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
553 op.attrs["shrink_axis_mask"] = 0
554 if axis >= 0:
555 axis_4D[idx] = axis + (4 - len(output_shape))
556 else:
557 axis_4D[idx] = axis
558 op.ofm_shapes[idx] = Shape4D(output_shape)
559
560 elif new_axis_mask != 0:
561 n = 0
562 axis = 0
563 while new_axis_mask:
564 prev_mask = new_axis_mask
565 n += 1
566 new_axis_mask &= new_axis_mask - 1
567 axis = int(math.log2(prev_mask - new_axis_mask))
568 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
569 new_axis_mask >>= 1
570
571 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
572 op.attrs["new_axis_mask"] = 0
573 if axis >= 0:
574 axis_4D[idx] = axis + (4 - len(output_shape))
575 else:
576 axis_4D[idx] = axis
577 op.ofm_shapes[idx] = Shape4D(output_shape)
578
579 op.attrs["split_axis_4D"] = axis_4D
580 return op
581
582
583def rewrite_unpack_output(op, arch, nng):
584 tens = op.outputs[0]
585 if op.run_on_npu and op.type == Op.Unpack:
586 # Unpack is also referred to as Unstack
587 axis = int(op.attrs["axis"])
588 if axis < 0: # Convert to positive axis
589 axis = len(op.inputs[0].shape) + 1 + axis
590 op.type = Op.UnpackReshaped
591 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
592
593 axis_4D = axis + (4 - len(desired_output_shape))
594 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
595
596 for idx, out_tens in enumerate(op.outputs):
597 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
598 return op
599
600
601def add_padding_fields(op, arch, nng):
602 if op.run_on_npu:
603 if "padding" in op.attrs:
604 input_shape = op.ifm_shapes[0]
605 output_shape = op.ofm_shapes[0]
606 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
607 kernel_size = op.inputs[1].shape[:2]
608 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
609 kernel_size = op.attrs["ksize"][1:3]
610 else:
611 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
612
613 if op.type == Op.Conv2DBackpropInputSwitchedBias:
614 upscaling_factor = output_shape.height // input_shape.height
615 padding, skirt = calc_upscaled_padding_and_skirt(
616 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
617 )
618 else:
619 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200620 op.attrs["padding"],
621 op.kernel,
622 input_shape,
623 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200624 )
625
626 op.attrs["explicit_padding"] = padding
627 op.attrs["skirt"] = skirt
628
629 return op
630
631
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200632def reorder_depthwise_weights(op, arch, nng):
633 if op.type.is_depthwise_conv2d_op():
634 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100635 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
636 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200637 weight_tensor.weight_transpose_depthwise = True
638
639 return op
640
641
642def optimise_strided_conv(op, arch, nng):
Louis Verhaard43d27582022-03-17 14:06:00 +0100643 if op.type != Op.Conv2DBias or op.op_index != 0:
644 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200645 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100646 weight_tensor = op.weights
647 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200648
649 if (
Louis Verhaard43d27582022-03-17 14:06:00 +0100650 stride_x == 2
651 and ifm_shape.depth <= 4
652 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200653 and weight_tensor is not None
654 and weight_tensor.shape[1] >= 2
655 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100656 k_w, _ = op.get_kernel_size()
657 curr_padding_x = needed_total_padding(ifm_shape.width, 2, k_w)
658 optimised_padding_x = needed_total_padding(ifm_shape.width // 2, 1, (k_w + 1) // 2)
659 if curr_padding_x != optimised_padding_x:
660 # Horizontal padding would become different after optimisation; this would not work
661 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200662 # IFM
663 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
664
665 # Weights
666 weight_shape = weight_tensor.shape
667 if weight_shape[1] % 2 != 0:
668 weight_shape[1] = weight_shape[1] + 1
669 padded_array = np.zeros(weight_shape)
670 for i in range(weight_shape[0]):
671 padded_array[i] = np.vstack(
672 [
James Peet7519d502021-07-19 16:47:58 +0100673 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200674 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
675 ]
676 )
James Peet7519d502021-07-19 16:47:58 +0100677 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200678 weight_shape[1] //= 2
679 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100680 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200681 weight_tensor.set_all_shapes(weight_shape)
682 # If multiple copies of the weights are used, we could avoid
683 # them having the same address by changing the value_id
684 weight_tensor.value_id = uuid.uuid4()
685
686 # Strides
687 stride_x = 1
688 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
689
690 return op
691
692
693def convert_conv_to_fc(op, arch, nng):
694 # Conv 1x1 can be equivalent to Fully Connected.
695 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
696 # caching/double buffering for the weights.
697 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
698 if op.type == Op.Conv2DBias:
699 h = op.ifm_shapes[0].height
700 w = op.ifm_shapes[0].width
701 kh, kw, _, _ = op.inputs[1].shape
702 if h == 1 and w == 1 and kh == 1 and kw == 1:
703 # Overwrite this op as a Fully Connected Op
704 op.name += "_fc"
705 op.type = Op.FullyConnected
706 op.attrs = {
707 "weights_format": 0,
708 }
709 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
710 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100711 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
712 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200713
714 DebugDatabase.add_optimised(op, op)
715 return op
716
717
718def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
719 if op.run_on_npu and op.type.is_relu_op():
720 ifm = op.inputs[0]
721 ofm = op.outputs[0]
722 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
723 # and requires its own to be inserted
724 if not check_quantized_tens_scaling_equal(ifm, ofm):
725 # Override this op with its own primary op (avgpool)
726 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
727 # And fuse the original activation function to it
728 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200729 # Add explicit rescaling
730 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
731 multiplier, shift = scaling.quantise_scale(rescale)
732 relu_fused_op.rescale = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200733 # Tidy up and assign the ifm and ofm to the new op
734 ifm.consumer_list.remove(op)
735
736 relu_fused_op.add_input_tensor(ifm)
737 relu_fused_op.set_output_tensor(ofm)
738 relu_fused_op.set_ifm_ofm_shapes()
739 op = relu_fused_op
740 return op
741
742
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200743def convert_softmax(op, arch, nng):
744 if op.type == Op.Softmax and op.run_on_npu:
745 softmax = SoftMax(op)
746 op = softmax.get_graph()
747 return op
748
749
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200750def convert_prelu(op, arch, nng):
751 if op.type == Op.Prelu:
752 ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
753 if None in (ifm, alpha, ofm):
754 return op
755
Fredrik Svedberg66591652022-08-29 10:51:27 +0200756 if alpha.values is not None:
757 # If const alpha check for possible optimisations
758 alpha_zp = alpha.quantization.zero_point
759 alpha_scale = alpha.quantization.scale_f32
760 # If all alpha values are the same the PReLU can be converted to LeakyRelu
761 alpha_min = (alpha.values.min().astype(np.int) - alpha_zp) * alpha_scale
762 alpha_max = (alpha.values.max().astype(np.int) - alpha_zp) * alpha_scale
763 if alpha_min == alpha_max:
764 # or even a Relu
765 if alpha_min == 0:
766 new_op = Op.Relu
767 else:
768 new_op = Op.LeakyRelu
769 op.attrs["alpha"] = alpha_min
770 # setup alpha_scaling for bit exact result
771 ifm_scale = ifm.quantization.scale_f32
772 ofm_scale = ofm.quantization.scale_f32
773 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha_scale, ofm_scale)
774 op.attrs["alpha_scaling"] = (alpha.values.min() - alpha_zp, alpha_scale, alpha_shift)
775 # Change op type
776 op.type = new_op
777 op.name = op.name.replace("Prelu", new_op.name)
778 del op.inputs[1] # Remove alpha tensor
779 return op
780 elif alpha_max < 1:
781 # If alpha_max is less than 1 convert PReLU to Max(alpha * IFM, identity * IFM)
782 # Multiply with alpha tensor
783 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
784 mul_alpha.add_input_tensor(ifm)
785 mul_alpha.add_input_tensor(alpha)
786 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
787 mul_alpha.set_output_tensor(fm_alpha)
788 mul_alpha.set_ifm_ofm_shapes()
789 DebugDatabase.add_optimised(op, mul_alpha)
790 if check_quantized_tens_scaling_equal(ifm, ofm):
791 # No scaling is needed
792 fm_id = ifm
793 else:
794 # Add multiplication with identity
795 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
796 mul_identity.add_input_tensor(ifm)
797 # Create const tensor containing identity as scalar
798 quantization = ifm.quantization.clone()
799 quantization.scale_f32 = np.float32(1)
800 quantization.zero_point = 0
801 one = create_const_tensor("one_const", [], ifm.dtype, [1], quantization=quantization)
802 mul_identity.add_input_tensor(one)
803 # Make sure that fm_id is allocated to a different address than fm_alpha
804 fm_id = ofm.clone(op.name + "_id", set_unique=True)
805 mul_identity.set_output_tensor(fm_id)
806 mul_identity.set_ifm_ofm_shapes()
807
808 # Combine scaled and alpha multiplied values
809 max_op = Operation(Op.Maximum, op.name + "_max")
810 max_op.add_input_tensor(fm_alpha)
811 max_op.add_input_tensor(fm_id)
812 max_op.set_output_tensor(ofm)
813 max_op.set_ifm_ofm_shapes()
814
815 DebugDatabase.add_optimised(op, max_op)
816 ifm.consumer_list.remove(op)
817 return max_op
818
819 # Catch all PReLU conversion for the cases that could not be optimised above
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200820 no_scale_quant = ifm.quantization.clone()
821 no_scale_quant.scale_f32 = None
822 no_scale_quant.zero_point = 0
Fredrik Svedberg66591652022-08-29 10:51:27 +0200823 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200824
825 # Select values < 0
826 min_op = Operation(Op.Minimum, op.name + "_min")
827 min_op.add_input_tensor(ifm)
828 min_op.add_input_tensor(zero)
829 fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
830 min_op.set_output_tensor(fm_negative)
831 min_op.set_ifm_ofm_shapes()
832 DebugDatabase.add_optimised(op, min_op)
833
834 # and multiply with alpha tensor
835 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
836 mul_alpha.add_input_tensor(fm_negative)
837 mul_alpha.add_input_tensor(alpha)
838 fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
839 mul_alpha.set_output_tensor(fm_alpha)
840 mul_alpha.set_ifm_ofm_shapes()
841 DebugDatabase.add_optimised(op, mul_alpha)
842
843 # Select (and scale) values > 0
844 relu_op = Operation(Op.Relu, op.name + "_relu")
845 relu_op.add_input_tensor(ifm)
846 fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
847 relu_op.set_output_tensor(fm_scaled)
848 relu_op.set_ifm_ofm_shapes()
849 DebugDatabase.add_optimised(op, relu_op)
850
851 # Add scaled and alpha multiplied values (without scaling)
852 add_op = Operation(Op.RescaleAdd, op.name + "_add")
853 add_op.rescale = (1, 0) # No scale or shift
854 add_op.add_input_tensor(fm_alpha)
855 add_op.add_input_tensor(fm_scaled)
856 add_op.set_output_tensor(ofm)
857 add_op.set_ifm_ofm_shapes()
858
859 DebugDatabase.add_optimised(op, add_op)
860 ifm.consumer_list.remove(op)
861 op = add_op
862
863 return op
864
865
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200866def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
867 r"""Whenever there is a subgraph with this topology:
868
Jonas Ohlssond8575072022-03-30 10:30:25 +0200869 Input X For X = -1 or X > 0
870 | \ / This subgraph can be replaced with either
871 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
872 | /
873 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200874 """
875
876 if op.type == Op.Maximum:
877 # finds the Mul input(s) to the Max
878 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
879 if len(muls) == 1:
880 mul = muls[0].ops[0]
881 elif len(muls) == 2:
882 # In the case both inputs are Muls, find the one with the same input as the Max
Fredrik Svedberg66591652022-08-29 10:51:27 +0200883 mul_ifms = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1]
884 if len(mul_ifms):
885 mul = mul_ifms[0].ops[0]
886 else:
887 # Not using same input
888 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200889 else:
890 # No Mul inputs
891 return op
892
893 # make sure the Mul doesn't have any other consumers
894 mul_ofm = mul.outputs[0]
895 if len(mul_ofm.consumers()) != 1:
896 return op
897 # make sure the Mul doesn't have a fused activation function
898 if mul.activation:
899 return op
900 ifm, ofm = op.get_ifm_ofm()
901 if ifm is None or ofm is None:
902 return op
903
904 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
905 return op
906 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
907 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
908 return op
909
910 # finds the branched input that goes to both the Max and the Mul
911 shared = set(op.inputs) & set(mul.inputs)
912 if len(shared) == 1:
913 shared_in = shared.pop()
914 # find the constant scalar input to the Mul
915 const_tens = (set(mul.inputs) - {shared_in}).pop()
916 # check that it is a scalar
917 if const_tens.shape != []:
918 return op
919 const = const_tens.ops[0]
920 # check that it is a constant
921 if const.type != Op.Const:
922 return op
923 # Remove the Mul from the shared input's consumers
924 shared_in.consumer_list.remove(mul)
925 else:
926 return op
927
928 val = const.outputs[0].values
929 if val >= 0:
930 new_op = Op.LeakyRelu
931 op.attrs["alpha"] = val
932 # to produce bit exact results, the alpha is not enough;
933 # save additional scaling info in attr "alpha_scale", to be used as input
934 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +0100935 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200936 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
937 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
938 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
939 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
940 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
941 elif val == -1:
942 new_op = Op.Abs
943 else:
944 return op
945
946 op.type = new_op
947 op.name = op.name.replace("Maximum", new_op.name)
948 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
949 op.inputs = [shared_in]
950 op.set_ifm_ofm_shapes()
951
952 # Record optimisation in debug database
953 DebugDatabase.add_optimised(op, op)
954
955 return op
956
957
958def convert_hardswish_to_lut(op, arch, nng):
959 if op.type == Op.HardSwish:
960 ifm, ofm = op.get_ifm_ofm()
961 # Generate the LUT
962 ifm_scale = np.double(ifm.quantization.scale_f32)
963 ofm_scale = np.double(ofm.quantization.scale_f32)
964 zp_in = ifm.quantization.zero_point
965 zp_out = ofm.quantization.zero_point
966 ifm_scale_hires = (1 / 128) * ifm_scale
967 relu_multiplier = np.double(3 / 32768)
968 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
969 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
970 # Use 16bit scale
971 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
972 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
973
974 values = []
975 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
976 quantized_min = min(ix)
977 quantized_max = max(ix)
978 for x in ix:
979 input_value = x - zp_in
980 input_value_hires = input_value * 128
981 # Compute the input value on essentially the output scale, not shifted yet
982 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
983 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
984 relu_value = np.int16(input_value_hires)
985 if relu_shift < 31:
986 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
987
988 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
989
990 if relu_shift < 31:
991 relu_value = fp_math.shift_left16(relu_value, 1)
992
993 if relu_shift > 31:
994 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
995
996 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
997 # Now convert that to a 16bit fixedpoint value in [0, 1]
998 relu_value = (relu_value + (1 << 15)) >> 1
999 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
1000 shift = 31 - out_shift
1001 shift = -shift if shift < 0 else 0
1002 # Finally apply the output shift
1003 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
1004 lut_result = min(quantized_max, max(quantized_min, lut_result))
1005 values.append(lut_result)
1006 return convert_to_lut(op, values, "hardswish")
1007 return op
1008
1009
1010def convert_lrelu_to_mul_max(op, arch):
1011 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
1012 # (the opposite of convert_mul_max_to_abs_or_lrelu)
1013 ifm, ofm = op.get_ifm_ofm()
1014 if ifm is None or ofm is None:
1015 return op
1016
1017 # Add multiplication with alpha
1018 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
1019 mul_alpha.add_input_tensor(ifm)
1020 # Create const tensor containing alpha as scalar
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001021 alpha = np.float32(op.attrs["alpha"])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001022 quantization = ifm.quantization.clone()
1023 quantization.min = 0
1024 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
1025 quantization.zero_point = 0
Fredrik Svedberg66591652022-08-29 10:51:27 +02001026 if "alpha_scaling" in op.attrs:
1027 # The LeakyRelu was the result from convert_prelu
1028 scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1029 mul_alpha.type = Op.RescaleMul
1030 mul_alpha.rescale = [alpha_scale, alpha_shift]
1031 elif np.isinf(1 / alpha):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001032 # Handling of alpha near zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001033 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001034 scalar = 0
1035 else:
1036 quantization.scale_f32 = alpha
Fredrik Svedberg66591652022-08-29 10:51:27 +02001037 scalar = 1
1038 alpha_tens = create_const_tensor(op.name + "_alpha_scalar", [], ifm.dtype, [scalar], quantization=quantization)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001039 mul_alpha.add_input_tensor(alpha_tens)
1040 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1041 mul_alpha.set_output_tensor(fm_alpha)
1042 mul_alpha.set_ifm_ofm_shapes()
1043 DebugDatabase.add_optimised(op, mul_alpha)
1044
1045 if check_quantized_tens_scaling_equal(ifm, ofm):
1046 # No identity multiplication is needed
1047 fm_id = ifm
1048 else:
1049 # Add multiplication with identity
1050 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1051 mul_identity.add_input_tensor(ifm)
1052 # Create const tensor containing identity as scalar
1053 quantization = ifm.quantization.clone()
1054 quantization.min = 0
1055 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001056 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001057 quantization.zero_point = 0
1058 identity_tens = create_const_tensor(
1059 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
1060 )
1061 mul_identity.add_input_tensor(identity_tens)
1062 # Make sure that fm_id is allocated to a different address than fm_alpha
1063 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1064 mul_identity.set_output_tensor(fm_id)
1065 mul_identity.set_ifm_ofm_shapes()
1066 DebugDatabase.add_optimised(op, mul_identity)
1067
1068 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
1069 op.type = Op.Maximum
1070 op.name = op.name.replace("LeakyRelu", "Maximum")
1071 op.inputs = []
1072 ifm.consumer_list.remove(op)
1073 op.add_input_tensor(fm_alpha)
1074 op.add_input_tensor(fm_id)
1075 op.set_ifm_ofm_shapes()
1076
1077 DebugDatabase.add_optimised(op, op)
1078 return op
1079
1080
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001081def convert_to_lut8(op, fn, fn_name):
1082 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
1083 # fn is a function(real) -> real
1084 ifm, ofm = op.get_ifm_ofm()
1085 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1086 return op
1087 # Generate the LUT
1088 ifm_scale = np.double(ifm.quantization.scale_f32)
1089 ofm_scale = np.double(ofm.quantization.scale_f32)
1090 zp_in = ifm.quantization.zero_point
1091 zp_out = ofm.quantization.zero_point
1092 values = []
1093 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1094 quantized_min = min(ix)
1095 quantized_max = max(ix)
1096 for x in ix:
1097 x_real = ifm_scale * (x - zp_in)
1098 y_real = fn(x_real)
1099 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1100 lut_result = min(quantized_max, max(quantized_min, lut_result))
1101 values.append(lut_result)
1102 return convert_to_lut(op, values, fn_name)
1103
1104
1105def convert_lrelu_to_lut(op, arch):
1106 ifm, ofm = op.get_ifm_ofm()
1107 # Generate the LUT
1108 alpha = op.attrs["alpha"]
1109 ifm_scale = np.double(ifm.quantization.scale_f32)
1110 ofm_scale = np.double(ofm.quantization.scale_f32)
1111 zp_in = ifm.quantization.zero_point
1112 zp_out = ofm.quantization.zero_point
1113 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1114 alpha_scalar = 1
1115 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1116 if "alpha_scaling" in op.attrs:
1117 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1118 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1119 values = []
1120 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1121 quantized_min = min(ix)
1122 quantized_max = max(ix)
1123 for x in ix:
1124 if x < zp_in:
1125 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1126 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1127 )
1128 else:
1129 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1130 lut_result = min(quantized_max, max(quantized_min, lut_result))
1131 values.append(lut_result)
1132 return convert_to_lut(op, values, "lrelu")
1133
1134
1135def convert_lrelu(op, arch, nng):
1136 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
1137 if op.type != Op.LeakyRelu:
1138 return op
1139 ifm, ofm = op.get_ifm_ofm()
1140 if ifm is None or ofm is None:
1141 return op
1142 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1143 # use LUT for int8/uint8
1144 return convert_lrelu_to_lut(op, arch)
1145 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
1146 # use LeakyRelu unmodified for int16 with equal input/output scaling
1147 return op
1148 return convert_lrelu_to_mul_max(op, arch)
1149
1150
1151def convert_tanh_sigmoid_to_lut(op, arch, nng):
1152 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
1153 if op.type == Op.Sigmoid:
1154 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1155 elif op.type == Op.Tanh:
1156 return convert_to_lut8(op, math.tanh, "tanh")
1157 return op
1158
1159
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001160def remove_memory_only_ops(op, arch):
1161 if op.run_on_npu and op.type in memory_only_ops:
1162 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001163
1164
1165def fuse_activation_function_with_prev(op, arch, nng):
1166 # if op is a no-op: attempts to move the activation function to the preceding op
1167 if not op.attrs.get("is_nop", False) or op.activation is None:
1168 return op
1169 ifm, ofm = op.get_ifm_ofm()
1170 if ifm is None or ofm is None:
1171 return op
1172 # finds the input(s) to the operation
1173 prev_op = ifm.ops[0]
1174 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1175 fuse = (
1176 prev_op.run_on_npu
1177 and prev_op.type.npu_block_type != NpuBlockType.Default
1178 and len(ifm.ops) == 1
1179 and len(prev_op.outputs[0].consumers()) == 1
1180 and prev_op.activation is None
1181 )
1182 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1183 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1184 # LUT currently only works correctly for elementwise ops
1185 fuse = False
1186 if not fuse:
1187 return op
1188 # Move the fused activation function + corresponding info to prev_op
1189 prev_op.activation = op.activation
1190 prev_op.forced_output_quantization = op.forced_output_quantization
1191 if op.activation_lut is not None:
1192 prev_op.set_activation_lut(op.activation_lut)
1193 # Bypass op
1194 prev_op.set_output_tensor(ofm)
1195 DebugDatabase.add_optimised(op, prev_op)
1196 return op
1197
1198
1199def _leading_pad_ok(leading_pad, stride, kernel_size):
1200 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1201 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1202 max_size = kernel_size // 2
1203 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1204
1205
1206def replace_pad_by_hw_pad(op: Operation, arch, nng):
1207 """
1208 Tries to completely remove a PAD operator by using hardware padding.
1209 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1210 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1211 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1212 if both operations can be run on the NPU.
1213 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1214 """
1215 if (
1216 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +00001217 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001218 and op.run_on_npu
1219 and op.attrs["padding"] == Padding.VALID
1220 ):
1221 pad_op = op.ifm.ops[0]
1222 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1223 return op
1224 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1225 return op
1226 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1227 k = op.kernel
1228 k_w, k_h = k.dilated_wh()
1229
1230 # Check if the PAD operator can be replaced by hardware padding
1231 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1232 # Too much padding, it would require hardware padding to actually insert zeros
1233 return op
1234 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1235 return op
1236
1237 if op.type.is_avgpool_op():
1238 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1239 for pad, k_size in (
1240 (left, k_w),
1241 (right, k_w),
1242 (top, k_h),
1243 (bottom, k_h),
1244 ):
1245 if pad not in (0, k_size // 2):
1246 return op
1247 # Average pool is converted to depthwise, because NPU average pool + same padding
1248 # has a special implementation that is different from PAD followed by average pool with
1249 # valid padding.
1250 k_w, k_h = op.kernel.width, op.kernel.height
1251 ifm = op.ifm
1252 # Remember other inputs
1253 other_inputs = op.inputs[1:]
1254 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1255 quantization = QuantizationParameters(0.0, 255.0)
1256 quantization.scale_f32 = 1.0 / (k_w * k_h)
1257 quantization.zero_point = 0
1258 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1259 weights = np.full(shape, 1)
1260
1261 weight_tens = create_const_tensor(
1262 op.name + "_weights",
1263 shape,
1264 op.ifm.dtype,
1265 weights,
1266 np.uint8,
1267 purpose=TensorPurpose.Weights,
1268 quantization=quantization,
1269 )
James Peet7519d502021-07-19 16:47:58 +01001270 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001271 op.type = Op.DepthwiseConv2DBias
1272 op.inputs = []
1273 op.add_input_tensor(ifm)
1274 op.add_input_tensor(weight_tens)
1275 # Add bias tensor, all biases set to 0
1276 op.inputs.append(None)
1277 fixup_bias_tensors(op, arch, nng)
1278 # Add other inputs
1279 op.inputs.extend(other_inputs)
1280 op.rounding_mode = NpuRoundingMode.NATURAL
1281
1282 # Bypass the PAD operator
1283 op.set_input_tensor(pad_op.ifm, 0)
1284 # Adjust the padding attributes of the convolution operator
1285 op.attrs["padding"] = Padding.EXPLICIT
1286 op.attrs["explicit_padding"] = (top, left, bottom, right)
1287 op.set_ifm_ofm_shapes()
1288 return op
1289
1290
1291def convert_pad(op: Operation, arch, nng):
1292 """
1293 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1294 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1295 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1296 """
1297 if op.type != Op.Pad or not op.run_on_npu:
1298 return op
1299 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1300
1301 ifm = op.ifm
1302 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001303 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001304 ofm = op.ofm
1305 assert ofm is not None
1306 ofm.ops = []
1307 ofm_shape = op.ofm_shapes[0]
1308
1309 # Average pool op that copies IFM to the right place inside the OFM
1310 shp0 = Shape4D(0, 0, 0, 0)
1311 shp_top = shp0.with_height(top)
1312 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1313 avgpool_op.activation = op.activation
1314 quant = ofm.quantization
1315 pad_value = quant.zero_point
1316 # Add operations that fill the borders of the OFM
1317 if top > 0:
1318 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1319 zero_tens = create_const_tensor(
1320 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1321 )
1322 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1323 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1324 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1325 if bottom > 0:
1326 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1327 zero_tens = create_const_tensor(
1328 op.name + "_bottom",
1329 shape.as_list(),
1330 ofm.dtype,
1331 shape.elements() * [pad_value],
1332 np.uint8,
1333 quantization=quant,
1334 )
1335 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1336 create_avg_pool_for_concat(
1337 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1338 )
1339 if left > 0:
1340 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1341 zero_tens = create_const_tensor(
1342 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1343 )
1344 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1345 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1346 if right > 0:
1347 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1348 zero_tens = create_const_tensor(
1349 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1350 )
1351 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1352 create_avg_pool_for_concat(
1353 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1354 )
1355
1356 op.type = Op.ConcatTFLite
1357 return avgpool_op
1358
1359
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001360def fixup_bias_tensors(op, arch, nng):
1361 if op.type.needs_bias() and op.bias is None:
1362 # Op has no bias, add bias tensor filled with zeros
1363 nr_biases = op.inputs[1].shape[-1]
1364 bias_values = [0] * nr_biases
1365 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001366 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1367
1368 return op
1369
1370
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001371def fixup_asymmetric_weights(op, arch, nng):
1372 if op.run_on_npu and (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op()):
1373 if op.ifm.dtype == DataType.int8:
1374 if not np.all(op.weights.quantization.zero_point == 0):
1375 print(f"Warning: {op.type} '{op.name}' has asymmetric weights, zero points have been adjusted.")
1376 op.weights.quantization.zero_point *= 0
1377
1378 return op
1379
1380
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001381def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1382 if op.type == Op.Mean and op.run_on_npu:
1383 keep_dims = op.attrs.get("keep_dims", False)
1384 inp, axis = op.inputs
1385 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001386 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001387 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001388 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001389
1390 # Height and width axes have different index depending on dimensions
1391 if axis.shape == [] or axis.shape[0] == 1: # single axis
1392 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1393 if dims in (2, 3):
1394 if axis == 0:
1395 h, w = shape[axis], 1
1396 else:
1397 h, w = 1, shape[axis]
1398 else:
1399 if axis == 1:
1400 h, w = shape[axis], 1
1401 else:
1402 h, w = 1, shape[axis]
1403 else: # multiple axes
1404 axis = sorted(axis.values)
1405 h, w = [shape[i] for i in axis]
1406
1407 # Set necessary depthwise attributes
1408 op.attrs.update(
1409 {
1410 "padding": Padding.VALID,
1411 "stride_h": 1,
1412 "stride_w": 1,
1413 "strides": (1, 1, 1, 1),
1414 "depth_multiplier": 1,
1415 "channel_multiplier": 1,
1416 "dilation_h_factor": 1,
1417 "dilation_w_factor": 1,
1418 "dilation": (1, 1, 1, 1),
1419 }
1420 )
1421 # Change op type
1422 op.type = Op.DepthwiseConv2DBias
1423 # Set IFM/OFM shapes after changing op type
1424 op.set_ifm_ofm_shapes()
1425
1426 weight_scale, bias = 1, None
1427 ofmq, ifmq = op.ofm.quantization, inp.quantization
1428 # Set rounding mode, scaling and zero point based on which reference implementation to match
1429 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1430 if inp.dtype == DataType.uint8:
1431 # This attribute means a different scaling calculation is used in order to match reference
1432 op.low_precision_scaling = True
1433 weight_scale = h * w
1434 # Set zero points to 0 as they will be adjusted for with bias term
1435 foq = ofmq.clone()
1436 foq.zero_point = 0
1437 fiq = ifmq.clone()
1438 fiq.zero_point = 0
1439 op.forced_input_quantization = fiq
Johan Alfvén17009392022-08-30 09:14:56 +02001440 bias_term = ofmq.zero_point - round_up_to_int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001441 # If the bias term is outside uint8 range, we need an Add op to apply it.
1442 if bias_term < 0 or bias_term > 255:
1443 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1444 # Bias term has higher bitness (i32) than input/output (u8).
1445 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1446 # the bias can only effectively assume values in the range [-255, 255].
1447 intermediate.dtype = DataType.int16
1448 intermediate.quantization.zero_point = 0
1449 add_op = Operation(Op.Add, op.name + "_bias")
1450 add_op.forced_output_quantization = foq
1451 add_op.add_input_tensor(intermediate)
1452 quant = QuantizationParameters()
1453 quant.zero_point = 0
1454 bias_term_tens = create_const_tensor(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001455 op.name + "_bias",
1456 [1, 1, 1, 1],
1457 DataType.int16,
1458 [bias_term],
1459 np.int16,
1460 quantization=quant,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001461 )
1462 add_op.add_input_tensor(bias_term_tens)
1463 add_op.set_output_tensor(op.ofm)
1464 add_op.set_ifm_ofm_shapes()
1465 add_op.activation = op.activation
1466 op.activation = None
1467 op.set_output_tensor(intermediate)
1468 op.set_ifm_ofm_shapes()
1469 # If not, we can just do it with the OFM zero point.
1470 else:
1471 foq.zero_point = bias_term
1472 op.forced_output_quantization = foq
1473 else:
1474 assert inp.dtype == DataType.int8
1475 # Use a depthwise to calculate the sum,
1476 # followed by a multiplication with 1/N to get the MEAN
1477 weight_scale = 1
1478 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1479 intermediate.dtype = DataType.int16
1480 mul_op = Operation(Op.Mul, op.name + "_mul")
1481 mul_op.add_input_tensor(intermediate)
1482 # Create scalar containing 1/N
1483 quant = QuantizationParameters()
1484 quant.zero_point = 0
1485 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1486 # while rounding mode NATURAL would round this to -1.
1487 # This can only occur if N is even, and can be emulated by
1488 # multiplying with a number that is slightly smaller than 1/N.
1489 # It must be so small that other roundings are not affected;
1490 # the calculated value is based on worst case,
1491 # which is sum 256 * N (the maximum sum that can occur with int8)
1492 n = int(h * w)
1493 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1494 quant.scale_f32 = 1 / (n - eps)
1495 scalar = create_const_tensor(
1496 op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
1497 )
1498 mul_op.add_input_tensor(scalar)
1499 mul_op.set_output_tensor(op.ofm)
1500 mul_op.set_ifm_ofm_shapes()
1501 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1502 mul_op.activation = op.activation
1503 op.activation = None
1504 op.set_output_tensor(intermediate)
1505 op.set_ifm_ofm_shapes()
1506 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1507 # Here we can just use a simple AvgPool with truncating rounding,
1508 # as we're emulating simple integer division.
1509 op.rounding_mode = NpuRoundingMode.TRUNCATE
1510 op.type = Op.AvgPool
1511 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1512 else:
1513 op.rounding_mode = NpuRoundingMode.NATURAL
1514 weight_scale = 1 / (h * w)
1515 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1516 bias = -ifmq.zero_point * h * w
1517 fiq = ifmq.clone()
1518 fiq.zero_point = 0
1519 op.forced_input_quantization = fiq
1520
1521 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001522 def extend_dims(dim, in_shape):
1523 if dim < 4:
1524 in_shape = [1] + in_shape
1525 if dim == 2:
1526 in_shape += [1]
1527 return in_shape
1528
1529 if dims < 4 or dims_ofm < 4:
1530 # Fix the ofm dimension when keep_dims is false
1531 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1532 if isinstance(axis, int) and dims_ofm + 1 == dims:
1533 ofm_shape.insert(axis, 1)
1534 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1535 for i in axis:
1536 ofm_shape.insert(i, 1)
1537 shape = extend_dims(dims, shape)
1538 dims_ofm = len(ofm_shape)
1539 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1540 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001541
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001542 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
1543 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001544 shape = [shape[0], 1, h * w, shape[3]]
1545 op.ifm_shapes[0] = Shape4D(shape)
1546 if h > 256 and op.type == Op.AvgPool:
1547 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1548
1549 # If the AvgPool version is used, we don't need to do anything else
1550 if op.type == Op.AvgPool:
1551 return op
1552
1553 # Make unit weight tensor quantization
1554 weight_quant = ifmq.clone()
1555 weight_quant.min = 0
1556 weight_quant.max = 255
1557 weight_quant.scale_f32 = weight_scale
1558 weight_quant.zero_point = 0
1559
1560 # Set weight shape to [H,W,C,B]
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001561 weight_shape = [h, w, shape[3], shape[0]]
1562
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001563 # Add unit weight tensor
1564 op.set_input_tensor(
1565 create_const_tensor(
1566 "weights",
1567 weight_shape,
1568 inp.dtype,
1569 np.ones(weight_shape),
1570 value_dtype=np.uint8,
1571 quantization=weight_quant,
1572 ),
1573 1,
1574 )
James Peet7519d502021-07-19 16:47:58 +01001575 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001576
1577 # Add None bias tensor
1578 op.inputs.append(None)
1579 # Add bias tensor
1580 if bias:
1581 bias_shape = [shape[-1]]
1582 op.set_input_tensor(
1583 create_const_tensor(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001584 "bias",
1585 bias_shape,
1586 inp.dtype,
1587 np.ones(bias_shape) * bias,
1588 value_dtype=np.int32,
1589 quantization=None,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001590 ),
1591 2,
1592 )
1593
1594 return op
1595
1596
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001597def optimise_quantize(op: Operation, arch, nng):
1598
1599 if op.type == Op.Quantize and op.run_on_npu:
1600
1601 ifm, ofm = op.get_ifm_ofm()
1602 input_values = ifm.values
1603
1604 # Guard clause - input not const or no values to quantize
1605 if ifm.ops[0].type != Op.Const or input_values is None:
1606 return op
1607
1608 # Singular val in numpy array, convert to indexable array
1609 if input_values.ndim == 0:
1610 input_values = np.array([input_values])
1611
Fredrik Svedberg11563172022-07-06 14:54:12 +02001612 # requantized int8 to int8 or int16 to int16
1613 if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001614
1615 # scale needs to use double precision to match TFLite reference kernel
1616 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
1617 effective_multiplier, effective_shift = quantise_scale(effective_scale)
1618
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001619 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001620 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001621 input_val = val - ifm.quantization.zero_point
1622
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001623 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
1624 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001625
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001626 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
1627 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001628
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001629 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
1630 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001631
1632 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001633 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001634
1635 quantized_vals = []
1636 for val in input_values:
1637
1638 # Derive quantized value
1639 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001640 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
1641 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001642
1643 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001644 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
1645
1646 # Unsupported data type
1647 else:
1648 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001649
1650 # Make quantize op const and disconnect from parent node
1651
1652 # Remove reference of the current quant op from the parent tensor's consumer list
1653 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1654
1655 # Clear any references to parent node
1656 op.inputs = []
1657
1658 # Convert this quantize op to const
1659 op.type = Op.Const
1660
1661 return op
1662
1663
Ayaan Masood4965fae2022-06-29 11:30:57 +01001664def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
1665 """Static optimisation for SHAPE operator output value known at compile time"""
1666
1667 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
1668
1669 if op.type == Op.Shape and op.run_on_npu:
1670
1671 ifm, ofm = op.get_ifm_ofm()
1672
1673 if len(ifm.shape) != ofm.shape[0]:
1674 return op
1675
1676 # Remove reference of the current shape op from the parent tensor's consumer list
1677 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1678
1679 # Clear any references to parent node
1680 op.inputs = []
1681
1682 # Convert this SHAPE op to const
1683 op.type = Op.Const
1684
1685 # Add size calculation to shape output tensors
1686 ofm.values = np.array(ifm.shape)
1687
1688 return op
1689
1690
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001691def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001692 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001693 return op
1694
1695
1696def tflite_optimise_graph(nng, arch):
Fredrik Svedberg11563172022-07-06 14:54:12 +02001697 # Compile time static optimisations
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001698 optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
1699
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001700 for idx, sg in enumerate(nng.subgraphs):
1701 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001702 nng,
1703 sg,
1704 arch,
1705 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01001706 optimisation_list,
1707 rewrite_unsupported=False,
1708 )
1709
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001710 # Pre-processing step
1711 pre_process_list = [
1712 supported_operator_check,
1713 set_ifm_ofm_op_shapes,
1714 ]
1715
Ayaan Masood4965fae2022-06-29 11:30:57 +01001716 for idx, sg in enumerate(nng.subgraphs):
1717 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1718 nng,
1719 sg,
1720 arch,
1721 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02001722 pre_process_list,
1723 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001724 )
1725
1726 # Handle Concat Ops
1727 for idx, sg in enumerate(nng.subgraphs):
1728 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1729 sg.refresh_after_modification()
1730
1731 # Handle Split Ops
1732 for idx, sg in enumerate(nng.subgraphs):
1733 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1734 nng,
1735 sg,
1736 arch,
1737 [],
1738 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1739 rewrite_unsupported=False,
1740 )
1741
1742 for idx, sg in enumerate(nng.subgraphs):
1743 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001744 nng,
1745 sg,
1746 arch,
1747 [rewrite_split_ops],
1748 [],
1749 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001750 )
1751
1752 # Handle sg input output
1753 for idx, sg in enumerate(nng.subgraphs):
1754 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001755 nng,
1756 sg,
1757 arch,
1758 [],
1759 [fix_sg_input_output],
1760 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001761 )
1762
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001763 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001764 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001765 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001766 sg.refresh_after_modification()
1767
1768 # Rewrite of operators
1769 op_rewrite_list = [
1770 set_tensor_equivalence,
1771 convert_mean_to_depthwise_conv_or_avgpool,
1772 convert_depthwise_to_conv,
1773 convert_conv_to_fc,
1774 convert_softmax,
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001775 convert_prelu,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001776 optimise_strided_conv,
1777 convert_hardswish_to_lut,
1778 rewrite_fully_connected_input,
1779 convert_batched_fc_shape,
1780 fixup_conv2d_backprop,
1781 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001782 reorder_depthwise_weights,
Tim Hall885033b2022-07-21 11:46:03 +01001783 fixup_resize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001784 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001785 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001786 convert_mul_max_to_abs_or_lrelu,
1787 convert_lrelu,
1788 convert_tanh_sigmoid_to_lut,
1789 replace_pad_by_hw_pad,
1790 ]
1791
1792 for idx, sg in enumerate(nng.subgraphs):
1793 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001794 nng,
1795 sg,
1796 arch,
1797 [],
1798 op_rewrite_list,
1799 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001800 )
1801
1802 for idx, sg in enumerate(nng.subgraphs):
1803 # remove passthrough tensors and attempt further optimizations
1804 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1805 nng,
1806 sg,
1807 arch,
1808 [remove_passthrough_tensor],
1809 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1810 )
1811
1812 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1813 # since ifm/ofm_shapes are of importance to this function
1814 for sg in nng.subgraphs:
1815 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1816 sg.refresh_after_modification()
1817
1818 return nng