blob: f3ca1b6332944e776b50d06eb216454c5a1b8686 [file] [log] [blame]
Rickard Bolinfea15162022-07-04 16:19:16 +00001# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from . import rewrite_graph
26from . import scaling
27from .api import NpuRoundingMode
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020028from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020029from .data_type import DataType
30from .debug_database import DebugDatabase
31from .errors import UnsupportedFeatureError
32from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020033from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020034from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020035from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020036from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020037from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020038from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020039from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020040from .graph_optimiser_util import needed_total_padding
41from .graph_optimiser_util import set_ifm_ofm_op_shapes
42from .graph_optimiser_util import set_tensor_equivalence
43from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020044from .numeric_util import round_away_zero
Johan Alfvén17009392022-08-30 09:14:56 +020045from .numeric_util import round_up_to_int
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020046from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020047from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020048from .operation import NpuBlockType
49from .operation import Op
50from .operation import Operation
51from .operation import Padding
52from .operation_util import create_avgpool_nop
53from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010054from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020055from .shape4d import Shape4D
56from .softmax import SoftMax
57from .tensor import check_quantized_tens_scaling_equal
58from .tensor import create_const_tensor
59from .tensor import create_equivalence_id
60from .tensor import QuantizationParameters
61from .tensor import Tensor
62from .tensor import TensorPurpose
63from .tflite_mapping import optype_to_builtintype
64
65passthrough_nodes = (Op.Identity,)
66
67
68def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
69 """Creates an average pool for the given concat op/input feature map"""
70 ofm = concat_op.ofm
71 avgpool_op = create_avgpool_nop(name)
72 avgpool_op.inputs = [ifm]
73 avgpool_op.outputs = [ofm]
74
75 avgpool_op.write_offset = write_offset
76 avgpool_op.write_shape = ifm_shape
77 ofm.ops.append(avgpool_op)
78 DebugDatabase.add_optimised(concat_op, avgpool_op)
79 avgpool_op.ifm_shapes.append(ifm_shape)
80 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
81 avgpool_op.memory_function = Op.ConcatSliceWrite
82 return avgpool_op
83
84
85def remove_passthrough_tensor(tens, arch, nng):
86 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
87 assert len(tens.ops[0].inputs) == 1
88 tens = tens.ops[0].inputs[0]
89 return tens
90
91
92def rewrite_concat_ops(op, arch):
93 if not op.run_on_npu or not op.type.is_concat_op():
94 return
95
96 axis_4D = 0
97 ofm = op.ofm
98 ofm.ops = []
99 offset = 0
100
101 unfuse_activation_function(op)
102
103 if op.type == Op.Pack:
104 # Pack is also referred to as Stack
105 axis = int(op.attrs["axis"])
106 if axis < 0: # Convert to positive axis
107 axis = len(op.inputs[0].shape) + 1 + axis
108
109 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
110
111 axis_4D = axis + (4 - len(desired_shape))
112
113 for idx, inp in enumerate(op.inputs):
114 op.ifm_shapes[idx] = Shape4D(desired_shape)
115 op.type = Op.PackReshaped
116
117 inputs, axis = op.get_concat_inputs_axis()
118 for idx, inp in enumerate(inputs):
119 if op.type != Op.PackReshaped:
120 op.ifm_shapes[idx] = Shape4D(inp.shape)
121 if axis >= 0:
122 axis_4D = axis + (4 - len(inp.shape))
123 else:
124 axis_4D = axis
125 write_offset = [0, 0, 0, 0]
126 write_offset[axis_4D] = offset
127 concat_end = offset + op.ifm_shapes[idx][axis_4D]
128 create_avg_pool_for_concat(
129 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
130 )
131 offset = concat_end
132 assert ofm.shape[axis] == offset
133
134 return op
135
136
137def rewrite_split_ops(tens, arch, nng):
138
139 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
140 split_op = tens.ops[0]
141
142 # Not supported so leave it and run on CPU
143 if not split_op.run_on_npu:
144 return tens
145
146 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
147
148 tens.ops = []
149 new_op = Operation(Op.SplitSliceRead, split_op.name)
150 new_op.inputs = [inp]
151 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000152 if None in (offset_end, offset_start):
153 read_shape = None
154 else:
155 # the read shape is relative to each start offset
156 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200157
158 # For Split the offset cannot be extracted from the tensor so it has to
159 # be calculated from the index of the output tensor
160 if axis is not None:
161 # Get the start and end of the split
162 offset_start = [0] * 4
163 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
164 for idx, out in enumerate(outputs):
165 if axis_4D_list is not None:
166 axis_4D = axis_4D_list[idx]
167 else:
168 split_op.ofm_shapes[idx] = Shape4D(out.shape)
169 if axis >= 0:
170 axis_4D = axis + (4 - len(out.shape))
171 else:
172 axis_4D = axis
173
174 if out == tens:
175 ofm_shape_idx = idx
176 read_shape = split_op.ofm_shapes[idx]
177 break
178
179 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
180
181 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
182 new_op.read_shapes[0] = read_shape
183 new_op.run_on_npu = True
184 new_op.set_output_tensor(tens)
185 new_op.ifm_shapes.append(Shape4D(inp.shape))
186 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
187 DebugDatabase.add_optimised(split_op, new_op)
188
189 return tens
190
191
192def remove_SplitSliceRead(op, arch):
193
194 if op.type == Op.SplitSliceRead:
195 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
196 if (
197 len(op.ofm.consumer_list) == 1
198 and op.ofm.consumer_list[0] is not None
199 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200200 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200201 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
202 ):
203 # SplitSliceRead can be performed by tensor consumer
204 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200205 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200206 else:
207 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
208 avgpool_op.add_input_tensor(op.ifm)
209 avgpool_op.outputs = [op.ofm]
210 op.ofm.ops.remove(op)
211 op.ofm.ops.append(avgpool_op)
212 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
213 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
214 avgpool_op.read_offsets[0] = op.read_offsets[0]
215 avgpool_op.read_shapes[0] = op.read_shapes[0]
216
217 op.ifm.consumer_list.remove(op)
218 DebugDatabase.add_optimised(op, avgpool_op)
219
220
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200221def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
222 k_w, k_h = kernel.dilated_wh()
223 s_x, s_y = kernel.stride
224 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
225 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
226 if padding_type == Padding.SAME:
227 left_pad = (xpad + 0) // 2
228 right_pad = (xpad + 1) // 2
229 top_pad = (ypad + 0) // 2
230 bottom_pad = (ypad + 1) // 2
231 elif padding_type == Padding.VALID:
232 left_pad = 0
233 right_pad = 0
234 top_pad = 0
235 bottom_pad = 0
236 elif padding_type == Padding.EXPLICIT:
237 # Padding is specified in a PAD operator which has been bypassed.
238 top, left, bottom, right = explicit_padding
239 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
240 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Rickard Bolin9ae34552022-06-09 13:07:17 +0000241 elif padding_type == Padding.TILE:
242 # The values in the explicit padding only represent the "direction" in which to pad
243 top_pad, left_pad, bottom_pad, right_pad = explicit_padding
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200244 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000245 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200246 padding = (top_pad, left_pad, bottom_pad, right_pad)
247 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
248 return padding, skirt
249
250
251def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
252 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
253 if padding_type == Padding.SAME:
254 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
255 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
256 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
257 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
258 left_pad = max(kernel_width - 1 - right_pad, 0)
259 top_pad = max(kernel_height - 1 - bottom_pad, 0)
260 elif padding_type == Padding.VALID:
261 right_pad = max(kernel_width - 2, 0)
262 bottom_pad = max(kernel_height - 2, 0)
263 left_pad = kernel_width - 1
264 top_pad = kernel_height - 1
265 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000266 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200267 padding = (top_pad, left_pad, bottom_pad, right_pad)
268 skirt = padding
269 return padding, skirt
270
271
272def fixup_conv2d_backprop(op, arch, nng):
273 if op.type == Op.Conv2DBackpropInput:
274 # flip the inputs
275 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
276 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000277 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200278
279 # Update strides
280 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
281
282 return op
283
284
285# Convert the op to an elementwise add
Tim Hall885033b2022-07-21 11:46:03 +0100286def convert_resize_1x1_to_add(op):
287 op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200288 op.name = op.name + "_add"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200289 # Create an input tensor filled with zeros
290 shape = op.ofm_shapes[0].as_list()
291 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100292 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200293 tens.quantization = QuantizationParameters(0.0, 255.0)
294 tens.quantization.scale_f32 = 1.0
295 tens.quantization.zero_point = 0
296 tens.consumer_list = [op]
297 tens_op = op.inputs[1].ops[0]
298 tens_op.set_output_tensor(tens)
299 # Set the add inputs
300 op.inputs[1] = op.inputs[0]
301 op.inputs[0] = tens
302 op.set_ifm_ofm_shapes()
303
304 return op
305
306
Tim Hall885033b2022-07-21 11:46:03 +0100307# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
308# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
309# to select the appropriate nearest neighbor value
310def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
311 ifm = op.ifm
312 ofm = op.ofm
313 output_depth = ofm.shape[-1]
314 dw_op_attrs = {
315 "padding": Padding.VALID,
316 "stride_h": 1,
317 "stride_w": 1,
318 "strides": (1, 1, 1, 1),
319 "depth_multiplier": 1,
320 "channel_multiplier": 1,
321 "dilation_h_factor": 1,
322 "dilation_w_factor": 1,
323 "dilation": (1, 1, 1, 1),
324 }
325
326 # change resizebilinear to depthwise
327 op.type = Op.DepthwiseConv2DBias
328 op.attrs.update(dw_op_attrs)
329 op.set_input_tensor(ifm, 0) # ifm tensor index
330 op.activation = None
331
332 # add input resample to resize by x2
333 op.ifm_resampling_mode = resampling_mode.NEAREST
334
335 # don't care about the rounding mode as it is nearest neighbor
336
337 # setup weight tensor
338 weight_quant = QuantizationParameters()
339 weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
340 weight_quant.zero_point = 0
341 weight_quant.quant_dim = 0
342 ofm_dtype = ofm.dtype
343 if ofm_dtype == DataType.uint8:
344 weight_value_dtype = np.uint8
345 weight_quant.quant_min = 0
346 weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
347 else:
348 if ofm_dtype == DataType.int8:
349 weight_value_dtype = np.int8
350 else:
351 assert ofm_dtype == DataType.int16
352 weight_value_dtype = np.int16
353
354 weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
355 weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
356
357 weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
358
359 # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
360 # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
361 # below-and-right (i.e. next) to it (D).
362 # 0---1---2
363 # | A | B |
364 # 1---*---+
365 # | C | D |
366 # 2---+---+
367 weight_values = [0] * (upscale_factor * upscale_factor)
368 centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
369 weight_values[centre_coeff] = 1
370
371 # add weight tensor, this will discard the size tensor of the resize op
372 op.set_input_tensor(
373 create_const_tensor(
374 "weights",
375 weight_shape,
376 ofm.dtype,
377 np.array(weight_values).reshape(weight_shape),
378 value_dtype=weight_value_dtype,
379 quantization=weight_quant,
380 ),
381 1, # inputs tensor weight index
382 )
383
384 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
385 # need to append the bias tensor as resize ops only have 2 inputs
386 assert len(op.inputs) == 2
387 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +0200388 fixup_bias_tensors(op, None, None, DataType.int32)
Tim Hall885033b2022-07-21 11:46:03 +0100389
390 # finally update the shape incase we've change the tensor shapes or connections
391 op.set_ifm_ofm_shapes()
392
393 return op
394
395
396# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
397# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
398# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
399def convert_resize_to_upscale_and_average_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200400 pre_op = op
401 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000402 dtype = op.ifm.dtype
Tim Hall885033b2022-07-21 11:46:03 +0100403
Rickard Boline546def2022-01-25 15:45:00 +0000404 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Tim Hall47c76362022-07-18 21:26:47 +0100405 op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
Tim Hall3c5cfe92022-03-16 16:31:57 +0000406 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200407
408 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
Tim Hall47c76362022-07-18 21:26:47 +0100409
410 # Get upscale factor that was calculated in the supported operators check
411 upscale_factor = op.attrs["upscale_factor"]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200412
Rickard Boline546def2022-01-25 15:45:00 +0000413 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100414 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
415 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
Rickard Boline546def2022-01-25 15:45:00 +0000416 n = int(np.log2(upscale_factor))
417
Tim Hall885033b2022-07-21 11:46:03 +0100418 # Perform x2 upscaling n-1 times
Rickard Boline546def2022-01-25 15:45:00 +0000419 scaled_op = pre_op
420 for count in range(n - 1):
421 if count > 0:
422 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200423 scaled_op.inputs[0] = pre_op.outputs[0]
424
Tim Hall885033b2022-07-21 11:46:03 +0100425 # Nearest neighbor x2 upscaling
Tim Hall47c76362022-07-18 21:26:47 +0100426 upscaled_shape = upscaled_shape * 2
Rickard Boline546def2022-01-25 15:45:00 +0000427 shape = op.ofm_shapes[0].as_list()
428 shape[1:3] = upscaled_shape
429 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
430 out_tens.quantization = op.outputs[0].quantization.clone()
431 scaled_op.set_output_tensor(out_tens)
432 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200433
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200434 scaled_op.set_ifm_ofm_shapes()
435
Tim Hall885033b2022-07-21 11:46:03 +0100436 # Last x2 upscaling
Rickard Boline546def2022-01-25 15:45:00 +0000437 if n > 1:
438 scaled_op = op.clone(f"_{n-1}")
439 scaled_op.inputs[0] = pre_op.outputs[0]
Tim Hall885033b2022-07-21 11:46:03 +0100440
441 if scaled_op.original_type == Op.ResizeBilinear:
442 if scaled_op.attrs["align_corners"]:
443 # no padding
444 scaled_op.attrs["padding"] = Padding.VALID
445 else:
446 # padding to the right and bottom (limits average pool to 8x8 kernel)
447 scaled_op.attrs["padding"] = Padding.EXPLICIT
448 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
449
450 # kernal size dependent on the upscaling factor
451 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
452 else: # Op.ResizeNearestNeighbor
453 if scaled_op.attrs["align_corners"]:
454 # use depthwise conv to select the correct value
455 scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
456 else:
457 # keep 1x1 kernel and average pool
458 pass
459
Rickard Boline546def2022-01-25 15:45:00 +0000460 scaled_op.outputs = outputs
461 scaled_op.outputs[0].ops = [scaled_op]
462 scaled_op.set_ifm_ofm_shapes()
463
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200464 return op
465
466
Rickard Bolinfea15162022-07-04 16:19:16 +0000467def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
468 def _compute_interpolation_values(index, input_size, output_size):
469 scale = input_size / output_size
470 scaled_value = (index + 0.5 * half_pixel_centers) * scale - 0.5 * half_pixel_centers
471 lower_bound = max(np.floor(scaled_value), 0)
472
473 return scaled_value, lower_bound
474
475 def _compute_kernels(input_height, input_width, output_height, output_width):
476 kernels = []
477 for y in (1, 2):
478 for x in (1, 2):
479 sv_h, lb_h = _compute_interpolation_values(y, input_height, output_height)
480 sv_w, lb_w = _compute_interpolation_values(x, input_width, output_width)
481
482 # Interpolation values calculated for (x, y) = ([1, 2], [1, 2]) will always generalize to the whole
483 # input for upscale = 2 and input sizes >= 2x2 and be in the correct order for going left-to-right,
484 # top-to-bottom - same as the depthwise convolution strides across each tile
485 kernel = np.zeros((2, 2))
486 kernel[1, 1] = (1 - (sv_h - lb_h)) * (1 - (sv_w - lb_w))
487 kernel[0, 1] = (sv_h - lb_h) * (1 - (sv_w - lb_w))
488 kernel[1, 0] = (1 - (sv_h - lb_h)) * (sv_w - lb_w)
489 kernel[0, 0] = (sv_h - lb_h) * (sv_w - lb_w)
490 kernel *= 16
491 kernels.append(kernel)
492
493 return kernels
494
495 def _build_convolutions(op, kernels):
496 dw_op_attrs = {
497 "padding": Padding.TILE,
498 "stride_h": 1,
499 "stride_w": 1,
500 "strides": (1, 1, 1, 1),
501 "depth_multiplier": 1,
502 "channel_multiplier": 1,
503 "dilation_h_factor": 1,
504 "dilation_w_factor": 1,
505 "dilation": (1, 1, 1, 1),
506 }
507 ifm = op.ifm
508 ofm = op.ofm
509 ofm.ops = []
510 elem_size = 2 if ofm.dtype == DataType.int16 else 1
511
512 n, h, w, c = ifm.shape
513 _, _, ow, _ = ofm.shape
514
515 intermediate_tens = Tensor(ifm.shape, ifm.dtype, "intermediate_tens")
516 intermediate_tens.quantization = op.outputs[0].quantization.clone()
517 avgpool_op = op
518 avgpool_op.name = "rb_init_avgpool"
519 avgpool_op.type = Op.AvgPool
520 avgpool_op.attrs["padding"] = Padding.VALID
521 avgpool_op.attrs["stride_w"] = 1
522 avgpool_op.attrs["stride_h"] = 1
523 avgpool_op.attrs["filter_width"] = 1
524 avgpool_op.attrs["filter_height"] = 1
525 avgpool_op.attrs["strides"] = [1, 1, 1, 1]
526 avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
527
528 avgpool_op.add_input_tensor(ifm)
529 avgpool_op.set_output_tensor(intermediate_tens)
530 avgpool_op.set_ifm_ofm_shapes()
531
532 dw_conv = Operation(Op.DepthwiseConv2DBias, "depthwise_conv")
533 dw_conv._original_type = Op.ResizeBilinear
534 dw_conv.write_shape = Shape4D(n, h, w, c)
535 dw_conv.write_offset = Shape4D(0, 0, 0, 0)
536
537 # Set the output rounding mode. Resize bilinear requires rounding away from zero. Therefore, we need to
538 # adjust the accumulated value by a "small" amount before applying natural rounding. The "small" amount
539 # should be big enough to cause a x.5 to be rounded correctly but small enough not to cause smaller
540 # values to be incorrectly rounded
541 ofm.quantization.next_after = True
542 dw_conv.rounding_mode = NpuRoundingMode.NATURAL
543
544 # Double height and width stride to write the output of each of the four depthwise convolutions below
545 # interleaved with each other when combined with OFM tile base offsets.
546 dw_conv.ofm_stride_multiplier = [1, 2, 2] # C/H/W
547
548 # Choose tile padding direction - pad by 1 with edge values in two direction.
549 # For example, TL (top left) will pad top and left in H/W-plane in all channels.
550 directions = [[1, 1, 0, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 1]] # TL, TR, BL, BR
551 for i in (0, 1):
552 for j in (0, 1):
553 index = i * 2 + j
554 dw_conv.name = f"depthwise_conv_{index}"
555 dw_op_attrs["explicit_padding"] = directions[index]
556 dw_conv.attrs.update(dw_op_attrs)
557
558 # This will offset the start of the write by modifying the Tile 0 base address
559 dw_conv.tile_base_offsets_ofm[0] = (i * ow + j) * c * elem_size
560
561 ofm.ops.append(dw_conv)
562 dw_conv.outputs = [ofm]
563
564 kernel = kernels[index]
565 shape = [2, 2, 1, c]
566 kernel = np.dstack([kernel] * c)
567
568 quant = QuantizationParameters()
569 quant.zero_point = 0
570 quant.scale_f32 = 1.0 / 16
571
572 dw_conv.inputs = []
573 dw_conv.add_input_tensor(intermediate_tens)
574 dw_conv.add_input_tensor(
575 create_const_tensor(
576 "weights",
577 shape,
578 intermediate_tens.dtype,
579 np.array(kernel).reshape(shape),
580 value_dtype=np.int8,
581 quantization=quant,
582 ),
583 )
584
585 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
586 # need to append the bias tensor as resize ops only have 2 inputs
587 assert len(dw_conv.inputs) == 2
588 dw_conv.inputs.append(None)
Rickard Bolin017b4cc2022-09-23 10:16:48 +0000589 fixup_bias_tensors(dw_conv, None, None, dtype=DataType.int32)
Rickard Bolinfea15162022-07-04 16:19:16 +0000590
591 dw_conv.set_ifm_ofm_shapes()
592 dw_conv = dw_conv.clone(f"_{index}")
593 return op
594
595 _, input_height, input_width, _ = op.ifm.shape
596 _, output_height, output_width, _ = op.ofm.shape
597
598 kernels = _compute_kernels(input_height, input_width, output_height, output_width)
599 op = _build_convolutions(op, kernels)
600
601 return op
602
603
Tim Hall885033b2022-07-21 11:46:03 +0100604def fixup_resize(op, arch, nng):
605 if op.type.is_resize_op() and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200606 if op.ifm_shapes[0] == op.ofm_shapes[0]:
Tim Hall885033b2022-07-21 11:46:03 +0100607 # Bypass the resize op which is essentially a NOP
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200608 op.inputs = op.inputs[:1]
609 op.type = Op.Identity
610 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
Tim Hall885033b2022-07-21 11:46:03 +0100611 convert_resize_1x1_to_add(op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000612 elif op.type == Op.ResizeBilinear and op.attrs.get("half_pixel_centers", False):
613 convert_resizebilinear_to_depthwise_convolutions(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200614 else:
Tim Hall885033b2022-07-21 11:46:03 +0100615 convert_resize_to_upscale_and_average_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200616
617 return op
618
619
620def convert_nop_split_to_identity(op, arch, nng):
621 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
622 # the list comprehension should return a list with a single tensor
623 # if it shouldn't, remove_passthrough_tensor will fail appropriately
624 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
625 op.type = Op.Identity
626 return op
627
628
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100629def rewrite_fully_connected_input(op: Operation, arch, nng):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200630
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100631 if op.type == Op.FullyConnected:
632 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
633 assert new_shape is not None, "Tensor can not be reshaped to 2D"
634 op.ifm_shapes[0] = new_shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200635 return op
636
637
638def convert_batched_fc_shape(op, arch, nng):
639 if op.type == Op.FullyConnected:
640 # Check if the first dimension indicates batching
641 if op.ifm_shapes[0].batch > 1:
642 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
643 n = op.ifm_shapes[0].batch
644 h, w = batching_split.get(n, (1, n))
645 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
646
647 # Reshape Weights to be 4D. IO becomes HWIO
648 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100649 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
650 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200651
652 n = op.ofm_shapes[0].batch
653 h, w = batching_split.get(n, (1, n))
654 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
655 return op
656
657
658def unfuse_activation_function(op):
659 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
660 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
661 op.activation = None
662 out_tens = op.outputs[0]
663 intermediate_tens = out_tens.clone("_act_intermediate")
664 act_op.set_output_tensor(out_tens)
665 act_op.add_input_tensor(intermediate_tens)
666 op.set_output_tensor(intermediate_tens)
667 act_op.set_ifm_ofm_shapes()
668
669
670def rewrite_stridedslice_output(op, arch, nng):
671 if not op.run_on_npu or op.type != Op.StridedSlice:
672 return op
673
674 new_axis_mask = op.attrs["new_axis_mask"]
675 shrink_axis_mask = op.attrs["shrink_axis_mask"]
676
677 if shrink_axis_mask == 0 and new_axis_mask == 0:
678 return op
679
680 axis_4D = [0] * len(op.outputs)
681 for idx, out_tens in enumerate(op.outputs):
682 output_shape = list(out_tens.shape)
683
684 if shrink_axis_mask != 0:
685 n = 0
686 axis = 0
687 while shrink_axis_mask:
688 prev_mask = shrink_axis_mask
689 n += 1
690 shrink_axis_mask &= shrink_axis_mask - 1
691 axis = int(math.log2(prev_mask - shrink_axis_mask))
692 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
693
694 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
695 op.attrs["shrink_axis_mask"] = 0
696 if axis >= 0:
697 axis_4D[idx] = axis + (4 - len(output_shape))
698 else:
699 axis_4D[idx] = axis
700 op.ofm_shapes[idx] = Shape4D(output_shape)
701
702 elif new_axis_mask != 0:
703 n = 0
704 axis = 0
705 while new_axis_mask:
706 prev_mask = new_axis_mask
707 n += 1
708 new_axis_mask &= new_axis_mask - 1
709 axis = int(math.log2(prev_mask - new_axis_mask))
710 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
711 new_axis_mask >>= 1
712
713 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
714 op.attrs["new_axis_mask"] = 0
715 if axis >= 0:
716 axis_4D[idx] = axis + (4 - len(output_shape))
717 else:
718 axis_4D[idx] = axis
719 op.ofm_shapes[idx] = Shape4D(output_shape)
720
721 op.attrs["split_axis_4D"] = axis_4D
722 return op
723
724
725def rewrite_unpack_output(op, arch, nng):
726 tens = op.outputs[0]
727 if op.run_on_npu and op.type == Op.Unpack:
728 # Unpack is also referred to as Unstack
729 axis = int(op.attrs["axis"])
730 if axis < 0: # Convert to positive axis
731 axis = len(op.inputs[0].shape) + 1 + axis
732 op.type = Op.UnpackReshaped
733 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
734
735 axis_4D = axis + (4 - len(desired_output_shape))
736 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
737
738 for idx, out_tens in enumerate(op.outputs):
739 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
740 return op
741
742
743def add_padding_fields(op, arch, nng):
744 if op.run_on_npu:
745 if "padding" in op.attrs:
746 input_shape = op.ifm_shapes[0]
747 output_shape = op.ofm_shapes[0]
748 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
749 kernel_size = op.inputs[1].shape[:2]
750 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
751 kernel_size = op.attrs["ksize"][1:3]
752 else:
753 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
754
755 if op.type == Op.Conv2DBackpropInputSwitchedBias:
756 upscaling_factor = output_shape.height // input_shape.height
757 padding, skirt = calc_upscaled_padding_and_skirt(
758 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
759 )
760 else:
761 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200762 op.attrs["padding"],
763 op.kernel,
764 input_shape,
765 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200766 )
767
768 op.attrs["explicit_padding"] = padding
769 op.attrs["skirt"] = skirt
770
771 return op
772
773
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200774def reorder_depthwise_weights(op, arch, nng):
775 if op.type.is_depthwise_conv2d_op():
776 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100777 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
778 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200779 weight_tensor.weight_transpose_depthwise = True
780
781 return op
782
783
784def optimise_strided_conv(op, arch, nng):
Louis Verhaard43d27582022-03-17 14:06:00 +0100785 if op.type != Op.Conv2DBias or op.op_index != 0:
786 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200787 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100788 weight_tensor = op.weights
789 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200790
791 if (
Louis Verhaard43d27582022-03-17 14:06:00 +0100792 stride_x == 2
793 and ifm_shape.depth <= 4
794 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200795 and weight_tensor is not None
796 and weight_tensor.shape[1] >= 2
797 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100798 k_w, _ = op.get_kernel_size()
799 curr_padding_x = needed_total_padding(ifm_shape.width, 2, k_w)
800 optimised_padding_x = needed_total_padding(ifm_shape.width // 2, 1, (k_w + 1) // 2)
801 if curr_padding_x != optimised_padding_x:
802 # Horizontal padding would become different after optimisation; this would not work
803 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200804 # IFM
805 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
806
807 # Weights
808 weight_shape = weight_tensor.shape
809 if weight_shape[1] % 2 != 0:
810 weight_shape[1] = weight_shape[1] + 1
811 padded_array = np.zeros(weight_shape)
812 for i in range(weight_shape[0]):
813 padded_array[i] = np.vstack(
814 [
James Peet7519d502021-07-19 16:47:58 +0100815 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200816 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
817 ]
818 )
James Peet7519d502021-07-19 16:47:58 +0100819 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200820 weight_shape[1] //= 2
821 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100822 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200823 weight_tensor.set_all_shapes(weight_shape)
824 # If multiple copies of the weights are used, we could avoid
825 # them having the same address by changing the value_id
826 weight_tensor.value_id = uuid.uuid4()
827
828 # Strides
829 stride_x = 1
830 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
831
832 return op
833
834
835def convert_conv_to_fc(op, arch, nng):
836 # Conv 1x1 can be equivalent to Fully Connected.
837 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
838 # caching/double buffering for the weights.
839 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
840 if op.type == Op.Conv2DBias:
841 h = op.ifm_shapes[0].height
842 w = op.ifm_shapes[0].width
843 kh, kw, _, _ = op.inputs[1].shape
844 if h == 1 and w == 1 and kh == 1 and kw == 1:
845 # Overwrite this op as a Fully Connected Op
846 op.name += "_fc"
847 op.type = Op.FullyConnected
848 op.attrs = {
849 "weights_format": 0,
850 }
851 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
852 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100853 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
854 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200855
856 DebugDatabase.add_optimised(op, op)
857 return op
858
859
860def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
861 if op.run_on_npu and op.type.is_relu_op():
862 ifm = op.inputs[0]
863 ofm = op.outputs[0]
864 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
865 # and requires its own to be inserted
866 if not check_quantized_tens_scaling_equal(ifm, ofm):
867 # Override this op with its own primary op (avgpool)
868 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
869 # And fuse the original activation function to it
870 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200871 # Add explicit rescaling
872 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
873 multiplier, shift = scaling.quantise_scale(rescale)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200874 relu_fused_op.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200875 # Tidy up and assign the ifm and ofm to the new op
876 ifm.consumer_list.remove(op)
877
878 relu_fused_op.add_input_tensor(ifm)
879 relu_fused_op.set_output_tensor(ofm)
880 relu_fused_op.set_ifm_ofm_shapes()
881 op = relu_fused_op
882 return op
883
884
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200885def convert_softmax(op, arch, nng):
886 if op.type == Op.Softmax and op.run_on_npu:
887 softmax = SoftMax(op)
888 op = softmax.get_graph()
889 return op
890
891
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200892def convert_prelu(op, arch, nng):
893 if op.type == Op.Prelu:
894 ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
895 if None in (ifm, alpha, ofm):
896 return op
897
Fredrik Svedberg66591652022-08-29 10:51:27 +0200898 if alpha.values is not None:
899 # If const alpha check for possible optimisations
900 alpha_zp = alpha.quantization.zero_point
901 alpha_scale = alpha.quantization.scale_f32
902 # If all alpha values are the same the PReLU can be converted to LeakyRelu
903 alpha_min = (alpha.values.min().astype(np.int) - alpha_zp) * alpha_scale
904 alpha_max = (alpha.values.max().astype(np.int) - alpha_zp) * alpha_scale
905 if alpha_min == alpha_max:
906 # or even a Relu
907 if alpha_min == 0:
908 new_op = Op.Relu
909 else:
910 new_op = Op.LeakyRelu
911 op.attrs["alpha"] = alpha_min
912 # setup alpha_scaling for bit exact result
913 ifm_scale = ifm.quantization.scale_f32
914 ofm_scale = ofm.quantization.scale_f32
915 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha_scale, ofm_scale)
916 op.attrs["alpha_scaling"] = (alpha.values.min() - alpha_zp, alpha_scale, alpha_shift)
917 # Change op type
918 op.type = new_op
919 op.name = op.name.replace("Prelu", new_op.name)
920 del op.inputs[1] # Remove alpha tensor
921 return op
922 elif alpha_max < 1:
923 # If alpha_max is less than 1 convert PReLU to Max(alpha * IFM, identity * IFM)
924 # Multiply with alpha tensor
925 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
926 mul_alpha.add_input_tensor(ifm)
927 mul_alpha.add_input_tensor(alpha)
928 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
929 mul_alpha.set_output_tensor(fm_alpha)
930 mul_alpha.set_ifm_ofm_shapes()
931 DebugDatabase.add_optimised(op, mul_alpha)
932 if check_quantized_tens_scaling_equal(ifm, ofm):
933 # No scaling is needed
934 fm_id = ifm
935 else:
936 # Add multiplication with identity
937 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
938 mul_identity.add_input_tensor(ifm)
939 # Create const tensor containing identity as scalar
940 quantization = ifm.quantization.clone()
941 quantization.scale_f32 = np.float32(1)
942 quantization.zero_point = 0
943 one = create_const_tensor("one_const", [], ifm.dtype, [1], quantization=quantization)
944 mul_identity.add_input_tensor(one)
945 # Make sure that fm_id is allocated to a different address than fm_alpha
946 fm_id = ofm.clone(op.name + "_id", set_unique=True)
947 mul_identity.set_output_tensor(fm_id)
948 mul_identity.set_ifm_ofm_shapes()
949
950 # Combine scaled and alpha multiplied values
951 max_op = Operation(Op.Maximum, op.name + "_max")
952 max_op.add_input_tensor(fm_alpha)
953 max_op.add_input_tensor(fm_id)
954 max_op.set_output_tensor(ofm)
955 max_op.set_ifm_ofm_shapes()
956
957 DebugDatabase.add_optimised(op, max_op)
958 ifm.consumer_list.remove(op)
959 return max_op
960
961 # Catch all PReLU conversion for the cases that could not be optimised above
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200962 no_scale_quant = ifm.quantization.clone()
963 no_scale_quant.scale_f32 = None
964 no_scale_quant.zero_point = 0
Fredrik Svedberg66591652022-08-29 10:51:27 +0200965 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200966
967 # Select values < 0
968 min_op = Operation(Op.Minimum, op.name + "_min")
969 min_op.add_input_tensor(ifm)
970 min_op.add_input_tensor(zero)
971 fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
972 min_op.set_output_tensor(fm_negative)
973 min_op.set_ifm_ofm_shapes()
974 DebugDatabase.add_optimised(op, min_op)
975
976 # and multiply with alpha tensor
977 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
978 mul_alpha.add_input_tensor(fm_negative)
979 mul_alpha.add_input_tensor(alpha)
980 fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
981 mul_alpha.set_output_tensor(fm_alpha)
982 mul_alpha.set_ifm_ofm_shapes()
983 DebugDatabase.add_optimised(op, mul_alpha)
984
985 # Select (and scale) values > 0
986 relu_op = Operation(Op.Relu, op.name + "_relu")
987 relu_op.add_input_tensor(ifm)
988 fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
989 relu_op.set_output_tensor(fm_scaled)
990 relu_op.set_ifm_ofm_shapes()
991 DebugDatabase.add_optimised(op, relu_op)
992
993 # Add scaled and alpha multiplied values (without scaling)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200994 add_op = Operation(Op.Add, op.name + "_add")
995 add_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200996 add_op.add_input_tensor(fm_alpha)
997 add_op.add_input_tensor(fm_scaled)
998 add_op.set_output_tensor(ofm)
999 add_op.set_ifm_ofm_shapes()
1000
1001 DebugDatabase.add_optimised(op, add_op)
1002 ifm.consumer_list.remove(op)
1003 op = add_op
1004
1005 return op
1006
1007
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001008def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
1009 r"""Whenever there is a subgraph with this topology:
1010
Jonas Ohlssond8575072022-03-30 10:30:25 +02001011 Input X For X = -1 or X > 0
1012 | \ / This subgraph can be replaced with either
1013 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
1014 | /
1015 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001016 """
1017
1018 if op.type == Op.Maximum:
1019 # finds the Mul input(s) to the Max
1020 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
1021 if len(muls) == 1:
1022 mul = muls[0].ops[0]
1023 elif len(muls) == 2:
1024 # In the case both inputs are Muls, find the one with the same input as the Max
Fredrik Svedberg66591652022-08-29 10:51:27 +02001025 mul_ifms = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1]
1026 if len(mul_ifms):
1027 mul = mul_ifms[0].ops[0]
1028 else:
1029 # Not using same input
1030 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001031 else:
1032 # No Mul inputs
1033 return op
1034
1035 # make sure the Mul doesn't have any other consumers
1036 mul_ofm = mul.outputs[0]
1037 if len(mul_ofm.consumers()) != 1:
1038 return op
1039 # make sure the Mul doesn't have a fused activation function
1040 if mul.activation:
1041 return op
1042 ifm, ofm = op.get_ifm_ofm()
1043 if ifm is None or ofm is None:
1044 return op
1045
1046 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1047 return op
1048 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
1049 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
1050 return op
1051
1052 # finds the branched input that goes to both the Max and the Mul
1053 shared = set(op.inputs) & set(mul.inputs)
1054 if len(shared) == 1:
1055 shared_in = shared.pop()
1056 # find the constant scalar input to the Mul
1057 const_tens = (set(mul.inputs) - {shared_in}).pop()
1058 # check that it is a scalar
1059 if const_tens.shape != []:
1060 return op
1061 const = const_tens.ops[0]
1062 # check that it is a constant
1063 if const.type != Op.Const:
1064 return op
1065 # Remove the Mul from the shared input's consumers
1066 shared_in.consumer_list.remove(mul)
1067 else:
1068 return op
1069
1070 val = const.outputs[0].values
1071 if val >= 0:
1072 new_op = Op.LeakyRelu
1073 op.attrs["alpha"] = val
1074 # to produce bit exact results, the alpha is not enough;
1075 # save additional scaling info in attr "alpha_scale", to be used as input
1076 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +01001077 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001078 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
1079 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
1080 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
1081 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
1082 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
1083 elif val == -1:
1084 new_op = Op.Abs
1085 else:
1086 return op
1087
1088 op.type = new_op
1089 op.name = op.name.replace("Maximum", new_op.name)
1090 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
1091 op.inputs = [shared_in]
1092 op.set_ifm_ofm_shapes()
1093
1094 # Record optimisation in debug database
1095 DebugDatabase.add_optimised(op, op)
1096
1097 return op
1098
1099
1100def convert_hardswish_to_lut(op, arch, nng):
1101 if op.type == Op.HardSwish:
1102 ifm, ofm = op.get_ifm_ofm()
1103 # Generate the LUT
1104 ifm_scale = np.double(ifm.quantization.scale_f32)
1105 ofm_scale = np.double(ofm.quantization.scale_f32)
1106 zp_in = ifm.quantization.zero_point
1107 zp_out = ofm.quantization.zero_point
1108 ifm_scale_hires = (1 / 128) * ifm_scale
1109 relu_multiplier = np.double(3 / 32768)
1110 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
1111 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
1112 # Use 16bit scale
1113 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
1114 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
1115
1116 values = []
1117 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1118 quantized_min = min(ix)
1119 quantized_max = max(ix)
1120 for x in ix:
1121 input_value = x - zp_in
1122 input_value_hires = input_value * 128
1123 # Compute the input value on essentially the output scale, not shifted yet
1124 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
1125 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
1126 relu_value = np.int16(input_value_hires)
1127 if relu_shift < 31:
1128 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
1129
1130 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
1131
1132 if relu_shift < 31:
1133 relu_value = fp_math.shift_left16(relu_value, 1)
1134
1135 if relu_shift > 31:
1136 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
1137
1138 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
1139 # Now convert that to a 16bit fixedpoint value in [0, 1]
1140 relu_value = (relu_value + (1 << 15)) >> 1
1141 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
1142 shift = 31 - out_shift
1143 shift = -shift if shift < 0 else 0
1144 # Finally apply the output shift
1145 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
1146 lut_result = min(quantized_max, max(quantized_min, lut_result))
1147 values.append(lut_result)
1148 return convert_to_lut(op, values, "hardswish")
1149 return op
1150
1151
1152def convert_lrelu_to_mul_max(op, arch):
1153 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
1154 # (the opposite of convert_mul_max_to_abs_or_lrelu)
1155 ifm, ofm = op.get_ifm_ofm()
1156 if ifm is None or ofm is None:
1157 return op
1158
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001159 alpha = np.float32(op.attrs["alpha"])
1160 use_mul_max = 0 < alpha < 1
Fredrik Svedberg36424312022-09-16 09:39:26 +02001161 is_converted_prelu = "alpha_scaling" in op.attrs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001162 if use_mul_max:
1163 mul_ifm = ifm
1164 new_op = Op.Maximum
1165 else:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001166 # Need to use a different approach for alpha < 0 or alpha > 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001167 no_scale_quant = ifm.quantization.clone()
1168 no_scale_quant.scale_f32 = None
1169 no_scale_quant.zero_point = 0
1170 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
1171
1172 # Select values < 0
1173 min_op = Operation(Op.Minimum, op.name + "_min")
1174 min_op.add_input_tensor(ifm)
1175 min_op.add_input_tensor(zero)
1176 mul_ifm = ifm.clone(op.name + "_negative", set_unique=True)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001177 if alpha < 0 and not is_converted_prelu:
1178 # For negative alpha that is not from a converted PReLU we need to use
1179 # int32 Mul below to perform the (negative) alpha scaling
1180 mul_ifm.dtype = DataType.int32
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001181 min_op.set_output_tensor(mul_ifm)
1182 min_op.set_ifm_ofm_shapes()
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001183 new_op = Op.Add
1184 op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001185 DebugDatabase.add_optimised(op, min_op)
1186
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001187 # Add multiplication with alpha
1188 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001189 mul_alpha.add_input_tensor(mul_ifm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001190 # Create const tensor containing alpha as scalar
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001191 quantization = ifm.quantization.clone()
1192 quantization.min = 0
1193 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
1194 quantization.zero_point = 0
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001195 alpha_dtype = mul_ifm.dtype
Fredrik Svedberg36424312022-09-16 09:39:26 +02001196 if is_converted_prelu:
1197 # The LeakyRelu was the result from convert_prelu and the scaling is provided
Fredrik Svedberg66591652022-08-29 10:51:27 +02001198 scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001199 mul_alpha.explicit_scaling = ExplicitScaling(False, [alpha_shift], [alpha_scale])
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001200 elif alpha == 0 or np.isinf(1 / alpha):
1201 # Handling of alpha near or at zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001202 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001203 scalar = 0
1204 else:
1205 quantization.scale_f32 = alpha
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001206 if alpha_dtype == DataType.int32:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001207 # When the datatype is int32 (alpha negative) we need to do the scaling with the multiplication
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001208 scalar, _ = scaling.elementwise_mul_scale(ifm.quantization.scale_f32, alpha, ofm.quantization.scale_f32)
1209 else:
1210 scalar = 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001211 alpha_tens = create_const_tensor(
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001212 op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], alpha_dtype.as_numpy_type(), quantization=quantization
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001213 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001214 mul_alpha.add_input_tensor(alpha_tens)
1215 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1216 mul_alpha.set_output_tensor(fm_alpha)
1217 mul_alpha.set_ifm_ofm_shapes()
1218 DebugDatabase.add_optimised(op, mul_alpha)
1219
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001220 if not use_mul_max:
1221 relu_op = Operation(Op.Relu, op.name + "_relu")
1222 relu_op.add_input_tensor(ifm)
1223 fm_id = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1224 relu_op.set_output_tensor(fm_id)
1225 relu_op.set_ifm_ofm_shapes()
1226 DebugDatabase.add_optimised(op, relu_op)
1227 elif check_quantized_tens_scaling_equal(ifm, ofm):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001228 # No identity multiplication is needed
1229 fm_id = ifm
1230 else:
1231 # Add multiplication with identity
1232 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1233 mul_identity.add_input_tensor(ifm)
1234 # Create const tensor containing identity as scalar
1235 quantization = ifm.quantization.clone()
1236 quantization.min = 0
1237 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001238 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001239 quantization.zero_point = 0
1240 identity_tens = create_const_tensor(
1241 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
1242 )
1243 mul_identity.add_input_tensor(identity_tens)
1244 # Make sure that fm_id is allocated to a different address than fm_alpha
1245 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1246 mul_identity.set_output_tensor(fm_id)
1247 mul_identity.set_ifm_ofm_shapes()
1248 DebugDatabase.add_optimised(op, mul_identity)
1249
1250 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001251 op.type = new_op
1252 op.name = op.name.replace("LeakyRelu", new_op.name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001253 op.inputs = []
1254 ifm.consumer_list.remove(op)
1255 op.add_input_tensor(fm_alpha)
1256 op.add_input_tensor(fm_id)
1257 op.set_ifm_ofm_shapes()
1258
1259 DebugDatabase.add_optimised(op, op)
1260 return op
1261
1262
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001263def convert_to_lut8(op, fn, fn_name):
1264 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
1265 # fn is a function(real) -> real
1266 ifm, ofm = op.get_ifm_ofm()
1267 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1268 return op
1269 # Generate the LUT
1270 ifm_scale = np.double(ifm.quantization.scale_f32)
1271 ofm_scale = np.double(ofm.quantization.scale_f32)
1272 zp_in = ifm.quantization.zero_point
1273 zp_out = ofm.quantization.zero_point
1274 values = []
1275 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1276 quantized_min = min(ix)
1277 quantized_max = max(ix)
1278 for x in ix:
1279 x_real = ifm_scale * (x - zp_in)
1280 y_real = fn(x_real)
1281 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1282 lut_result = min(quantized_max, max(quantized_min, lut_result))
1283 values.append(lut_result)
1284 return convert_to_lut(op, values, fn_name)
1285
1286
1287def convert_lrelu_to_lut(op, arch):
1288 ifm, ofm = op.get_ifm_ofm()
1289 # Generate the LUT
1290 alpha = op.attrs["alpha"]
1291 ifm_scale = np.double(ifm.quantization.scale_f32)
1292 ofm_scale = np.double(ofm.quantization.scale_f32)
1293 zp_in = ifm.quantization.zero_point
1294 zp_out = ofm.quantization.zero_point
1295 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1296 alpha_scalar = 1
1297 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1298 if "alpha_scaling" in op.attrs:
1299 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1300 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1301 values = []
1302 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1303 quantized_min = min(ix)
1304 quantized_max = max(ix)
1305 for x in ix:
1306 if x < zp_in:
1307 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1308 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1309 )
1310 else:
1311 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1312 lut_result = min(quantized_max, max(quantized_min, lut_result))
1313 values.append(lut_result)
1314 return convert_to_lut(op, values, "lrelu")
1315
1316
1317def convert_lrelu(op, arch, nng):
1318 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
1319 if op.type != Op.LeakyRelu:
1320 return op
1321 ifm, ofm = op.get_ifm_ofm()
1322 if ifm is None or ofm is None:
1323 return op
Fredrik Svedberg36424312022-09-16 09:39:26 +02001324 alpha = op.attrs["alpha"]
1325 if alpha == 0:
1326 # When alpha is 0 the opertion can be converted to a ReLU
1327 op.type = Op.Relu
1328 op.name = op.name.replace("LeakyRelu", op.type.name)
1329 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001330 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1331 # use LUT for int8/uint8
1332 return convert_lrelu_to_lut(op, arch)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001333 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16 and alpha > 0:
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001334 # use LeakyRelu unmodified for int16 with equal input/output scaling and positive alpha
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001335 return op
1336 return convert_lrelu_to_mul_max(op, arch)
1337
1338
1339def convert_tanh_sigmoid_to_lut(op, arch, nng):
1340 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
1341 if op.type == Op.Sigmoid:
1342 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1343 elif op.type == Op.Tanh:
1344 return convert_to_lut8(op, math.tanh, "tanh")
1345 return op
1346
1347
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001348def remove_memory_only_ops(op, arch):
1349 if op.run_on_npu and op.type in memory_only_ops:
1350 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001351
1352
1353def fuse_activation_function_with_prev(op, arch, nng):
1354 # if op is a no-op: attempts to move the activation function to the preceding op
1355 if not op.attrs.get("is_nop", False) or op.activation is None:
1356 return op
1357 ifm, ofm = op.get_ifm_ofm()
1358 if ifm is None or ofm is None:
1359 return op
1360 # finds the input(s) to the operation
1361 prev_op = ifm.ops[0]
1362 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1363 fuse = (
1364 prev_op.run_on_npu
1365 and prev_op.type.npu_block_type != NpuBlockType.Default
1366 and len(ifm.ops) == 1
1367 and len(prev_op.outputs[0].consumers()) == 1
1368 and prev_op.activation is None
1369 )
1370 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1371 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1372 # LUT currently only works correctly for elementwise ops
1373 fuse = False
1374 if not fuse:
1375 return op
1376 # Move the fused activation function + corresponding info to prev_op
1377 prev_op.activation = op.activation
1378 prev_op.forced_output_quantization = op.forced_output_quantization
1379 if op.activation_lut is not None:
1380 prev_op.set_activation_lut(op.activation_lut)
1381 # Bypass op
1382 prev_op.set_output_tensor(ofm)
1383 DebugDatabase.add_optimised(op, prev_op)
1384 return op
1385
1386
1387def _leading_pad_ok(leading_pad, stride, kernel_size):
1388 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1389 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1390 max_size = kernel_size // 2
1391 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1392
1393
1394def replace_pad_by_hw_pad(op: Operation, arch, nng):
1395 """
1396 Tries to completely remove a PAD operator by using hardware padding.
1397 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1398 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1399 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1400 if both operations can be run on the NPU.
1401 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1402 """
1403 if (
1404 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +00001405 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001406 and op.run_on_npu
1407 and op.attrs["padding"] == Padding.VALID
1408 ):
1409 pad_op = op.ifm.ops[0]
1410 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1411 return op
1412 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1413 return op
1414 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1415 k = op.kernel
1416 k_w, k_h = k.dilated_wh()
1417
1418 # Check if the PAD operator can be replaced by hardware padding
1419 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1420 # Too much padding, it would require hardware padding to actually insert zeros
1421 return op
1422 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1423 return op
1424
1425 if op.type.is_avgpool_op():
1426 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1427 for pad, k_size in (
1428 (left, k_w),
1429 (right, k_w),
1430 (top, k_h),
1431 (bottom, k_h),
1432 ):
1433 if pad not in (0, k_size // 2):
1434 return op
1435 # Average pool is converted to depthwise, because NPU average pool + same padding
1436 # has a special implementation that is different from PAD followed by average pool with
1437 # valid padding.
1438 k_w, k_h = op.kernel.width, op.kernel.height
1439 ifm = op.ifm
1440 # Remember other inputs
1441 other_inputs = op.inputs[1:]
1442 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1443 quantization = QuantizationParameters(0.0, 255.0)
1444 quantization.scale_f32 = 1.0 / (k_w * k_h)
1445 quantization.zero_point = 0
1446 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1447 weights = np.full(shape, 1)
1448
1449 weight_tens = create_const_tensor(
1450 op.name + "_weights",
1451 shape,
1452 op.ifm.dtype,
1453 weights,
1454 np.uint8,
1455 purpose=TensorPurpose.Weights,
1456 quantization=quantization,
1457 )
James Peet7519d502021-07-19 16:47:58 +01001458 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001459 op.type = Op.DepthwiseConv2DBias
1460 op.inputs = []
1461 op.add_input_tensor(ifm)
1462 op.add_input_tensor(weight_tens)
1463 # Add bias tensor, all biases set to 0
1464 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001465 fixup_bias_tensors(op, arch, nng, DataType.int32)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001466 # Add other inputs
1467 op.inputs.extend(other_inputs)
1468 op.rounding_mode = NpuRoundingMode.NATURAL
1469
1470 # Bypass the PAD operator
1471 op.set_input_tensor(pad_op.ifm, 0)
1472 # Adjust the padding attributes of the convolution operator
1473 op.attrs["padding"] = Padding.EXPLICIT
1474 op.attrs["explicit_padding"] = (top, left, bottom, right)
1475 op.set_ifm_ofm_shapes()
1476 return op
1477
1478
1479def convert_pad(op: Operation, arch, nng):
1480 """
1481 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1482 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1483 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1484 """
1485 if op.type != Op.Pad or not op.run_on_npu:
1486 return op
1487 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1488
1489 ifm = op.ifm
1490 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001491 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001492 ofm = op.ofm
1493 assert ofm is not None
1494 ofm.ops = []
1495 ofm_shape = op.ofm_shapes[0]
1496
1497 # Average pool op that copies IFM to the right place inside the OFM
1498 shp0 = Shape4D(0, 0, 0, 0)
1499 shp_top = shp0.with_height(top)
1500 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1501 avgpool_op.activation = op.activation
1502 quant = ofm.quantization
1503 pad_value = quant.zero_point
1504 # Add operations that fill the borders of the OFM
1505 if top > 0:
1506 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1507 zero_tens = create_const_tensor(
1508 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1509 )
1510 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1511 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1512 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1513 if bottom > 0:
1514 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1515 zero_tens = create_const_tensor(
1516 op.name + "_bottom",
1517 shape.as_list(),
1518 ofm.dtype,
1519 shape.elements() * [pad_value],
1520 np.uint8,
1521 quantization=quant,
1522 )
1523 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1524 create_avg_pool_for_concat(
1525 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1526 )
1527 if left > 0:
1528 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1529 zero_tens = create_const_tensor(
1530 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1531 )
1532 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1533 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1534 if right > 0:
1535 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1536 zero_tens = create_const_tensor(
1537 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1538 )
1539 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1540 create_avg_pool_for_concat(
1541 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1542 )
1543
1544 op.type = Op.ConcatTFLite
1545 return avgpool_op
1546
1547
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001548def fixup_bias_tensors(op, arch, nng, dtype=None):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001549 if op.type.needs_bias() and op.bias is None:
1550 # Op has no bias, add bias tensor filled with zeros
1551 nr_biases = op.inputs[1].shape[-1]
1552 bias_values = [0] * nr_biases
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001553 # The DataType of the bias tensor can be explicitly provided or deduced from the ifm
1554 # DataType. Default is int32 bias for 8-bit ifms and int64 for int16 ifms.
1555 # For int16 the selected bias DataType will have an impact on the scaling
1556 # used when encoding the scales and biases later. The default mode will match the
1557 # refence with reduced scaling for int64 bias.
1558 # This means that in cases (in the graph optimiser) where DepthwiseConv2DBias
1559 # is used to emulate average pool int32 bias should be selected for full precision
1560 # int16 scaling.
1561 if dtype is None:
1562 dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32
1563 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001564 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1565
1566 return op
1567
1568
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001569def fixup_asymmetric_weights(op, arch, nng):
1570 if op.run_on_npu and (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op()):
1571 if op.ifm.dtype == DataType.int8:
1572 if not np.all(op.weights.quantization.zero_point == 0):
1573 print(f"Warning: {op.type} '{op.name}' has asymmetric weights, zero points have been adjusted.")
1574 op.weights.quantization.zero_point *= 0
1575
1576 return op
1577
1578
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001579def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1580 if op.type == Op.Mean and op.run_on_npu:
1581 keep_dims = op.attrs.get("keep_dims", False)
1582 inp, axis = op.inputs
1583 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001584 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001585 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001586 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001587
1588 # Height and width axes have different index depending on dimensions
1589 if axis.shape == [] or axis.shape[0] == 1: # single axis
1590 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1591 if dims in (2, 3):
1592 if axis == 0:
1593 h, w = shape[axis], 1
1594 else:
1595 h, w = 1, shape[axis]
1596 else:
1597 if axis == 1:
1598 h, w = shape[axis], 1
1599 else:
1600 h, w = 1, shape[axis]
1601 else: # multiple axes
1602 axis = sorted(axis.values)
1603 h, w = [shape[i] for i in axis]
1604
1605 # Set necessary depthwise attributes
1606 op.attrs.update(
1607 {
1608 "padding": Padding.VALID,
1609 "stride_h": 1,
1610 "stride_w": 1,
1611 "strides": (1, 1, 1, 1),
1612 "depth_multiplier": 1,
1613 "channel_multiplier": 1,
1614 "dilation_h_factor": 1,
1615 "dilation_w_factor": 1,
1616 "dilation": (1, 1, 1, 1),
1617 }
1618 )
1619 # Change op type
1620 op.type = Op.DepthwiseConv2DBias
1621 # Set IFM/OFM shapes after changing op type
1622 op.set_ifm_ofm_shapes()
1623
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001624 weight_scale, bias = 1, 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001625 ofmq, ifmq = op.ofm.quantization, inp.quantization
1626 # Set rounding mode, scaling and zero point based on which reference implementation to match
1627 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1628 if inp.dtype == DataType.uint8:
1629 # This attribute means a different scaling calculation is used in order to match reference
1630 op.low_precision_scaling = True
1631 weight_scale = h * w
1632 # Set zero points to 0 as they will be adjusted for with bias term
1633 foq = ofmq.clone()
1634 foq.zero_point = 0
1635 fiq = ifmq.clone()
1636 fiq.zero_point = 0
1637 op.forced_input_quantization = fiq
Johan Alfvén17009392022-08-30 09:14:56 +02001638 bias_term = ofmq.zero_point - round_up_to_int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001639 # If the bias term is outside uint8 range, we need an Add op to apply it.
1640 if bias_term < 0 or bias_term > 255:
1641 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1642 # Bias term has higher bitness (i32) than input/output (u8).
1643 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1644 # the bias can only effectively assume values in the range [-255, 255].
1645 intermediate.dtype = DataType.int16
1646 intermediate.quantization.zero_point = 0
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001647 add_op = Operation(Op.Add, f"{op.name}_bias")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001648 add_op.forced_output_quantization = foq
1649 add_op.add_input_tensor(intermediate)
1650 quant = QuantizationParameters()
1651 quant.zero_point = 0
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001652 bias_scalar = create_const_tensor(add_op.name, [], DataType.int16, [bias_term], quantization=quant)
1653 add_op.add_input_tensor(bias_scalar)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001654 add_op.set_output_tensor(op.ofm)
1655 add_op.set_ifm_ofm_shapes()
1656 add_op.activation = op.activation
1657 op.activation = None
1658 op.set_output_tensor(intermediate)
1659 op.set_ifm_ofm_shapes()
1660 # If not, we can just do it with the OFM zero point.
1661 else:
1662 foq.zero_point = bias_term
1663 op.forced_output_quantization = foq
1664 else:
1665 assert inp.dtype == DataType.int8
1666 # Use a depthwise to calculate the sum,
1667 # followed by a multiplication with 1/N to get the MEAN
1668 weight_scale = 1
1669 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
Johan Alfvén05916632022-09-06 20:33:22 +02001670 intermediate.dtype = DataType.int32
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001671 mul_op = Operation(Op.Mul, op.name + "_mul")
1672 mul_op.add_input_tensor(intermediate)
Johan Alfvén05916632022-09-06 20:33:22 +02001673 mul_op.set_output_tensor(op.ofm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001674 # Create scalar containing 1/N
1675 quant = QuantizationParameters()
1676 quant.zero_point = 0
1677 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1678 # while rounding mode NATURAL would round this to -1.
1679 # This can only occur if N is even, and can be emulated by
1680 # multiplying with a number that is slightly smaller than 1/N.
1681 # It must be so small that other roundings are not affected;
1682 # the calculated value is based on worst case,
1683 # which is sum 256 * N (the maximum sum that can occur with int8)
1684 n = int(h * w)
1685 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1686 quant.scale_f32 = 1 / (n - eps)
Johan Alfvén05916632022-09-06 20:33:22 +02001687
1688 # For int8/int16 we could use IFM/OFM scaling to do the division
1689 # intermediate * 1 -> scale > round and shift.
1690 #
1691 # For int32 scaling is not supported so instead multiply with the scale
1692 # intermediate * scale -> round and shift.
1693 #
1694 # Calculate the scale and shift value. const Tensor must be created
1695 # with correct quantization since the scale and shift is calculated later
1696 # in the command stream generator.
1697 mul_scale, _ = scaling.elementwise_mul_scale(
1698 mul_op.ifm.quantization.scale_f32, quant.scale_f32, mul_op.ofm.quantization.scale_f32
1699 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001700 scalar = create_const_tensor(
Johan Alfvén05916632022-09-06 20:33:22 +02001701 op.name + "_scalar", [1, 1, 1, 1], DataType.int32, [mul_scale], np.int32, quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001702 )
1703 mul_op.add_input_tensor(scalar)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001704 mul_op.set_ifm_ofm_shapes()
1705 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1706 mul_op.activation = op.activation
1707 op.activation = None
1708 op.set_output_tensor(intermediate)
1709 op.set_ifm_ofm_shapes()
1710 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1711 # Here we can just use a simple AvgPool with truncating rounding,
1712 # as we're emulating simple integer division.
1713 op.rounding_mode = NpuRoundingMode.TRUNCATE
1714 op.type = Op.AvgPool
1715 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1716 else:
1717 op.rounding_mode = NpuRoundingMode.NATURAL
1718 weight_scale = 1 / (h * w)
1719 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1720 bias = -ifmq.zero_point * h * w
1721 fiq = ifmq.clone()
1722 fiq.zero_point = 0
1723 op.forced_input_quantization = fiq
1724
1725 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001726 def extend_dims(dim, in_shape):
1727 if dim < 4:
1728 in_shape = [1] + in_shape
1729 if dim == 2:
1730 in_shape += [1]
1731 return in_shape
1732
1733 if dims < 4 or dims_ofm < 4:
1734 # Fix the ofm dimension when keep_dims is false
1735 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1736 if isinstance(axis, int) and dims_ofm + 1 == dims:
1737 ofm_shape.insert(axis, 1)
1738 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1739 for i in axis:
1740 ofm_shape.insert(i, 1)
1741 shape = extend_dims(dims, shape)
1742 dims_ofm = len(ofm_shape)
1743 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1744 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001745
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001746 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001747 weight_shape = None
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001748 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001749 # This can only happen and be done for multiple axes, and
1750 # h * w <= 256 for DepthwiseConv2DBias
1751 # h * w <= 4096 for AvgPool
1752 # which is checked in supported ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001753 shape = [shape[0], 1, h * w, shape[3]]
1754 op.ifm_shapes[0] = Shape4D(shape)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001755 weight_shape = [1, h * w, shape[3], shape[0]]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001756 if h > 256 and op.type == Op.AvgPool:
1757 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1758
1759 # If the AvgPool version is used, we don't need to do anything else
1760 if op.type == Op.AvgPool:
1761 return op
1762
1763 # Make unit weight tensor quantization
1764 weight_quant = ifmq.clone()
1765 weight_quant.min = 0
1766 weight_quant.max = 255
1767 weight_quant.scale_f32 = weight_scale
1768 weight_quant.zero_point = 0
1769
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001770 if weight_shape is None:
1771 # Set weight shape to [H,W,C,B]
1772 weight_shape = [h, w, shape[3], shape[0]]
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001773
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001774 # Add unit weight tensor
1775 op.set_input_tensor(
1776 create_const_tensor(
1777 "weights",
1778 weight_shape,
1779 inp.dtype,
1780 np.ones(weight_shape),
1781 value_dtype=np.uint8,
1782 quantization=weight_quant,
1783 ),
1784 1,
1785 )
James Peet7519d502021-07-19 16:47:58 +01001786 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001787
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001788 # Add bias tensor
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001789 bias_shape = [shape[-1]]
1790 op.inputs.append(create_const_tensor("bias", bias_shape, DataType.int32, np.ones(bias_shape) * bias))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001791
1792 return op
1793
1794
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001795def optimise_quantize(op: Operation, arch, nng):
1796
1797 if op.type == Op.Quantize and op.run_on_npu:
1798
1799 ifm, ofm = op.get_ifm_ofm()
1800 input_values = ifm.values
1801
1802 # Guard clause - input not const or no values to quantize
1803 if ifm.ops[0].type != Op.Const or input_values is None:
1804 return op
1805
1806 # Singular val in numpy array, convert to indexable array
1807 if input_values.ndim == 0:
1808 input_values = np.array([input_values])
1809
Fredrik Svedberg11563172022-07-06 14:54:12 +02001810 # requantized int8 to int8 or int16 to int16
1811 if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001812
1813 # scale needs to use double precision to match TFLite reference kernel
1814 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
1815 effective_multiplier, effective_shift = quantise_scale(effective_scale)
1816
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001817 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001818 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001819 input_val = val - ifm.quantization.zero_point
1820
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001821 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
1822 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001823
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001824 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
1825 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001826
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001827 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
1828 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001829
1830 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001831 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001832
1833 quantized_vals = []
1834 for val in input_values:
1835
1836 # Derive quantized value
1837 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001838 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
1839 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001840
1841 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001842 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
1843
1844 # Unsupported data type
1845 else:
1846 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001847
1848 # Make quantize op const and disconnect from parent node
1849
1850 # Remove reference of the current quant op from the parent tensor's consumer list
1851 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1852
1853 # Clear any references to parent node
1854 op.inputs = []
1855
1856 # Convert this quantize op to const
1857 op.type = Op.Const
1858
1859 return op
1860
1861
Ayaan Masood4965fae2022-06-29 11:30:57 +01001862def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
1863 """Static optimisation for SHAPE operator output value known at compile time"""
1864
1865 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
1866
1867 if op.type == Op.Shape and op.run_on_npu:
1868
1869 ifm, ofm = op.get_ifm_ofm()
1870
1871 if len(ifm.shape) != ofm.shape[0]:
1872 return op
1873
1874 # Remove reference of the current shape op from the parent tensor's consumer list
1875 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1876
1877 # Clear any references to parent node
1878 op.inputs = []
1879
1880 # Convert this SHAPE op to const
1881 op.type = Op.Const
1882
1883 # Add size calculation to shape output tensors
1884 ofm.values = np.array(ifm.shape)
1885
1886 return op
1887
1888
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001889def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001890 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001891 return op
1892
1893
1894def tflite_optimise_graph(nng, arch):
Fredrik Svedberg11563172022-07-06 14:54:12 +02001895 # Compile time static optimisations
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001896 optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
1897
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001898 for idx, sg in enumerate(nng.subgraphs):
1899 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001900 nng,
1901 sg,
1902 arch,
1903 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01001904 optimisation_list,
1905 rewrite_unsupported=False,
1906 )
1907
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001908 # Pre-processing step
1909 pre_process_list = [
1910 supported_operator_check,
1911 set_ifm_ofm_op_shapes,
1912 ]
1913
Ayaan Masood4965fae2022-06-29 11:30:57 +01001914 for idx, sg in enumerate(nng.subgraphs):
1915 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1916 nng,
1917 sg,
1918 arch,
1919 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02001920 pre_process_list,
1921 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001922 )
1923
1924 # Handle Concat Ops
1925 for idx, sg in enumerate(nng.subgraphs):
1926 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1927 sg.refresh_after_modification()
1928
1929 # Handle Split Ops
1930 for idx, sg in enumerate(nng.subgraphs):
1931 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1932 nng,
1933 sg,
1934 arch,
1935 [],
1936 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1937 rewrite_unsupported=False,
1938 )
1939
1940 for idx, sg in enumerate(nng.subgraphs):
1941 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001942 nng,
1943 sg,
1944 arch,
1945 [rewrite_split_ops],
1946 [],
1947 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001948 )
1949
1950 # Handle sg input output
1951 for idx, sg in enumerate(nng.subgraphs):
1952 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001953 nng,
1954 sg,
1955 arch,
1956 [],
1957 [fix_sg_input_output],
1958 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001959 )
1960
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001961 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001962 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001963 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001964 sg.refresh_after_modification()
1965
1966 # Rewrite of operators
1967 op_rewrite_list = [
1968 set_tensor_equivalence,
1969 convert_mean_to_depthwise_conv_or_avgpool,
1970 convert_depthwise_to_conv,
1971 convert_conv_to_fc,
1972 convert_softmax,
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001973 convert_prelu,
Fredrik Svedberg36424312022-09-16 09:39:26 +02001974 convert_mul_max_to_abs_or_lrelu,
1975 convert_lrelu,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001976 optimise_strided_conv,
1977 convert_hardswish_to_lut,
1978 rewrite_fully_connected_input,
1979 convert_batched_fc_shape,
1980 fixup_conv2d_backprop,
1981 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001982 reorder_depthwise_weights,
Tim Hall885033b2022-07-21 11:46:03 +01001983 fixup_resize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001984 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001985 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001986 convert_tanh_sigmoid_to_lut,
1987 replace_pad_by_hw_pad,
1988 ]
1989
1990 for idx, sg in enumerate(nng.subgraphs):
1991 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001992 nng,
1993 sg,
1994 arch,
1995 [],
1996 op_rewrite_list,
1997 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001998 )
1999
2000 for idx, sg in enumerate(nng.subgraphs):
2001 # remove passthrough tensors and attempt further optimizations
2002 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
2003 nng,
2004 sg,
2005 arch,
2006 [remove_passthrough_tensor],
2007 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
2008 )
2009
2010 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
2011 # since ifm/ofm_shapes are of importance to this function
2012 for sg in nng.subgraphs:
2013 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
2014 sg.refresh_after_modification()
2015
2016 return nng