blob: 574d298ae2b1a918813d4c9fdad8d7fd842eda82 [file] [log] [blame]
Rickard Bolinfea15162022-07-04 16:19:16 +00001# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from . import rewrite_graph
26from . import scaling
27from .api import NpuRoundingMode
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020028from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020029from .data_type import DataType
30from .debug_database import DebugDatabase
31from .errors import UnsupportedFeatureError
32from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020033from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020034from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020035from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020036from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020037from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020038from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020039from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020040from .graph_optimiser_util import needed_total_padding
41from .graph_optimiser_util import set_ifm_ofm_op_shapes
42from .graph_optimiser_util import set_tensor_equivalence
43from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020044from .numeric_util import round_away_zero
Johan Alfvén17009392022-08-30 09:14:56 +020045from .numeric_util import round_up_to_int
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020046from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020047from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020048from .operation import NpuBlockType
49from .operation import Op
50from .operation import Operation
51from .operation import Padding
52from .operation_util import create_avgpool_nop
53from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010054from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020055from .shape4d import Shape4D
56from .softmax import SoftMax
57from .tensor import check_quantized_tens_scaling_equal
58from .tensor import create_const_tensor
59from .tensor import create_equivalence_id
60from .tensor import QuantizationParameters
61from .tensor import Tensor
62from .tensor import TensorPurpose
63from .tflite_mapping import optype_to_builtintype
64
65passthrough_nodes = (Op.Identity,)
66
67
68def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
69 """Creates an average pool for the given concat op/input feature map"""
70 ofm = concat_op.ofm
71 avgpool_op = create_avgpool_nop(name)
72 avgpool_op.inputs = [ifm]
73 avgpool_op.outputs = [ofm]
74
75 avgpool_op.write_offset = write_offset
76 avgpool_op.write_shape = ifm_shape
77 ofm.ops.append(avgpool_op)
78 DebugDatabase.add_optimised(concat_op, avgpool_op)
79 avgpool_op.ifm_shapes.append(ifm_shape)
80 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
81 avgpool_op.memory_function = Op.ConcatSliceWrite
82 return avgpool_op
83
84
85def remove_passthrough_tensor(tens, arch, nng):
86 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
87 assert len(tens.ops[0].inputs) == 1
88 tens = tens.ops[0].inputs[0]
89 return tens
90
91
92def rewrite_concat_ops(op, arch):
93 if not op.run_on_npu or not op.type.is_concat_op():
94 return
95
96 axis_4D = 0
97 ofm = op.ofm
98 ofm.ops = []
99 offset = 0
100
101 unfuse_activation_function(op)
102
103 if op.type == Op.Pack:
104 # Pack is also referred to as Stack
105 axis = int(op.attrs["axis"])
106 if axis < 0: # Convert to positive axis
107 axis = len(op.inputs[0].shape) + 1 + axis
108
109 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
110
111 axis_4D = axis + (4 - len(desired_shape))
112
113 for idx, inp in enumerate(op.inputs):
114 op.ifm_shapes[idx] = Shape4D(desired_shape)
115 op.type = Op.PackReshaped
116
117 inputs, axis = op.get_concat_inputs_axis()
118 for idx, inp in enumerate(inputs):
119 if op.type != Op.PackReshaped:
120 op.ifm_shapes[idx] = Shape4D(inp.shape)
121 if axis >= 0:
122 axis_4D = axis + (4 - len(inp.shape))
123 else:
124 axis_4D = axis
125 write_offset = [0, 0, 0, 0]
126 write_offset[axis_4D] = offset
127 concat_end = offset + op.ifm_shapes[idx][axis_4D]
128 create_avg_pool_for_concat(
129 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
130 )
131 offset = concat_end
132 assert ofm.shape[axis] == offset
133
134 return op
135
136
137def rewrite_split_ops(tens, arch, nng):
138
139 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
140 split_op = tens.ops[0]
141
142 # Not supported so leave it and run on CPU
143 if not split_op.run_on_npu:
144 return tens
145
146 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
147
148 tens.ops = []
149 new_op = Operation(Op.SplitSliceRead, split_op.name)
150 new_op.inputs = [inp]
151 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000152 if None in (offset_end, offset_start):
153 read_shape = None
154 else:
155 # the read shape is relative to each start offset
156 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200157
158 # For Split the offset cannot be extracted from the tensor so it has to
159 # be calculated from the index of the output tensor
160 if axis is not None:
161 # Get the start and end of the split
162 offset_start = [0] * 4
163 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
164 for idx, out in enumerate(outputs):
165 if axis_4D_list is not None:
166 axis_4D = axis_4D_list[idx]
167 else:
168 split_op.ofm_shapes[idx] = Shape4D(out.shape)
169 if axis >= 0:
170 axis_4D = axis + (4 - len(out.shape))
171 else:
172 axis_4D = axis
173
174 if out == tens:
175 ofm_shape_idx = idx
176 read_shape = split_op.ofm_shapes[idx]
177 break
178
179 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
180
181 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
182 new_op.read_shapes[0] = read_shape
183 new_op.run_on_npu = True
184 new_op.set_output_tensor(tens)
185 new_op.ifm_shapes.append(Shape4D(inp.shape))
186 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
187 DebugDatabase.add_optimised(split_op, new_op)
188
189 return tens
190
191
192def remove_SplitSliceRead(op, arch):
193
194 if op.type == Op.SplitSliceRead:
195 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
196 if (
197 len(op.ofm.consumer_list) == 1
198 and op.ofm.consumer_list[0] is not None
199 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200200 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200201 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
202 ):
203 # SplitSliceRead can be performed by tensor consumer
204 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200205 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200206 else:
207 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
208 avgpool_op.add_input_tensor(op.ifm)
209 avgpool_op.outputs = [op.ofm]
210 op.ofm.ops.remove(op)
211 op.ofm.ops.append(avgpool_op)
212 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
213 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
214 avgpool_op.read_offsets[0] = op.read_offsets[0]
215 avgpool_op.read_shapes[0] = op.read_shapes[0]
216
217 op.ifm.consumer_list.remove(op)
218 DebugDatabase.add_optimised(op, avgpool_op)
219
220
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200221def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
222 k_w, k_h = kernel.dilated_wh()
223 s_x, s_y = kernel.stride
224 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
225 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
226 if padding_type == Padding.SAME:
227 left_pad = (xpad + 0) // 2
228 right_pad = (xpad + 1) // 2
229 top_pad = (ypad + 0) // 2
230 bottom_pad = (ypad + 1) // 2
231 elif padding_type == Padding.VALID:
232 left_pad = 0
233 right_pad = 0
234 top_pad = 0
235 bottom_pad = 0
236 elif padding_type == Padding.EXPLICIT:
237 # Padding is specified in a PAD operator which has been bypassed.
238 top, left, bottom, right = explicit_padding
239 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
240 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Rickard Bolin9ae34552022-06-09 13:07:17 +0000241 elif padding_type == Padding.TILE:
242 # The values in the explicit padding only represent the "direction" in which to pad
243 top_pad, left_pad, bottom_pad, right_pad = explicit_padding
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200244 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000245 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200246 padding = (top_pad, left_pad, bottom_pad, right_pad)
247 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
248 return padding, skirt
249
250
251def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
252 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
253 if padding_type == Padding.SAME:
254 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
255 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
256 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
257 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
258 left_pad = max(kernel_width - 1 - right_pad, 0)
259 top_pad = max(kernel_height - 1 - bottom_pad, 0)
260 elif padding_type == Padding.VALID:
261 right_pad = max(kernel_width - 2, 0)
262 bottom_pad = max(kernel_height - 2, 0)
263 left_pad = kernel_width - 1
264 top_pad = kernel_height - 1
265 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000266 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200267 padding = (top_pad, left_pad, bottom_pad, right_pad)
268 skirt = padding
269 return padding, skirt
270
271
272def fixup_conv2d_backprop(op, arch, nng):
273 if op.type == Op.Conv2DBackpropInput:
274 # flip the inputs
275 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
276 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000277 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200278
279 # Update strides
280 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
281
282 return op
283
284
285# Convert the op to an elementwise add
Tim Hall885033b2022-07-21 11:46:03 +0100286def convert_resize_1x1_to_add(op):
287 op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200288 op.name = op.name + "_add"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200289 # Create an input tensor filled with zeros
290 shape = op.ofm_shapes[0].as_list()
291 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100292 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200293 tens.quantization = QuantizationParameters(0.0, 255.0)
294 tens.quantization.scale_f32 = 1.0
295 tens.quantization.zero_point = 0
296 tens.consumer_list = [op]
297 tens_op = op.inputs[1].ops[0]
298 tens_op.set_output_tensor(tens)
299 # Set the add inputs
300 op.inputs[1] = op.inputs[0]
301 op.inputs[0] = tens
302 op.set_ifm_ofm_shapes()
303
304 return op
305
306
Tim Hall885033b2022-07-21 11:46:03 +0100307# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
308# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
309# to select the appropriate nearest neighbor value
310def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
311 ifm = op.ifm
312 ofm = op.ofm
313 output_depth = ofm.shape[-1]
314 dw_op_attrs = {
315 "padding": Padding.VALID,
316 "stride_h": 1,
317 "stride_w": 1,
318 "strides": (1, 1, 1, 1),
319 "depth_multiplier": 1,
320 "channel_multiplier": 1,
321 "dilation_h_factor": 1,
322 "dilation_w_factor": 1,
323 "dilation": (1, 1, 1, 1),
324 }
325
326 # change resizebilinear to depthwise
327 op.type = Op.DepthwiseConv2DBias
328 op.attrs.update(dw_op_attrs)
329 op.set_input_tensor(ifm, 0) # ifm tensor index
330 op.activation = None
331
332 # add input resample to resize by x2
333 op.ifm_resampling_mode = resampling_mode.NEAREST
334
335 # don't care about the rounding mode as it is nearest neighbor
336
337 # setup weight tensor
338 weight_quant = QuantizationParameters()
339 weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
340 weight_quant.zero_point = 0
341 weight_quant.quant_dim = 0
342 ofm_dtype = ofm.dtype
343 if ofm_dtype == DataType.uint8:
344 weight_value_dtype = np.uint8
345 weight_quant.quant_min = 0
346 weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
347 else:
348 if ofm_dtype == DataType.int8:
349 weight_value_dtype = np.int8
350 else:
351 assert ofm_dtype == DataType.int16
352 weight_value_dtype = np.int16
353
354 weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
355 weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
356
357 weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
358
359 # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
360 # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
361 # below-and-right (i.e. next) to it (D).
362 # 0---1---2
363 # | A | B |
364 # 1---*---+
365 # | C | D |
366 # 2---+---+
367 weight_values = [0] * (upscale_factor * upscale_factor)
368 centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
369 weight_values[centre_coeff] = 1
370
371 # add weight tensor, this will discard the size tensor of the resize op
372 op.set_input_tensor(
373 create_const_tensor(
374 "weights",
375 weight_shape,
376 ofm.dtype,
377 np.array(weight_values).reshape(weight_shape),
378 value_dtype=weight_value_dtype,
379 quantization=weight_quant,
380 ),
381 1, # inputs tensor weight index
382 )
383
384 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
385 # need to append the bias tensor as resize ops only have 2 inputs
386 assert len(op.inputs) == 2
387 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +0200388 fixup_bias_tensors(op, None, None, DataType.int32)
Tim Hall885033b2022-07-21 11:46:03 +0100389
390 # finally update the shape incase we've change the tensor shapes or connections
391 op.set_ifm_ofm_shapes()
392
393 return op
394
395
396# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
397# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
398# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
399def convert_resize_to_upscale_and_average_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200400 pre_op = op
401 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000402 dtype = op.ifm.dtype
Tim Hall885033b2022-07-21 11:46:03 +0100403
Rickard Boline546def2022-01-25 15:45:00 +0000404 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Tim Hall47c76362022-07-18 21:26:47 +0100405 op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
Tim Hall3c5cfe92022-03-16 16:31:57 +0000406 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200407
408 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
Tim Hall47c76362022-07-18 21:26:47 +0100409
410 # Get upscale factor that was calculated in the supported operators check
411 upscale_factor = op.attrs["upscale_factor"]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200412
Rickard Boline546def2022-01-25 15:45:00 +0000413 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100414 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
415 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
Rickard Boline546def2022-01-25 15:45:00 +0000416 n = int(np.log2(upscale_factor))
417
Tim Hall885033b2022-07-21 11:46:03 +0100418 # Perform x2 upscaling n-1 times
Rickard Boline546def2022-01-25 15:45:00 +0000419 scaled_op = pre_op
420 for count in range(n - 1):
421 if count > 0:
422 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200423 scaled_op.inputs[0] = pre_op.outputs[0]
424
Tim Hall885033b2022-07-21 11:46:03 +0100425 # Nearest neighbor x2 upscaling
Tim Hall47c76362022-07-18 21:26:47 +0100426 upscaled_shape = upscaled_shape * 2
Rickard Boline546def2022-01-25 15:45:00 +0000427 shape = op.ofm_shapes[0].as_list()
428 shape[1:3] = upscaled_shape
429 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
430 out_tens.quantization = op.outputs[0].quantization.clone()
431 scaled_op.set_output_tensor(out_tens)
432 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200433
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200434 scaled_op.set_ifm_ofm_shapes()
435
Tim Hall885033b2022-07-21 11:46:03 +0100436 # Last x2 upscaling
Rickard Boline546def2022-01-25 15:45:00 +0000437 if n > 1:
438 scaled_op = op.clone(f"_{n-1}")
439 scaled_op.inputs[0] = pre_op.outputs[0]
Tim Hall885033b2022-07-21 11:46:03 +0100440
441 if scaled_op.original_type == Op.ResizeBilinear:
442 if scaled_op.attrs["align_corners"]:
443 # no padding
444 scaled_op.attrs["padding"] = Padding.VALID
445 else:
446 # padding to the right and bottom (limits average pool to 8x8 kernel)
447 scaled_op.attrs["padding"] = Padding.EXPLICIT
448 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
449
450 # kernal size dependent on the upscaling factor
451 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
452 else: # Op.ResizeNearestNeighbor
453 if scaled_op.attrs["align_corners"]:
454 # use depthwise conv to select the correct value
455 scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
456 else:
Johan Alfvéna64616c2022-10-17 12:29:12 +0200457 # Keep 1x1 kernel and average pool, this applies both when
458 # half-pixel-centers is True and False. Calculations are the
459 # same in the reference.
Tim Hall885033b2022-07-21 11:46:03 +0100460 pass
461
Rickard Boline546def2022-01-25 15:45:00 +0000462 scaled_op.outputs = outputs
463 scaled_op.outputs[0].ops = [scaled_op]
464 scaled_op.set_ifm_ofm_shapes()
465
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200466 return op
467
468
Rickard Bolinfea15162022-07-04 16:19:16 +0000469def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
470 def _compute_interpolation_values(index, input_size, output_size):
471 scale = input_size / output_size
472 scaled_value = (index + 0.5 * half_pixel_centers) * scale - 0.5 * half_pixel_centers
473 lower_bound = max(np.floor(scaled_value), 0)
474
475 return scaled_value, lower_bound
476
477 def _compute_kernels(input_height, input_width, output_height, output_width):
478 kernels = []
479 for y in (1, 2):
480 for x in (1, 2):
481 sv_h, lb_h = _compute_interpolation_values(y, input_height, output_height)
482 sv_w, lb_w = _compute_interpolation_values(x, input_width, output_width)
483
484 # Interpolation values calculated for (x, y) = ([1, 2], [1, 2]) will always generalize to the whole
485 # input for upscale = 2 and input sizes >= 2x2 and be in the correct order for going left-to-right,
486 # top-to-bottom - same as the depthwise convolution strides across each tile
487 kernel = np.zeros((2, 2))
488 kernel[1, 1] = (1 - (sv_h - lb_h)) * (1 - (sv_w - lb_w))
489 kernel[0, 1] = (sv_h - lb_h) * (1 - (sv_w - lb_w))
490 kernel[1, 0] = (1 - (sv_h - lb_h)) * (sv_w - lb_w)
491 kernel[0, 0] = (sv_h - lb_h) * (sv_w - lb_w)
492 kernel *= 16
493 kernels.append(kernel)
494
495 return kernels
496
497 def _build_convolutions(op, kernels):
498 dw_op_attrs = {
499 "padding": Padding.TILE,
500 "stride_h": 1,
501 "stride_w": 1,
502 "strides": (1, 1, 1, 1),
503 "depth_multiplier": 1,
504 "channel_multiplier": 1,
505 "dilation_h_factor": 1,
506 "dilation_w_factor": 1,
507 "dilation": (1, 1, 1, 1),
508 }
509 ifm = op.ifm
510 ofm = op.ofm
511 ofm.ops = []
512 elem_size = 2 if ofm.dtype == DataType.int16 else 1
513
514 n, h, w, c = ifm.shape
515 _, _, ow, _ = ofm.shape
516
517 intermediate_tens = Tensor(ifm.shape, ifm.dtype, "intermediate_tens")
518 intermediate_tens.quantization = op.outputs[0].quantization.clone()
519 avgpool_op = op
520 avgpool_op.name = "rb_init_avgpool"
521 avgpool_op.type = Op.AvgPool
522 avgpool_op.attrs["padding"] = Padding.VALID
523 avgpool_op.attrs["stride_w"] = 1
524 avgpool_op.attrs["stride_h"] = 1
525 avgpool_op.attrs["filter_width"] = 1
526 avgpool_op.attrs["filter_height"] = 1
527 avgpool_op.attrs["strides"] = [1, 1, 1, 1]
528 avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
529
530 avgpool_op.add_input_tensor(ifm)
531 avgpool_op.set_output_tensor(intermediate_tens)
532 avgpool_op.set_ifm_ofm_shapes()
533
534 dw_conv = Operation(Op.DepthwiseConv2DBias, "depthwise_conv")
535 dw_conv._original_type = Op.ResizeBilinear
536 dw_conv.write_shape = Shape4D(n, h, w, c)
537 dw_conv.write_offset = Shape4D(0, 0, 0, 0)
538
539 # Set the output rounding mode. Resize bilinear requires rounding away from zero. Therefore, we need to
540 # adjust the accumulated value by a "small" amount before applying natural rounding. The "small" amount
541 # should be big enough to cause a x.5 to be rounded correctly but small enough not to cause smaller
542 # values to be incorrectly rounded
543 ofm.quantization.next_after = True
544 dw_conv.rounding_mode = NpuRoundingMode.NATURAL
545
546 # Double height and width stride to write the output of each of the four depthwise convolutions below
547 # interleaved with each other when combined with OFM tile base offsets.
548 dw_conv.ofm_stride_multiplier = [1, 2, 2] # C/H/W
549
550 # Choose tile padding direction - pad by 1 with edge values in two direction.
551 # For example, TL (top left) will pad top and left in H/W-plane in all channels.
552 directions = [[1, 1, 0, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 1]] # TL, TR, BL, BR
553 for i in (0, 1):
554 for j in (0, 1):
555 index = i * 2 + j
556 dw_conv.name = f"depthwise_conv_{index}"
557 dw_op_attrs["explicit_padding"] = directions[index]
558 dw_conv.attrs.update(dw_op_attrs)
559
560 # This will offset the start of the write by modifying the Tile 0 base address
561 dw_conv.tile_base_offsets_ofm[0] = (i * ow + j) * c * elem_size
562
563 ofm.ops.append(dw_conv)
564 dw_conv.outputs = [ofm]
565
566 kernel = kernels[index]
567 shape = [2, 2, 1, c]
568 kernel = np.dstack([kernel] * c)
569
570 quant = QuantizationParameters()
571 quant.zero_point = 0
572 quant.scale_f32 = 1.0 / 16
573
574 dw_conv.inputs = []
575 dw_conv.add_input_tensor(intermediate_tens)
576 dw_conv.add_input_tensor(
577 create_const_tensor(
578 "weights",
579 shape,
580 intermediate_tens.dtype,
581 np.array(kernel).reshape(shape),
582 value_dtype=np.int8,
583 quantization=quant,
584 ),
585 )
586
587 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
588 # need to append the bias tensor as resize ops only have 2 inputs
589 assert len(dw_conv.inputs) == 2
590 dw_conv.inputs.append(None)
Rickard Bolin017b4cc2022-09-23 10:16:48 +0000591 fixup_bias_tensors(dw_conv, None, None, dtype=DataType.int32)
Rickard Bolinfea15162022-07-04 16:19:16 +0000592
593 dw_conv.set_ifm_ofm_shapes()
594 dw_conv = dw_conv.clone(f"_{index}")
595 return op
596
597 _, input_height, input_width, _ = op.ifm.shape
598 _, output_height, output_width, _ = op.ofm.shape
599
600 kernels = _compute_kernels(input_height, input_width, output_height, output_width)
601 op = _build_convolutions(op, kernels)
602
603 return op
604
605
Tim Hall885033b2022-07-21 11:46:03 +0100606def fixup_resize(op, arch, nng):
607 if op.type.is_resize_op() and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200608 if op.ifm_shapes[0] == op.ofm_shapes[0]:
Tim Hall885033b2022-07-21 11:46:03 +0100609 # Bypass the resize op which is essentially a NOP
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200610 op.inputs = op.inputs[:1]
611 op.type = Op.Identity
612 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
Tim Hall885033b2022-07-21 11:46:03 +0100613 convert_resize_1x1_to_add(op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000614 elif op.type == Op.ResizeBilinear and op.attrs.get("half_pixel_centers", False):
615 convert_resizebilinear_to_depthwise_convolutions(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200616 else:
Tim Hall885033b2022-07-21 11:46:03 +0100617 convert_resize_to_upscale_and_average_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200618
619 return op
620
621
622def convert_nop_split_to_identity(op, arch, nng):
623 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
624 # the list comprehension should return a list with a single tensor
625 # if it shouldn't, remove_passthrough_tensor will fail appropriately
626 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
627 op.type = Op.Identity
628 return op
629
630
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100631def rewrite_fully_connected_input(op: Operation, arch, nng):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200632
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100633 if op.type == Op.FullyConnected:
634 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
635 assert new_shape is not None, "Tensor can not be reshaped to 2D"
636 op.ifm_shapes[0] = new_shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200637 return op
638
639
640def convert_batched_fc_shape(op, arch, nng):
641 if op.type == Op.FullyConnected:
642 # Check if the first dimension indicates batching
643 if op.ifm_shapes[0].batch > 1:
644 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
645 n = op.ifm_shapes[0].batch
646 h, w = batching_split.get(n, (1, n))
647 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
648
649 # Reshape Weights to be 4D. IO becomes HWIO
650 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100651 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
652 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200653
654 n = op.ofm_shapes[0].batch
655 h, w = batching_split.get(n, (1, n))
656 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
657 return op
658
659
660def unfuse_activation_function(op):
661 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
662 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
663 op.activation = None
664 out_tens = op.outputs[0]
665 intermediate_tens = out_tens.clone("_act_intermediate")
666 act_op.set_output_tensor(out_tens)
667 act_op.add_input_tensor(intermediate_tens)
668 op.set_output_tensor(intermediate_tens)
669 act_op.set_ifm_ofm_shapes()
670
671
672def rewrite_stridedslice_output(op, arch, nng):
673 if not op.run_on_npu or op.type != Op.StridedSlice:
674 return op
675
676 new_axis_mask = op.attrs["new_axis_mask"]
677 shrink_axis_mask = op.attrs["shrink_axis_mask"]
678
679 if shrink_axis_mask == 0 and new_axis_mask == 0:
680 return op
681
682 axis_4D = [0] * len(op.outputs)
683 for idx, out_tens in enumerate(op.outputs):
684 output_shape = list(out_tens.shape)
685
686 if shrink_axis_mask != 0:
687 n = 0
688 axis = 0
689 while shrink_axis_mask:
690 prev_mask = shrink_axis_mask
691 n += 1
692 shrink_axis_mask &= shrink_axis_mask - 1
693 axis = int(math.log2(prev_mask - shrink_axis_mask))
694 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
695
696 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
697 op.attrs["shrink_axis_mask"] = 0
698 if axis >= 0:
699 axis_4D[idx] = axis + (4 - len(output_shape))
700 else:
701 axis_4D[idx] = axis
702 op.ofm_shapes[idx] = Shape4D(output_shape)
703
704 elif new_axis_mask != 0:
705 n = 0
706 axis = 0
707 while new_axis_mask:
708 prev_mask = new_axis_mask
709 n += 1
710 new_axis_mask &= new_axis_mask - 1
711 axis = int(math.log2(prev_mask - new_axis_mask))
712 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
713 new_axis_mask >>= 1
714
715 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
716 op.attrs["new_axis_mask"] = 0
717 if axis >= 0:
718 axis_4D[idx] = axis + (4 - len(output_shape))
719 else:
720 axis_4D[idx] = axis
721 op.ofm_shapes[idx] = Shape4D(output_shape)
722
723 op.attrs["split_axis_4D"] = axis_4D
724 return op
725
726
727def rewrite_unpack_output(op, arch, nng):
728 tens = op.outputs[0]
729 if op.run_on_npu and op.type == Op.Unpack:
730 # Unpack is also referred to as Unstack
731 axis = int(op.attrs["axis"])
732 if axis < 0: # Convert to positive axis
733 axis = len(op.inputs[0].shape) + 1 + axis
734 op.type = Op.UnpackReshaped
735 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
736
737 axis_4D = axis + (4 - len(desired_output_shape))
738 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
739
740 for idx, out_tens in enumerate(op.outputs):
741 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
742 return op
743
744
745def add_padding_fields(op, arch, nng):
746 if op.run_on_npu:
747 if "padding" in op.attrs:
748 input_shape = op.ifm_shapes[0]
749 output_shape = op.ofm_shapes[0]
750 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
751 kernel_size = op.inputs[1].shape[:2]
752 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
753 kernel_size = op.attrs["ksize"][1:3]
754 else:
755 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
756
757 if op.type == Op.Conv2DBackpropInputSwitchedBias:
758 upscaling_factor = output_shape.height // input_shape.height
759 padding, skirt = calc_upscaled_padding_and_skirt(
760 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
761 )
762 else:
763 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200764 op.attrs["padding"],
765 op.kernel,
766 input_shape,
767 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200768 )
769
770 op.attrs["explicit_padding"] = padding
771 op.attrs["skirt"] = skirt
772
773 return op
774
775
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200776def reorder_depthwise_weights(op, arch, nng):
777 if op.type.is_depthwise_conv2d_op():
778 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100779 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
780 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200781 weight_tensor.weight_transpose_depthwise = True
782
783 return op
784
785
786def optimise_strided_conv(op, arch, nng):
Louis Verhaard43d27582022-03-17 14:06:00 +0100787 if op.type != Op.Conv2DBias or op.op_index != 0:
788 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200789 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100790 weight_tensor = op.weights
791 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200792
793 if (
Louis Verhaard43d27582022-03-17 14:06:00 +0100794 stride_x == 2
795 and ifm_shape.depth <= 4
796 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200797 and weight_tensor is not None
798 and weight_tensor.shape[1] >= 2
799 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100800 k_w, _ = op.get_kernel_size()
801 curr_padding_x = needed_total_padding(ifm_shape.width, 2, k_w)
802 optimised_padding_x = needed_total_padding(ifm_shape.width // 2, 1, (k_w + 1) // 2)
803 if curr_padding_x != optimised_padding_x:
804 # Horizontal padding would become different after optimisation; this would not work
805 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200806 # IFM
807 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
808
809 # Weights
810 weight_shape = weight_tensor.shape
811 if weight_shape[1] % 2 != 0:
812 weight_shape[1] = weight_shape[1] + 1
813 padded_array = np.zeros(weight_shape)
814 for i in range(weight_shape[0]):
815 padded_array[i] = np.vstack(
816 [
James Peet7519d502021-07-19 16:47:58 +0100817 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200818 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
819 ]
820 )
James Peet7519d502021-07-19 16:47:58 +0100821 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200822 weight_shape[1] //= 2
823 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100824 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200825 weight_tensor.set_all_shapes(weight_shape)
826 # If multiple copies of the weights are used, we could avoid
827 # them having the same address by changing the value_id
828 weight_tensor.value_id = uuid.uuid4()
829
830 # Strides
831 stride_x = 1
832 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
833
834 return op
835
836
837def convert_conv_to_fc(op, arch, nng):
838 # Conv 1x1 can be equivalent to Fully Connected.
839 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
840 # caching/double buffering for the weights.
841 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
842 if op.type == Op.Conv2DBias:
843 h = op.ifm_shapes[0].height
844 w = op.ifm_shapes[0].width
845 kh, kw, _, _ = op.inputs[1].shape
846 if h == 1 and w == 1 and kh == 1 and kw == 1:
847 # Overwrite this op as a Fully Connected Op
848 op.name += "_fc"
849 op.type = Op.FullyConnected
850 op.attrs = {
851 "weights_format": 0,
852 }
853 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
854 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100855 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
856 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200857
858 DebugDatabase.add_optimised(op, op)
859 return op
860
861
862def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
863 if op.run_on_npu and op.type.is_relu_op():
864 ifm = op.inputs[0]
865 ofm = op.outputs[0]
866 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
867 # and requires its own to be inserted
868 if not check_quantized_tens_scaling_equal(ifm, ofm):
869 # Override this op with its own primary op (avgpool)
870 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
871 # And fuse the original activation function to it
872 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200873 # Add explicit rescaling
874 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
875 multiplier, shift = scaling.quantise_scale(rescale)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200876 relu_fused_op.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200877 # Tidy up and assign the ifm and ofm to the new op
878 ifm.consumer_list.remove(op)
879
880 relu_fused_op.add_input_tensor(ifm)
881 relu_fused_op.set_output_tensor(ofm)
882 relu_fused_op.set_ifm_ofm_shapes()
883 op = relu_fused_op
884 return op
885
886
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200887def convert_softmax(op, arch, nng):
888 if op.type == Op.Softmax and op.run_on_npu:
889 softmax = SoftMax(op)
890 op = softmax.get_graph()
891 return op
892
893
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200894def convert_prelu(op, arch, nng):
895 if op.type == Op.Prelu:
896 ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
897 if None in (ifm, alpha, ofm):
898 return op
899
Fredrik Svedberg66591652022-08-29 10:51:27 +0200900 if alpha.values is not None:
901 # If const alpha check for possible optimisations
902 alpha_zp = alpha.quantization.zero_point
903 alpha_scale = alpha.quantization.scale_f32
904 # If all alpha values are the same the PReLU can be converted to LeakyRelu
905 alpha_min = (alpha.values.min().astype(np.int) - alpha_zp) * alpha_scale
906 alpha_max = (alpha.values.max().astype(np.int) - alpha_zp) * alpha_scale
907 if alpha_min == alpha_max:
908 # or even a Relu
909 if alpha_min == 0:
910 new_op = Op.Relu
911 else:
912 new_op = Op.LeakyRelu
913 op.attrs["alpha"] = alpha_min
914 # setup alpha_scaling for bit exact result
915 ifm_scale = ifm.quantization.scale_f32
916 ofm_scale = ofm.quantization.scale_f32
917 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha_scale, ofm_scale)
918 op.attrs["alpha_scaling"] = (alpha.values.min() - alpha_zp, alpha_scale, alpha_shift)
919 # Change op type
920 op.type = new_op
921 op.name = op.name.replace("Prelu", new_op.name)
922 del op.inputs[1] # Remove alpha tensor
923 return op
924 elif alpha_max < 1:
925 # If alpha_max is less than 1 convert PReLU to Max(alpha * IFM, identity * IFM)
926 # Multiply with alpha tensor
927 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
928 mul_alpha.add_input_tensor(ifm)
929 mul_alpha.add_input_tensor(alpha)
930 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
931 mul_alpha.set_output_tensor(fm_alpha)
932 mul_alpha.set_ifm_ofm_shapes()
933 DebugDatabase.add_optimised(op, mul_alpha)
934 if check_quantized_tens_scaling_equal(ifm, ofm):
935 # No scaling is needed
936 fm_id = ifm
937 else:
938 # Add multiplication with identity
939 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
940 mul_identity.add_input_tensor(ifm)
941 # Create const tensor containing identity as scalar
942 quantization = ifm.quantization.clone()
943 quantization.scale_f32 = np.float32(1)
944 quantization.zero_point = 0
945 one = create_const_tensor("one_const", [], ifm.dtype, [1], quantization=quantization)
946 mul_identity.add_input_tensor(one)
947 # Make sure that fm_id is allocated to a different address than fm_alpha
948 fm_id = ofm.clone(op.name + "_id", set_unique=True)
949 mul_identity.set_output_tensor(fm_id)
950 mul_identity.set_ifm_ofm_shapes()
951
952 # Combine scaled and alpha multiplied values
953 max_op = Operation(Op.Maximum, op.name + "_max")
954 max_op.add_input_tensor(fm_alpha)
955 max_op.add_input_tensor(fm_id)
956 max_op.set_output_tensor(ofm)
957 max_op.set_ifm_ofm_shapes()
958
959 DebugDatabase.add_optimised(op, max_op)
960 ifm.consumer_list.remove(op)
961 return max_op
962
963 # Catch all PReLU conversion for the cases that could not be optimised above
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200964 no_scale_quant = ifm.quantization.clone()
965 no_scale_quant.scale_f32 = None
966 no_scale_quant.zero_point = 0
Fredrik Svedberg66591652022-08-29 10:51:27 +0200967 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200968
969 # Select values < 0
970 min_op = Operation(Op.Minimum, op.name + "_min")
971 min_op.add_input_tensor(ifm)
972 min_op.add_input_tensor(zero)
973 fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
974 min_op.set_output_tensor(fm_negative)
975 min_op.set_ifm_ofm_shapes()
976 DebugDatabase.add_optimised(op, min_op)
977
978 # and multiply with alpha tensor
979 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
980 mul_alpha.add_input_tensor(fm_negative)
981 mul_alpha.add_input_tensor(alpha)
982 fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
983 mul_alpha.set_output_tensor(fm_alpha)
984 mul_alpha.set_ifm_ofm_shapes()
985 DebugDatabase.add_optimised(op, mul_alpha)
986
987 # Select (and scale) values > 0
988 relu_op = Operation(Op.Relu, op.name + "_relu")
989 relu_op.add_input_tensor(ifm)
990 fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
991 relu_op.set_output_tensor(fm_scaled)
992 relu_op.set_ifm_ofm_shapes()
993 DebugDatabase.add_optimised(op, relu_op)
994
995 # Add scaled and alpha multiplied values (without scaling)
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200996 add_op = Operation(Op.Add, op.name + "_add")
997 add_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200998 add_op.add_input_tensor(fm_alpha)
999 add_op.add_input_tensor(fm_scaled)
1000 add_op.set_output_tensor(ofm)
1001 add_op.set_ifm_ofm_shapes()
1002
1003 DebugDatabase.add_optimised(op, add_op)
1004 ifm.consumer_list.remove(op)
1005 op = add_op
1006
1007 return op
1008
1009
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001010def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
1011 r"""Whenever there is a subgraph with this topology:
1012
Jonas Ohlssond8575072022-03-30 10:30:25 +02001013 Input X For X = -1 or X > 0
1014 | \ / This subgraph can be replaced with either
1015 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
1016 | /
1017 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001018 """
1019
1020 if op.type == Op.Maximum:
1021 # finds the Mul input(s) to the Max
1022 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
1023 if len(muls) == 1:
1024 mul = muls[0].ops[0]
1025 elif len(muls) == 2:
1026 # In the case both inputs are Muls, find the one with the same input as the Max
Fredrik Svedberg66591652022-08-29 10:51:27 +02001027 mul_ifms = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1]
1028 if len(mul_ifms):
1029 mul = mul_ifms[0].ops[0]
1030 else:
1031 # Not using same input
1032 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001033 else:
1034 # No Mul inputs
1035 return op
1036
1037 # make sure the Mul doesn't have any other consumers
1038 mul_ofm = mul.outputs[0]
1039 if len(mul_ofm.consumers()) != 1:
1040 return op
1041 # make sure the Mul doesn't have a fused activation function
1042 if mul.activation:
1043 return op
1044 ifm, ofm = op.get_ifm_ofm()
1045 if ifm is None or ofm is None:
1046 return op
1047
1048 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1049 return op
1050 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
1051 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
1052 return op
1053
1054 # finds the branched input that goes to both the Max and the Mul
1055 shared = set(op.inputs) & set(mul.inputs)
1056 if len(shared) == 1:
1057 shared_in = shared.pop()
1058 # find the constant scalar input to the Mul
1059 const_tens = (set(mul.inputs) - {shared_in}).pop()
1060 # check that it is a scalar
1061 if const_tens.shape != []:
1062 return op
1063 const = const_tens.ops[0]
1064 # check that it is a constant
1065 if const.type != Op.Const:
1066 return op
1067 # Remove the Mul from the shared input's consumers
1068 shared_in.consumer_list.remove(mul)
1069 else:
1070 return op
1071
1072 val = const.outputs[0].values
1073 if val >= 0:
1074 new_op = Op.LeakyRelu
1075 op.attrs["alpha"] = val
1076 # to produce bit exact results, the alpha is not enough;
1077 # save additional scaling info in attr "alpha_scale", to be used as input
1078 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +01001079 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001080 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
1081 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
1082 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
1083 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
1084 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
1085 elif val == -1:
1086 new_op = Op.Abs
1087 else:
1088 return op
1089
1090 op.type = new_op
1091 op.name = op.name.replace("Maximum", new_op.name)
1092 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
1093 op.inputs = [shared_in]
1094 op.set_ifm_ofm_shapes()
1095
1096 # Record optimisation in debug database
1097 DebugDatabase.add_optimised(op, op)
1098
1099 return op
1100
1101
1102def convert_hardswish_to_lut(op, arch, nng):
1103 if op.type == Op.HardSwish:
1104 ifm, ofm = op.get_ifm_ofm()
1105 # Generate the LUT
1106 ifm_scale = np.double(ifm.quantization.scale_f32)
1107 ofm_scale = np.double(ofm.quantization.scale_f32)
1108 zp_in = ifm.quantization.zero_point
1109 zp_out = ofm.quantization.zero_point
1110 ifm_scale_hires = (1 / 128) * ifm_scale
1111 relu_multiplier = np.double(3 / 32768)
1112 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
1113 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
1114 # Use 16bit scale
1115 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
1116 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
1117
1118 values = []
1119 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1120 quantized_min = min(ix)
1121 quantized_max = max(ix)
1122 for x in ix:
1123 input_value = x - zp_in
1124 input_value_hires = input_value * 128
1125 # Compute the input value on essentially the output scale, not shifted yet
1126 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
1127 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
1128 relu_value = np.int16(input_value_hires)
1129 if relu_shift < 31:
1130 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
1131
1132 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
1133
1134 if relu_shift < 31:
1135 relu_value = fp_math.shift_left16(relu_value, 1)
1136
1137 if relu_shift > 31:
1138 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
1139
1140 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
1141 # Now convert that to a 16bit fixedpoint value in [0, 1]
1142 relu_value = (relu_value + (1 << 15)) >> 1
1143 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
1144 shift = 31 - out_shift
1145 shift = -shift if shift < 0 else 0
1146 # Finally apply the output shift
1147 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
1148 lut_result = min(quantized_max, max(quantized_min, lut_result))
1149 values.append(lut_result)
1150 return convert_to_lut(op, values, "hardswish")
1151 return op
1152
1153
1154def convert_lrelu_to_mul_max(op, arch):
1155 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
1156 # (the opposite of convert_mul_max_to_abs_or_lrelu)
1157 ifm, ofm = op.get_ifm_ofm()
1158 if ifm is None or ofm is None:
1159 return op
1160
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001161 alpha = np.float32(op.attrs["alpha"])
1162 use_mul_max = 0 < alpha < 1
Fredrik Svedberg36424312022-09-16 09:39:26 +02001163 is_converted_prelu = "alpha_scaling" in op.attrs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001164 if use_mul_max:
1165 mul_ifm = ifm
1166 new_op = Op.Maximum
1167 else:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001168 # Need to use a different approach for alpha < 0 or alpha > 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001169 no_scale_quant = ifm.quantization.clone()
1170 no_scale_quant.scale_f32 = None
1171 no_scale_quant.zero_point = 0
1172 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
1173
1174 # Select values < 0
1175 min_op = Operation(Op.Minimum, op.name + "_min")
1176 min_op.add_input_tensor(ifm)
1177 min_op.add_input_tensor(zero)
1178 mul_ifm = ifm.clone(op.name + "_negative", set_unique=True)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001179 if alpha < 0 and not is_converted_prelu:
1180 # For negative alpha that is not from a converted PReLU we need to use
1181 # int32 Mul below to perform the (negative) alpha scaling
1182 mul_ifm.dtype = DataType.int32
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001183 min_op.set_output_tensor(mul_ifm)
1184 min_op.set_ifm_ofm_shapes()
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001185 new_op = Op.Add
1186 op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001187 DebugDatabase.add_optimised(op, min_op)
1188
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001189 # Add multiplication with alpha
1190 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001191 mul_alpha.add_input_tensor(mul_ifm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001192 # Create const tensor containing alpha as scalar
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001193 quantization = ifm.quantization.clone()
1194 quantization.min = 0
1195 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
1196 quantization.zero_point = 0
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001197 alpha_dtype = mul_ifm.dtype
Fredrik Svedberg36424312022-09-16 09:39:26 +02001198 if is_converted_prelu:
1199 # The LeakyRelu was the result from convert_prelu and the scaling is provided
Fredrik Svedberg66591652022-08-29 10:51:27 +02001200 scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +02001201 mul_alpha.explicit_scaling = ExplicitScaling(False, [alpha_shift], [alpha_scale])
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001202 elif alpha == 0 or np.isinf(1 / alpha):
1203 # Handling of alpha near or at zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001204 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001205 scalar = 0
1206 else:
1207 quantization.scale_f32 = alpha
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001208 if alpha_dtype == DataType.int32:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001209 # When the datatype is int32 (alpha negative) we need to do the scaling with the multiplication
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001210 scalar, _ = scaling.elementwise_mul_scale(ifm.quantization.scale_f32, alpha, ofm.quantization.scale_f32)
1211 else:
1212 scalar = 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001213 alpha_tens = create_const_tensor(
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001214 op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], alpha_dtype.as_numpy_type(), quantization=quantization
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001215 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001216 mul_alpha.add_input_tensor(alpha_tens)
1217 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1218 mul_alpha.set_output_tensor(fm_alpha)
1219 mul_alpha.set_ifm_ofm_shapes()
1220 DebugDatabase.add_optimised(op, mul_alpha)
1221
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001222 if not use_mul_max:
1223 relu_op = Operation(Op.Relu, op.name + "_relu")
1224 relu_op.add_input_tensor(ifm)
1225 fm_id = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1226 relu_op.set_output_tensor(fm_id)
1227 relu_op.set_ifm_ofm_shapes()
1228 DebugDatabase.add_optimised(op, relu_op)
1229 elif check_quantized_tens_scaling_equal(ifm, ofm):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001230 # No identity multiplication is needed
1231 fm_id = ifm
1232 else:
1233 # Add multiplication with identity
1234 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1235 mul_identity.add_input_tensor(ifm)
1236 # Create const tensor containing identity as scalar
1237 quantization = ifm.quantization.clone()
1238 quantization.min = 0
1239 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001240 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001241 quantization.zero_point = 0
1242 identity_tens = create_const_tensor(
1243 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
1244 )
1245 mul_identity.add_input_tensor(identity_tens)
1246 # Make sure that fm_id is allocated to a different address than fm_alpha
1247 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1248 mul_identity.set_output_tensor(fm_id)
1249 mul_identity.set_ifm_ofm_shapes()
1250 DebugDatabase.add_optimised(op, mul_identity)
1251
1252 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001253 op.type = new_op
1254 op.name = op.name.replace("LeakyRelu", new_op.name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001255 op.inputs = []
1256 ifm.consumer_list.remove(op)
1257 op.add_input_tensor(fm_alpha)
1258 op.add_input_tensor(fm_id)
1259 op.set_ifm_ofm_shapes()
1260
1261 DebugDatabase.add_optimised(op, op)
1262 return op
1263
1264
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001265def convert_to_lut8(op, fn, fn_name):
1266 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
1267 # fn is a function(real) -> real
1268 ifm, ofm = op.get_ifm_ofm()
1269 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1270 return op
1271 # Generate the LUT
1272 ifm_scale = np.double(ifm.quantization.scale_f32)
1273 ofm_scale = np.double(ofm.quantization.scale_f32)
1274 zp_in = ifm.quantization.zero_point
1275 zp_out = ofm.quantization.zero_point
1276 values = []
1277 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1278 quantized_min = min(ix)
1279 quantized_max = max(ix)
1280 for x in ix:
1281 x_real = ifm_scale * (x - zp_in)
1282 y_real = fn(x_real)
1283 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1284 lut_result = min(quantized_max, max(quantized_min, lut_result))
1285 values.append(lut_result)
1286 return convert_to_lut(op, values, fn_name)
1287
1288
1289def convert_lrelu_to_lut(op, arch):
1290 ifm, ofm = op.get_ifm_ofm()
1291 # Generate the LUT
1292 alpha = op.attrs["alpha"]
1293 ifm_scale = np.double(ifm.quantization.scale_f32)
1294 ofm_scale = np.double(ofm.quantization.scale_f32)
1295 zp_in = ifm.quantization.zero_point
1296 zp_out = ofm.quantization.zero_point
1297 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1298 alpha_scalar = 1
1299 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1300 if "alpha_scaling" in op.attrs:
1301 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1302 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1303 values = []
1304 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1305 quantized_min = min(ix)
1306 quantized_max = max(ix)
1307 for x in ix:
1308 if x < zp_in:
1309 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1310 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1311 )
1312 else:
1313 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1314 lut_result = min(quantized_max, max(quantized_min, lut_result))
1315 values.append(lut_result)
1316 return convert_to_lut(op, values, "lrelu")
1317
1318
1319def convert_lrelu(op, arch, nng):
1320 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
1321 if op.type != Op.LeakyRelu:
1322 return op
1323 ifm, ofm = op.get_ifm_ofm()
1324 if ifm is None or ofm is None:
1325 return op
Fredrik Svedberg36424312022-09-16 09:39:26 +02001326 alpha = op.attrs["alpha"]
1327 if alpha == 0:
1328 # When alpha is 0 the opertion can be converted to a ReLU
1329 op.type = Op.Relu
1330 op.name = op.name.replace("LeakyRelu", op.type.name)
1331 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001332 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1333 # use LUT for int8/uint8
1334 return convert_lrelu_to_lut(op, arch)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001335 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16 and alpha > 0:
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001336 # use LeakyRelu unmodified for int16 with equal input/output scaling and positive alpha
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001337 return op
1338 return convert_lrelu_to_mul_max(op, arch)
1339
1340
1341def convert_tanh_sigmoid_to_lut(op, arch, nng):
1342 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
1343 if op.type == Op.Sigmoid:
1344 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1345 elif op.type == Op.Tanh:
1346 return convert_to_lut8(op, math.tanh, "tanh")
1347 return op
1348
1349
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001350def remove_memory_only_ops(op, arch):
1351 if op.run_on_npu and op.type in memory_only_ops:
1352 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001353
1354
1355def fuse_activation_function_with_prev(op, arch, nng):
1356 # if op is a no-op: attempts to move the activation function to the preceding op
1357 if not op.attrs.get("is_nop", False) or op.activation is None:
1358 return op
1359 ifm, ofm = op.get_ifm_ofm()
1360 if ifm is None or ofm is None:
1361 return op
1362 # finds the input(s) to the operation
1363 prev_op = ifm.ops[0]
1364 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1365 fuse = (
1366 prev_op.run_on_npu
1367 and prev_op.type.npu_block_type != NpuBlockType.Default
1368 and len(ifm.ops) == 1
1369 and len(prev_op.outputs[0].consumers()) == 1
1370 and prev_op.activation is None
1371 )
1372 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1373 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1374 # LUT currently only works correctly for elementwise ops
1375 fuse = False
1376 if not fuse:
1377 return op
1378 # Move the fused activation function + corresponding info to prev_op
1379 prev_op.activation = op.activation
1380 prev_op.forced_output_quantization = op.forced_output_quantization
1381 if op.activation_lut is not None:
1382 prev_op.set_activation_lut(op.activation_lut)
1383 # Bypass op
1384 prev_op.set_output_tensor(ofm)
1385 DebugDatabase.add_optimised(op, prev_op)
1386 return op
1387
1388
1389def _leading_pad_ok(leading_pad, stride, kernel_size):
1390 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1391 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1392 max_size = kernel_size // 2
1393 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1394
1395
1396def replace_pad_by_hw_pad(op: Operation, arch, nng):
1397 """
1398 Tries to completely remove a PAD operator by using hardware padding.
1399 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1400 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1401 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1402 if both operations can be run on the NPU.
1403 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1404 """
1405 if (
1406 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +00001407 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001408 and op.run_on_npu
1409 and op.attrs["padding"] == Padding.VALID
1410 ):
1411 pad_op = op.ifm.ops[0]
1412 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1413 return op
1414 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1415 return op
1416 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1417 k = op.kernel
1418 k_w, k_h = k.dilated_wh()
1419
1420 # Check if the PAD operator can be replaced by hardware padding
1421 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1422 # Too much padding, it would require hardware padding to actually insert zeros
1423 return op
1424 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1425 return op
1426
1427 if op.type.is_avgpool_op():
1428 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1429 for pad, k_size in (
1430 (left, k_w),
1431 (right, k_w),
1432 (top, k_h),
1433 (bottom, k_h),
1434 ):
1435 if pad not in (0, k_size // 2):
1436 return op
1437 # Average pool is converted to depthwise, because NPU average pool + same padding
1438 # has a special implementation that is different from PAD followed by average pool with
1439 # valid padding.
1440 k_w, k_h = op.kernel.width, op.kernel.height
1441 ifm = op.ifm
1442 # Remember other inputs
1443 other_inputs = op.inputs[1:]
1444 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1445 quantization = QuantizationParameters(0.0, 255.0)
1446 quantization.scale_f32 = 1.0 / (k_w * k_h)
1447 quantization.zero_point = 0
1448 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1449 weights = np.full(shape, 1)
1450
1451 weight_tens = create_const_tensor(
1452 op.name + "_weights",
1453 shape,
1454 op.ifm.dtype,
1455 weights,
1456 np.uint8,
1457 purpose=TensorPurpose.Weights,
1458 quantization=quantization,
1459 )
James Peet7519d502021-07-19 16:47:58 +01001460 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001461 op.type = Op.DepthwiseConv2DBias
1462 op.inputs = []
1463 op.add_input_tensor(ifm)
1464 op.add_input_tensor(weight_tens)
1465 # Add bias tensor, all biases set to 0
1466 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001467 fixup_bias_tensors(op, arch, nng, DataType.int32)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001468 # Add other inputs
1469 op.inputs.extend(other_inputs)
1470 op.rounding_mode = NpuRoundingMode.NATURAL
1471
1472 # Bypass the PAD operator
1473 op.set_input_tensor(pad_op.ifm, 0)
1474 # Adjust the padding attributes of the convolution operator
1475 op.attrs["padding"] = Padding.EXPLICIT
1476 op.attrs["explicit_padding"] = (top, left, bottom, right)
1477 op.set_ifm_ofm_shapes()
1478 return op
1479
1480
1481def convert_pad(op: Operation, arch, nng):
1482 """
1483 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1484 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1485 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1486 """
1487 if op.type != Op.Pad or not op.run_on_npu:
1488 return op
1489 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1490
1491 ifm = op.ifm
1492 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001493 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001494 ofm = op.ofm
1495 assert ofm is not None
1496 ofm.ops = []
1497 ofm_shape = op.ofm_shapes[0]
1498
1499 # Average pool op that copies IFM to the right place inside the OFM
1500 shp0 = Shape4D(0, 0, 0, 0)
1501 shp_top = shp0.with_height(top)
1502 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1503 avgpool_op.activation = op.activation
1504 quant = ofm.quantization
1505 pad_value = quant.zero_point
1506 # Add operations that fill the borders of the OFM
1507 if top > 0:
1508 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1509 zero_tens = create_const_tensor(
1510 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1511 )
1512 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1513 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1514 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1515 if bottom > 0:
1516 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1517 zero_tens = create_const_tensor(
1518 op.name + "_bottom",
1519 shape.as_list(),
1520 ofm.dtype,
1521 shape.elements() * [pad_value],
1522 np.uint8,
1523 quantization=quant,
1524 )
1525 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1526 create_avg_pool_for_concat(
1527 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1528 )
1529 if left > 0:
1530 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1531 zero_tens = create_const_tensor(
1532 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1533 )
1534 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1535 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1536 if right > 0:
1537 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1538 zero_tens = create_const_tensor(
1539 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1540 )
1541 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1542 create_avg_pool_for_concat(
1543 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1544 )
1545
1546 op.type = Op.ConcatTFLite
1547 return avgpool_op
1548
1549
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001550def fixup_bias_tensors(op, arch, nng, dtype=None):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001551 if op.type.needs_bias() and op.bias is None:
1552 # Op has no bias, add bias tensor filled with zeros
1553 nr_biases = op.inputs[1].shape[-1]
1554 bias_values = [0] * nr_biases
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001555 # The DataType of the bias tensor can be explicitly provided or deduced from the ifm
1556 # DataType. Default is int32 bias for 8-bit ifms and int64 for int16 ifms.
1557 # For int16 the selected bias DataType will have an impact on the scaling
1558 # used when encoding the scales and biases later. The default mode will match the
1559 # refence with reduced scaling for int64 bias.
1560 # This means that in cases (in the graph optimiser) where DepthwiseConv2DBias
1561 # is used to emulate average pool int32 bias should be selected for full precision
1562 # int16 scaling.
1563 if dtype is None:
1564 dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32
1565 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001566 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1567
1568 return op
1569
1570
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001571def fixup_asymmetric_weights(op, arch, nng):
1572 if op.run_on_npu and (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op()):
1573 if op.ifm.dtype == DataType.int8:
1574 if not np.all(op.weights.quantization.zero_point == 0):
1575 print(f"Warning: {op.type} '{op.name}' has asymmetric weights, zero points have been adjusted.")
1576 op.weights.quantization.zero_point *= 0
1577
1578 return op
1579
1580
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001581def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1582 if op.type == Op.Mean and op.run_on_npu:
1583 keep_dims = op.attrs.get("keep_dims", False)
1584 inp, axis = op.inputs
1585 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001586 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001587 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001588 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001589
1590 # Height and width axes have different index depending on dimensions
1591 if axis.shape == [] or axis.shape[0] == 1: # single axis
1592 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1593 if dims in (2, 3):
1594 if axis == 0:
1595 h, w = shape[axis], 1
1596 else:
1597 h, w = 1, shape[axis]
1598 else:
1599 if axis == 1:
1600 h, w = shape[axis], 1
1601 else:
1602 h, w = 1, shape[axis]
1603 else: # multiple axes
1604 axis = sorted(axis.values)
1605 h, w = [shape[i] for i in axis]
1606
1607 # Set necessary depthwise attributes
1608 op.attrs.update(
1609 {
1610 "padding": Padding.VALID,
1611 "stride_h": 1,
1612 "stride_w": 1,
1613 "strides": (1, 1, 1, 1),
1614 "depth_multiplier": 1,
1615 "channel_multiplier": 1,
1616 "dilation_h_factor": 1,
1617 "dilation_w_factor": 1,
1618 "dilation": (1, 1, 1, 1),
1619 }
1620 )
1621 # Change op type
1622 op.type = Op.DepthwiseConv2DBias
1623 # Set IFM/OFM shapes after changing op type
1624 op.set_ifm_ofm_shapes()
1625
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001626 weight_scale, bias = 1, 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001627 ofmq, ifmq = op.ofm.quantization, inp.quantization
1628 # Set rounding mode, scaling and zero point based on which reference implementation to match
1629 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1630 if inp.dtype == DataType.uint8:
1631 # This attribute means a different scaling calculation is used in order to match reference
1632 op.low_precision_scaling = True
1633 weight_scale = h * w
1634 # Set zero points to 0 as they will be adjusted for with bias term
1635 foq = ofmq.clone()
1636 foq.zero_point = 0
1637 fiq = ifmq.clone()
1638 fiq.zero_point = 0
1639 op.forced_input_quantization = fiq
Johan Alfvén17009392022-08-30 09:14:56 +02001640 bias_term = ofmq.zero_point - round_up_to_int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001641 # If the bias term is outside uint8 range, we need an Add op to apply it.
1642 if bias_term < 0 or bias_term > 255:
1643 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1644 # Bias term has higher bitness (i32) than input/output (u8).
1645 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1646 # the bias can only effectively assume values in the range [-255, 255].
1647 intermediate.dtype = DataType.int16
1648 intermediate.quantization.zero_point = 0
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001649 add_op = Operation(Op.Add, f"{op.name}_bias")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001650 add_op.forced_output_quantization = foq
1651 add_op.add_input_tensor(intermediate)
1652 quant = QuantizationParameters()
1653 quant.zero_point = 0
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001654 bias_scalar = create_const_tensor(add_op.name, [], DataType.int16, [bias_term], quantization=quant)
1655 add_op.add_input_tensor(bias_scalar)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001656 add_op.set_output_tensor(op.ofm)
1657 add_op.set_ifm_ofm_shapes()
1658 add_op.activation = op.activation
1659 op.activation = None
1660 op.set_output_tensor(intermediate)
1661 op.set_ifm_ofm_shapes()
1662 # If not, we can just do it with the OFM zero point.
1663 else:
1664 foq.zero_point = bias_term
1665 op.forced_output_quantization = foq
1666 else:
1667 assert inp.dtype == DataType.int8
1668 # Use a depthwise to calculate the sum,
1669 # followed by a multiplication with 1/N to get the MEAN
1670 weight_scale = 1
1671 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
Johan Alfvén05916632022-09-06 20:33:22 +02001672 intermediate.dtype = DataType.int32
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001673 mul_op = Operation(Op.Mul, op.name + "_mul")
1674 mul_op.add_input_tensor(intermediate)
Johan Alfvén05916632022-09-06 20:33:22 +02001675 mul_op.set_output_tensor(op.ofm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001676 # Create scalar containing 1/N
1677 quant = QuantizationParameters()
1678 quant.zero_point = 0
1679 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1680 # while rounding mode NATURAL would round this to -1.
1681 # This can only occur if N is even, and can be emulated by
1682 # multiplying with a number that is slightly smaller than 1/N.
1683 # It must be so small that other roundings are not affected;
1684 # the calculated value is based on worst case,
1685 # which is sum 256 * N (the maximum sum that can occur with int8)
1686 n = int(h * w)
1687 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1688 quant.scale_f32 = 1 / (n - eps)
Johan Alfvén05916632022-09-06 20:33:22 +02001689
1690 # For int8/int16 we could use IFM/OFM scaling to do the division
1691 # intermediate * 1 -> scale > round and shift.
1692 #
1693 # For int32 scaling is not supported so instead multiply with the scale
1694 # intermediate * scale -> round and shift.
1695 #
1696 # Calculate the scale and shift value. const Tensor must be created
1697 # with correct quantization since the scale and shift is calculated later
1698 # in the command stream generator.
1699 mul_scale, _ = scaling.elementwise_mul_scale(
1700 mul_op.ifm.quantization.scale_f32, quant.scale_f32, mul_op.ofm.quantization.scale_f32
1701 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001702 scalar = create_const_tensor(
Johan Alfvén05916632022-09-06 20:33:22 +02001703 op.name + "_scalar", [1, 1, 1, 1], DataType.int32, [mul_scale], np.int32, quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001704 )
1705 mul_op.add_input_tensor(scalar)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001706 mul_op.set_ifm_ofm_shapes()
1707 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1708 mul_op.activation = op.activation
1709 op.activation = None
1710 op.set_output_tensor(intermediate)
1711 op.set_ifm_ofm_shapes()
1712 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1713 # Here we can just use a simple AvgPool with truncating rounding,
1714 # as we're emulating simple integer division.
1715 op.rounding_mode = NpuRoundingMode.TRUNCATE
1716 op.type = Op.AvgPool
1717 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1718 else:
1719 op.rounding_mode = NpuRoundingMode.NATURAL
1720 weight_scale = 1 / (h * w)
1721 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1722 bias = -ifmq.zero_point * h * w
1723 fiq = ifmq.clone()
1724 fiq.zero_point = 0
1725 op.forced_input_quantization = fiq
1726
1727 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001728 def extend_dims(dim, in_shape):
1729 if dim < 4:
1730 in_shape = [1] + in_shape
1731 if dim == 2:
1732 in_shape += [1]
1733 return in_shape
1734
1735 if dims < 4 or dims_ofm < 4:
1736 # Fix the ofm dimension when keep_dims is false
1737 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1738 if isinstance(axis, int) and dims_ofm + 1 == dims:
1739 ofm_shape.insert(axis, 1)
1740 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1741 for i in axis:
1742 ofm_shape.insert(i, 1)
1743 shape = extend_dims(dims, shape)
1744 dims_ofm = len(ofm_shape)
1745 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1746 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001747
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001748 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001749 weight_shape = None
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001750 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001751 # This can only happen and be done for multiple axes, and
1752 # h * w <= 256 for DepthwiseConv2DBias
1753 # h * w <= 4096 for AvgPool
1754 # which is checked in supported ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001755 shape = [shape[0], 1, h * w, shape[3]]
1756 op.ifm_shapes[0] = Shape4D(shape)
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001757 weight_shape = [1, h * w, shape[3], shape[0]]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001758 if h > 256 and op.type == Op.AvgPool:
1759 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1760
1761 # If the AvgPool version is used, we don't need to do anything else
1762 if op.type == Op.AvgPool:
1763 return op
1764
1765 # Make unit weight tensor quantization
1766 weight_quant = ifmq.clone()
1767 weight_quant.min = 0
1768 weight_quant.max = 255
1769 weight_quant.scale_f32 = weight_scale
1770 weight_quant.zero_point = 0
1771
Johan Alfvéne84ed6b2022-09-26 13:46:51 +02001772 if weight_shape is None:
1773 # Set weight shape to [H,W,C,B]
1774 weight_shape = [h, w, shape[3], shape[0]]
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001775
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001776 # Add unit weight tensor
1777 op.set_input_tensor(
1778 create_const_tensor(
1779 "weights",
1780 weight_shape,
1781 inp.dtype,
1782 np.ones(weight_shape),
1783 value_dtype=np.uint8,
1784 quantization=weight_quant,
1785 ),
1786 1,
1787 )
James Peet7519d502021-07-19 16:47:58 +01001788 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001789
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001790 # Add bias tensor
Fredrik Svedberg1e5456f2022-09-23 15:25:17 +02001791 bias_shape = [shape[-1]]
1792 op.inputs.append(create_const_tensor("bias", bias_shape, DataType.int32, np.ones(bias_shape) * bias))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001793
1794 return op
1795
1796
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001797def optimise_quantize(op: Operation, arch, nng):
1798
1799 if op.type == Op.Quantize and op.run_on_npu:
1800
1801 ifm, ofm = op.get_ifm_ofm()
1802 input_values = ifm.values
1803
1804 # Guard clause - input not const or no values to quantize
1805 if ifm.ops[0].type != Op.Const or input_values is None:
1806 return op
1807
1808 # Singular val in numpy array, convert to indexable array
1809 if input_values.ndim == 0:
1810 input_values = np.array([input_values])
1811
Fredrik Svedberg11563172022-07-06 14:54:12 +02001812 # requantized int8 to int8 or int16 to int16
1813 if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001814
1815 # scale needs to use double precision to match TFLite reference kernel
1816 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
1817 effective_multiplier, effective_shift = quantise_scale(effective_scale)
1818
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001819 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001820 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001821 input_val = val - ifm.quantization.zero_point
1822
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001823 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
1824 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001825
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001826 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
1827 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001828
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001829 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
1830 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001831
1832 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001833 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001834
1835 quantized_vals = []
1836 for val in input_values:
1837
1838 # Derive quantized value
1839 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001840 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
1841 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001842
1843 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001844 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
1845
1846 # Unsupported data type
1847 else:
1848 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001849
1850 # Make quantize op const and disconnect from parent node
1851
1852 # Remove reference of the current quant op from the parent tensor's consumer list
1853 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1854
1855 # Clear any references to parent node
1856 op.inputs = []
1857
1858 # Convert this quantize op to const
1859 op.type = Op.Const
1860
1861 return op
1862
1863
Ayaan Masood4965fae2022-06-29 11:30:57 +01001864def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
1865 """Static optimisation for SHAPE operator output value known at compile time"""
1866
1867 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
1868
1869 if op.type == Op.Shape and op.run_on_npu:
1870
1871 ifm, ofm = op.get_ifm_ofm()
1872
1873 if len(ifm.shape) != ofm.shape[0]:
1874 return op
1875
1876 # Remove reference of the current shape op from the parent tensor's consumer list
1877 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1878
1879 # Clear any references to parent node
1880 op.inputs = []
1881
1882 # Convert this SHAPE op to const
1883 op.type = Op.Const
1884
1885 # Add size calculation to shape output tensors
1886 ofm.values = np.array(ifm.shape)
1887
1888 return op
1889
1890
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001891def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001892 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001893 return op
1894
1895
1896def tflite_optimise_graph(nng, arch):
Fredrik Svedberg11563172022-07-06 14:54:12 +02001897 # Compile time static optimisations
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001898 optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
1899
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001900 for idx, sg in enumerate(nng.subgraphs):
1901 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001902 nng,
1903 sg,
1904 arch,
1905 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01001906 optimisation_list,
1907 rewrite_unsupported=False,
1908 )
1909
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001910 # Pre-processing step
1911 pre_process_list = [
1912 supported_operator_check,
1913 set_ifm_ofm_op_shapes,
1914 ]
1915
Ayaan Masood4965fae2022-06-29 11:30:57 +01001916 for idx, sg in enumerate(nng.subgraphs):
1917 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1918 nng,
1919 sg,
1920 arch,
1921 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02001922 pre_process_list,
1923 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001924 )
1925
1926 # Handle Concat Ops
1927 for idx, sg in enumerate(nng.subgraphs):
1928 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1929 sg.refresh_after_modification()
1930
1931 # Handle Split Ops
1932 for idx, sg in enumerate(nng.subgraphs):
1933 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1934 nng,
1935 sg,
1936 arch,
1937 [],
1938 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1939 rewrite_unsupported=False,
1940 )
1941
1942 for idx, sg in enumerate(nng.subgraphs):
1943 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001944 nng,
1945 sg,
1946 arch,
1947 [rewrite_split_ops],
1948 [],
1949 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001950 )
1951
1952 # Handle sg input output
1953 for idx, sg in enumerate(nng.subgraphs):
1954 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001955 nng,
1956 sg,
1957 arch,
1958 [],
1959 [fix_sg_input_output],
1960 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001961 )
1962
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001963 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001964 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001965 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001966 sg.refresh_after_modification()
1967
1968 # Rewrite of operators
1969 op_rewrite_list = [
1970 set_tensor_equivalence,
1971 convert_mean_to_depthwise_conv_or_avgpool,
1972 convert_depthwise_to_conv,
1973 convert_conv_to_fc,
1974 convert_softmax,
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001975 convert_prelu,
Fredrik Svedberg36424312022-09-16 09:39:26 +02001976 convert_mul_max_to_abs_or_lrelu,
1977 convert_lrelu,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001978 optimise_strided_conv,
1979 convert_hardswish_to_lut,
1980 rewrite_fully_connected_input,
1981 convert_batched_fc_shape,
1982 fixup_conv2d_backprop,
1983 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001984 reorder_depthwise_weights,
Tim Hall885033b2022-07-21 11:46:03 +01001985 fixup_resize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001986 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001987 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001988 convert_tanh_sigmoid_to_lut,
1989 replace_pad_by_hw_pad,
1990 ]
1991
1992 for idx, sg in enumerate(nng.subgraphs):
1993 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001994 nng,
1995 sg,
1996 arch,
1997 [],
1998 op_rewrite_list,
1999 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002000 )
2001
2002 for idx, sg in enumerate(nng.subgraphs):
2003 # remove passthrough tensors and attempt further optimizations
2004 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
2005 nng,
2006 sg,
2007 arch,
2008 [remove_passthrough_tensor],
2009 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
2010 )
2011
2012 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
2013 # since ifm/ofm_shapes are of importance to this function
2014 for sg in nng.subgraphs:
2015 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
2016 sg.refresh_after_modification()
2017
2018 return nng