blob: 1310ee63fb319430afd60af675a5088fa401ec0f [file] [log] [blame]
Rickard Bolinfea15162022-07-04 16:19:16 +00001# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from . import rewrite_graph
26from . import scaling
27from .api import NpuRoundingMode
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020028from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020029from .data_type import DataType
30from .debug_database import DebugDatabase
31from .errors import UnsupportedFeatureError
32from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020033from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020034from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020035from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020036from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020037from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020038from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020039from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020040from .graph_optimiser_util import needed_total_padding
41from .graph_optimiser_util import set_ifm_ofm_op_shapes
42from .graph_optimiser_util import set_tensor_equivalence
43from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020044from .numeric_util import round_away_zero
Johan Alfvén17009392022-08-30 09:14:56 +020045from .numeric_util import round_up_to_int
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020046from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020047from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020048from .operation import NpuBlockType
49from .operation import Op
50from .operation import Operation
51from .operation import Padding
52from .operation_util import create_avgpool_nop
53from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010054from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020055from .shape4d import Shape4D
56from .softmax import SoftMax
57from .tensor import check_quantized_tens_scaling_equal
58from .tensor import create_const_tensor
59from .tensor import create_equivalence_id
60from .tensor import QuantizationParameters
61from .tensor import Tensor
62from .tensor import TensorPurpose
63from .tflite_mapping import optype_to_builtintype
64
65passthrough_nodes = (Op.Identity,)
66
67
68def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
69 """Creates an average pool for the given concat op/input feature map"""
70 ofm = concat_op.ofm
71 avgpool_op = create_avgpool_nop(name)
72 avgpool_op.inputs = [ifm]
73 avgpool_op.outputs = [ofm]
74
75 avgpool_op.write_offset = write_offset
76 avgpool_op.write_shape = ifm_shape
77 ofm.ops.append(avgpool_op)
78 DebugDatabase.add_optimised(concat_op, avgpool_op)
79 avgpool_op.ifm_shapes.append(ifm_shape)
80 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
81 avgpool_op.memory_function = Op.ConcatSliceWrite
82 return avgpool_op
83
84
85def remove_passthrough_tensor(tens, arch, nng):
86 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
87 assert len(tens.ops[0].inputs) == 1
88 tens = tens.ops[0].inputs[0]
89 return tens
90
91
92def rewrite_concat_ops(op, arch):
93 if not op.run_on_npu or not op.type.is_concat_op():
94 return
95
96 axis_4D = 0
97 ofm = op.ofm
98 ofm.ops = []
99 offset = 0
100
101 unfuse_activation_function(op)
102
103 if op.type == Op.Pack:
104 # Pack is also referred to as Stack
105 axis = int(op.attrs["axis"])
106 if axis < 0: # Convert to positive axis
107 axis = len(op.inputs[0].shape) + 1 + axis
108
109 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
110
111 axis_4D = axis + (4 - len(desired_shape))
112
113 for idx, inp in enumerate(op.inputs):
114 op.ifm_shapes[idx] = Shape4D(desired_shape)
115 op.type = Op.PackReshaped
116
117 inputs, axis = op.get_concat_inputs_axis()
118 for idx, inp in enumerate(inputs):
119 if op.type != Op.PackReshaped:
120 op.ifm_shapes[idx] = Shape4D(inp.shape)
121 if axis >= 0:
122 axis_4D = axis + (4 - len(inp.shape))
123 else:
124 axis_4D = axis
125 write_offset = [0, 0, 0, 0]
126 write_offset[axis_4D] = offset
127 concat_end = offset + op.ifm_shapes[idx][axis_4D]
128 create_avg_pool_for_concat(
129 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
130 )
131 offset = concat_end
132 assert ofm.shape[axis] == offset
133
134 return op
135
136
137def rewrite_split_ops(tens, arch, nng):
138
139 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
140 split_op = tens.ops[0]
141
142 # Not supported so leave it and run on CPU
143 if not split_op.run_on_npu:
144 return tens
145
146 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
147
148 tens.ops = []
149 new_op = Operation(Op.SplitSliceRead, split_op.name)
150 new_op.inputs = [inp]
151 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000152 if None in (offset_end, offset_start):
153 read_shape = None
154 else:
155 # the read shape is relative to each start offset
156 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200157
158 # For Split the offset cannot be extracted from the tensor so it has to
159 # be calculated from the index of the output tensor
160 if axis is not None:
161 # Get the start and end of the split
162 offset_start = [0] * 4
163 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
164 for idx, out in enumerate(outputs):
165 if axis_4D_list is not None:
166 axis_4D = axis_4D_list[idx]
167 else:
168 split_op.ofm_shapes[idx] = Shape4D(out.shape)
169 if axis >= 0:
170 axis_4D = axis + (4 - len(out.shape))
171 else:
172 axis_4D = axis
173
174 if out == tens:
175 ofm_shape_idx = idx
176 read_shape = split_op.ofm_shapes[idx]
177 break
178
179 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
180
181 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
182 new_op.read_shapes[0] = read_shape
183 new_op.run_on_npu = True
184 new_op.set_output_tensor(tens)
185 new_op.ifm_shapes.append(Shape4D(inp.shape))
186 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
187 DebugDatabase.add_optimised(split_op, new_op)
188
189 return tens
190
191
192def remove_SplitSliceRead(op, arch):
193
194 if op.type == Op.SplitSliceRead:
195 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
196 if (
197 len(op.ofm.consumer_list) == 1
198 and op.ofm.consumer_list[0] is not None
199 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200200 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200201 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
202 ):
203 # SplitSliceRead can be performed by tensor consumer
204 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200205 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200206 else:
207 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
208 avgpool_op.add_input_tensor(op.ifm)
209 avgpool_op.outputs = [op.ofm]
210 op.ofm.ops.remove(op)
211 op.ofm.ops.append(avgpool_op)
212 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
213 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
214 avgpool_op.read_offsets[0] = op.read_offsets[0]
215 avgpool_op.read_shapes[0] = op.read_shapes[0]
216
217 op.ifm.consumer_list.remove(op)
218 DebugDatabase.add_optimised(op, avgpool_op)
219
220
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200221def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
222 k_w, k_h = kernel.dilated_wh()
223 s_x, s_y = kernel.stride
224 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
225 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
226 if padding_type == Padding.SAME:
227 left_pad = (xpad + 0) // 2
228 right_pad = (xpad + 1) // 2
229 top_pad = (ypad + 0) // 2
230 bottom_pad = (ypad + 1) // 2
231 elif padding_type == Padding.VALID:
232 left_pad = 0
233 right_pad = 0
234 top_pad = 0
235 bottom_pad = 0
236 elif padding_type == Padding.EXPLICIT:
237 # Padding is specified in a PAD operator which has been bypassed.
238 top, left, bottom, right = explicit_padding
239 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
240 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Rickard Bolin9ae34552022-06-09 13:07:17 +0000241 elif padding_type == Padding.TILE:
242 # The values in the explicit padding only represent the "direction" in which to pad
243 top_pad, left_pad, bottom_pad, right_pad = explicit_padding
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200244 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000245 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200246 padding = (top_pad, left_pad, bottom_pad, right_pad)
247 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
248 return padding, skirt
249
250
251def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
252 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
253 if padding_type == Padding.SAME:
254 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
255 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
256 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
257 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
258 left_pad = max(kernel_width - 1 - right_pad, 0)
259 top_pad = max(kernel_height - 1 - bottom_pad, 0)
260 elif padding_type == Padding.VALID:
261 right_pad = max(kernel_width - 2, 0)
262 bottom_pad = max(kernel_height - 2, 0)
263 left_pad = kernel_width - 1
264 top_pad = kernel_height - 1
265 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000266 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200267 padding = (top_pad, left_pad, bottom_pad, right_pad)
268 skirt = padding
269 return padding, skirt
270
271
272def fixup_conv2d_backprop(op, arch, nng):
273 if op.type == Op.Conv2DBackpropInput:
274 # flip the inputs
275 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
276 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000277 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200278
279 # Update strides
280 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
281
282 return op
283
284
285# Convert the op to an elementwise add
Tim Hall885033b2022-07-21 11:46:03 +0100286def convert_resize_1x1_to_add(op):
287 op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200288 op.name = op.name + "_add"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200289 # Create an input tensor filled with zeros
290 shape = op.ofm_shapes[0].as_list()
291 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100292 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200293 tens.quantization = QuantizationParameters(0.0, 255.0)
294 tens.quantization.scale_f32 = 1.0
295 tens.quantization.zero_point = 0
296 tens.consumer_list = [op]
297 tens_op = op.inputs[1].ops[0]
298 tens_op.set_output_tensor(tens)
299 # Set the add inputs
300 op.inputs[1] = op.inputs[0]
301 op.inputs[0] = tens
302 op.set_ifm_ofm_shapes()
303
304 return op
305
306
Tim Hall885033b2022-07-21 11:46:03 +0100307# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
308# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
309# to select the appropriate nearest neighbor value
310def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
311 ifm = op.ifm
312 ofm = op.ofm
313 output_depth = ofm.shape[-1]
314 dw_op_attrs = {
315 "padding": Padding.VALID,
316 "stride_h": 1,
317 "stride_w": 1,
318 "strides": (1, 1, 1, 1),
319 "depth_multiplier": 1,
320 "channel_multiplier": 1,
321 "dilation_h_factor": 1,
322 "dilation_w_factor": 1,
323 "dilation": (1, 1, 1, 1),
324 }
325
326 # change resizebilinear to depthwise
327 op.type = Op.DepthwiseConv2DBias
328 op.attrs.update(dw_op_attrs)
329 op.set_input_tensor(ifm, 0) # ifm tensor index
330 op.activation = None
331
332 # add input resample to resize by x2
333 op.ifm_resampling_mode = resampling_mode.NEAREST
334
335 # don't care about the rounding mode as it is nearest neighbor
336
337 # setup weight tensor
338 weight_quant = QuantizationParameters()
339 weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
340 weight_quant.zero_point = 0
341 weight_quant.quant_dim = 0
342 ofm_dtype = ofm.dtype
343 if ofm_dtype == DataType.uint8:
344 weight_value_dtype = np.uint8
345 weight_quant.quant_min = 0
346 weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
347 else:
348 if ofm_dtype == DataType.int8:
349 weight_value_dtype = np.int8
350 else:
351 assert ofm_dtype == DataType.int16
352 weight_value_dtype = np.int16
353
354 weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
355 weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
356
357 weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
358
359 # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
360 # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
361 # below-and-right (i.e. next) to it (D).
362 # 0---1---2
363 # | A | B |
364 # 1---*---+
365 # | C | D |
366 # 2---+---+
367 weight_values = [0] * (upscale_factor * upscale_factor)
368 centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
369 weight_values[centre_coeff] = 1
370
371 # add weight tensor, this will discard the size tensor of the resize op
372 op.set_input_tensor(
373 create_const_tensor(
374 "weights",
375 weight_shape,
376 ofm.dtype,
377 np.array(weight_values).reshape(weight_shape),
378 value_dtype=weight_value_dtype,
379 quantization=weight_quant,
380 ),
381 1, # inputs tensor weight index
382 )
383
384 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
385 # need to append the bias tensor as resize ops only have 2 inputs
386 assert len(op.inputs) == 2
387 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +0200388 fixup_bias_tensors(op, None, None, DataType.int32)
Tim Hall885033b2022-07-21 11:46:03 +0100389
390 # finally update the shape incase we've change the tensor shapes or connections
391 op.set_ifm_ofm_shapes()
392
393 return op
394
395
396# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
397# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
398# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
399def convert_resize_to_upscale_and_average_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200400 pre_op = op
401 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000402 dtype = op.ifm.dtype
Tim Hall885033b2022-07-21 11:46:03 +0100403
Rickard Boline546def2022-01-25 15:45:00 +0000404 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Tim Hall47c76362022-07-18 21:26:47 +0100405 op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
Tim Hall3c5cfe92022-03-16 16:31:57 +0000406 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200407
408 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
Tim Hall47c76362022-07-18 21:26:47 +0100409
410 # Get upscale factor that was calculated in the supported operators check
411 upscale_factor = op.attrs["upscale_factor"]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200412
Rickard Boline546def2022-01-25 15:45:00 +0000413 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100414 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
415 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
Rickard Boline546def2022-01-25 15:45:00 +0000416 n = int(np.log2(upscale_factor))
417
Tim Hall885033b2022-07-21 11:46:03 +0100418 # Perform x2 upscaling n-1 times
Rickard Boline546def2022-01-25 15:45:00 +0000419 scaled_op = pre_op
420 for count in range(n - 1):
421 if count > 0:
422 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200423 scaled_op.inputs[0] = pre_op.outputs[0]
424
Tim Hall885033b2022-07-21 11:46:03 +0100425 # Nearest neighbor x2 upscaling
Tim Hall47c76362022-07-18 21:26:47 +0100426 upscaled_shape = upscaled_shape * 2
Rickard Boline546def2022-01-25 15:45:00 +0000427 shape = op.ofm_shapes[0].as_list()
428 shape[1:3] = upscaled_shape
429 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
430 out_tens.quantization = op.outputs[0].quantization.clone()
431 scaled_op.set_output_tensor(out_tens)
432 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200433
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200434 scaled_op.set_ifm_ofm_shapes()
435
Tim Hall885033b2022-07-21 11:46:03 +0100436 # Last x2 upscaling
Rickard Boline546def2022-01-25 15:45:00 +0000437 if n > 1:
438 scaled_op = op.clone(f"_{n-1}")
439 scaled_op.inputs[0] = pre_op.outputs[0]
Tim Hall885033b2022-07-21 11:46:03 +0100440
441 if scaled_op.original_type == Op.ResizeBilinear:
442 if scaled_op.attrs["align_corners"]:
443 # no padding
444 scaled_op.attrs["padding"] = Padding.VALID
445 else:
446 # padding to the right and bottom (limits average pool to 8x8 kernel)
447 scaled_op.attrs["padding"] = Padding.EXPLICIT
448 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
449
450 # kernal size dependent on the upscaling factor
451 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
452 else: # Op.ResizeNearestNeighbor
453 if scaled_op.attrs["align_corners"]:
454 # use depthwise conv to select the correct value
455 scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
456 else:
457 # keep 1x1 kernel and average pool
458 pass
459
Rickard Boline546def2022-01-25 15:45:00 +0000460 scaled_op.outputs = outputs
461 scaled_op.outputs[0].ops = [scaled_op]
462 scaled_op.set_ifm_ofm_shapes()
463
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200464 return op
465
466
Rickard Bolinfea15162022-07-04 16:19:16 +0000467def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
468 def _compute_interpolation_values(index, input_size, output_size):
469 scale = input_size / output_size
470 scaled_value = (index + 0.5 * half_pixel_centers) * scale - 0.5 * half_pixel_centers
471 lower_bound = max(np.floor(scaled_value), 0)
472
473 return scaled_value, lower_bound
474
475 def _compute_kernels(input_height, input_width, output_height, output_width):
476 kernels = []
477 for y in (1, 2):
478 for x in (1, 2):
479 sv_h, lb_h = _compute_interpolation_values(y, input_height, output_height)
480 sv_w, lb_w = _compute_interpolation_values(x, input_width, output_width)
481
482 # Interpolation values calculated for (x, y) = ([1, 2], [1, 2]) will always generalize to the whole
483 # input for upscale = 2 and input sizes >= 2x2 and be in the correct order for going left-to-right,
484 # top-to-bottom - same as the depthwise convolution strides across each tile
485 kernel = np.zeros((2, 2))
486 kernel[1, 1] = (1 - (sv_h - lb_h)) * (1 - (sv_w - lb_w))
487 kernel[0, 1] = (sv_h - lb_h) * (1 - (sv_w - lb_w))
488 kernel[1, 0] = (1 - (sv_h - lb_h)) * (sv_w - lb_w)
489 kernel[0, 0] = (sv_h - lb_h) * (sv_w - lb_w)
490 kernel *= 16
491 kernels.append(kernel)
492
493 return kernels
494
495 def _build_convolutions(op, kernels):
496 dw_op_attrs = {
497 "padding": Padding.TILE,
498 "stride_h": 1,
499 "stride_w": 1,
500 "strides": (1, 1, 1, 1),
501 "depth_multiplier": 1,
502 "channel_multiplier": 1,
503 "dilation_h_factor": 1,
504 "dilation_w_factor": 1,
505 "dilation": (1, 1, 1, 1),
506 }
507 ifm = op.ifm
508 ofm = op.ofm
509 ofm.ops = []
510 elem_size = 2 if ofm.dtype == DataType.int16 else 1
511
512 n, h, w, c = ifm.shape
513 _, _, ow, _ = ofm.shape
514
515 intermediate_tens = Tensor(ifm.shape, ifm.dtype, "intermediate_tens")
516 intermediate_tens.quantization = op.outputs[0].quantization.clone()
517 avgpool_op = op
518 avgpool_op.name = "rb_init_avgpool"
519 avgpool_op.type = Op.AvgPool
520 avgpool_op.attrs["padding"] = Padding.VALID
521 avgpool_op.attrs["stride_w"] = 1
522 avgpool_op.attrs["stride_h"] = 1
523 avgpool_op.attrs["filter_width"] = 1
524 avgpool_op.attrs["filter_height"] = 1
525 avgpool_op.attrs["strides"] = [1, 1, 1, 1]
526 avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
527
528 avgpool_op.add_input_tensor(ifm)
529 avgpool_op.set_output_tensor(intermediate_tens)
530 avgpool_op.set_ifm_ofm_shapes()
531
532 dw_conv = Operation(Op.DepthwiseConv2DBias, "depthwise_conv")
533 dw_conv._original_type = Op.ResizeBilinear
534 dw_conv.write_shape = Shape4D(n, h, w, c)
535 dw_conv.write_offset = Shape4D(0, 0, 0, 0)
536
537 # Set the output rounding mode. Resize bilinear requires rounding away from zero. Therefore, we need to
538 # adjust the accumulated value by a "small" amount before applying natural rounding. The "small" amount
539 # should be big enough to cause a x.5 to be rounded correctly but small enough not to cause smaller
540 # values to be incorrectly rounded
541 ofm.quantization.next_after = True
542 dw_conv.rounding_mode = NpuRoundingMode.NATURAL
543
544 # Double height and width stride to write the output of each of the four depthwise convolutions below
545 # interleaved with each other when combined with OFM tile base offsets.
546 dw_conv.ofm_stride_multiplier = [1, 2, 2] # C/H/W
547
548 # Choose tile padding direction - pad by 1 with edge values in two direction.
549 # For example, TL (top left) will pad top and left in H/W-plane in all channels.
550 directions = [[1, 1, 0, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 1]] # TL, TR, BL, BR
551 for i in (0, 1):
552 for j in (0, 1):
553 index = i * 2 + j
554 dw_conv.name = f"depthwise_conv_{index}"
555 dw_op_attrs["explicit_padding"] = directions[index]
556 dw_conv.attrs.update(dw_op_attrs)
557
558 # This will offset the start of the write by modifying the Tile 0 base address
559 dw_conv.tile_base_offsets_ofm[0] = (i * ow + j) * c * elem_size
560
561 ofm.ops.append(dw_conv)
562 dw_conv.outputs = [ofm]
563
564 kernel = kernels[index]
565 shape = [2, 2, 1, c]
566 kernel = np.dstack([kernel] * c)
567
568 quant = QuantizationParameters()
569 quant.zero_point = 0
570 quant.scale_f32 = 1.0 / 16
571
572 dw_conv.inputs = []
573 dw_conv.add_input_tensor(intermediate_tens)
574 dw_conv.add_input_tensor(
575 create_const_tensor(
576 "weights",
577 shape,
578 intermediate_tens.dtype,
579 np.array(kernel).reshape(shape),
580 value_dtype=np.int8,
581 quantization=quant,
582 ),
583 )
584
585 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
586 # need to append the bias tensor as resize ops only have 2 inputs
587 assert len(dw_conv.inputs) == 2
588 dw_conv.inputs.append(None)
Rickard Bolin017b4cc2022-09-23 10:16:48 +0000589 fixup_bias_tensors(dw_conv, None, None, dtype=DataType.int32)
Rickard Bolinfea15162022-07-04 16:19:16 +0000590
591 dw_conv.set_ifm_ofm_shapes()
592 dw_conv = dw_conv.clone(f"_{index}")
593 return op
594
595 _, input_height, input_width, _ = op.ifm.shape
596 _, output_height, output_width, _ = op.ofm.shape
597
598 kernels = _compute_kernels(input_height, input_width, output_height, output_width)
599 op = _build_convolutions(op, kernels)
600
601 return op
602
603
Tim Hall885033b2022-07-21 11:46:03 +0100604def fixup_resize(op, arch, nng):
605 if op.type.is_resize_op() and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200606 if op.ifm_shapes[0] == op.ofm_shapes[0]:
Tim Hall885033b2022-07-21 11:46:03 +0100607 # Bypass the resize op which is essentially a NOP
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200608 op.inputs = op.inputs[:1]
609 op.type = Op.Identity
610 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
Tim Hall885033b2022-07-21 11:46:03 +0100611 convert_resize_1x1_to_add(op)
Rickard Bolinfea15162022-07-04 16:19:16 +0000612 elif op.type == Op.ResizeBilinear and op.attrs.get("half_pixel_centers", False):
613 convert_resizebilinear_to_depthwise_convolutions(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200614 else:
Tim Hall885033b2022-07-21 11:46:03 +0100615 convert_resize_to_upscale_and_average_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200616
617 return op
618
619
620def convert_nop_split_to_identity(op, arch, nng):
621 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
622 # the list comprehension should return a list with a single tensor
623 # if it shouldn't, remove_passthrough_tensor will fail appropriately
624 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
625 op.type = Op.Identity
626 return op
627
628
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100629def rewrite_fully_connected_input(op: Operation, arch, nng):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200630
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100631 if op.type == Op.FullyConnected:
632 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
633 assert new_shape is not None, "Tensor can not be reshaped to 2D"
634 op.ifm_shapes[0] = new_shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200635 return op
636
637
638def convert_batched_fc_shape(op, arch, nng):
639 if op.type == Op.FullyConnected:
640 # Check if the first dimension indicates batching
641 if op.ifm_shapes[0].batch > 1:
642 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
643 n = op.ifm_shapes[0].batch
644 h, w = batching_split.get(n, (1, n))
645 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
646
647 # Reshape Weights to be 4D. IO becomes HWIO
648 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100649 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
650 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200651
652 n = op.ofm_shapes[0].batch
653 h, w = batching_split.get(n, (1, n))
654 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
655 return op
656
657
658def unfuse_activation_function(op):
659 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
660 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
661 op.activation = None
662 out_tens = op.outputs[0]
663 intermediate_tens = out_tens.clone("_act_intermediate")
664 act_op.set_output_tensor(out_tens)
665 act_op.add_input_tensor(intermediate_tens)
666 op.set_output_tensor(intermediate_tens)
667 act_op.set_ifm_ofm_shapes()
668
669
670def rewrite_stridedslice_output(op, arch, nng):
671 if not op.run_on_npu or op.type != Op.StridedSlice:
672 return op
673
674 new_axis_mask = op.attrs["new_axis_mask"]
675 shrink_axis_mask = op.attrs["shrink_axis_mask"]
676
677 if shrink_axis_mask == 0 and new_axis_mask == 0:
678 return op
679
680 axis_4D = [0] * len(op.outputs)
681 for idx, out_tens in enumerate(op.outputs):
682 output_shape = list(out_tens.shape)
683
684 if shrink_axis_mask != 0:
685 n = 0
686 axis = 0
687 while shrink_axis_mask:
688 prev_mask = shrink_axis_mask
689 n += 1
690 shrink_axis_mask &= shrink_axis_mask - 1
691 axis = int(math.log2(prev_mask - shrink_axis_mask))
692 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
693
694 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
695 op.attrs["shrink_axis_mask"] = 0
696 if axis >= 0:
697 axis_4D[idx] = axis + (4 - len(output_shape))
698 else:
699 axis_4D[idx] = axis
700 op.ofm_shapes[idx] = Shape4D(output_shape)
701
702 elif new_axis_mask != 0:
703 n = 0
704 axis = 0
705 while new_axis_mask:
706 prev_mask = new_axis_mask
707 n += 1
708 new_axis_mask &= new_axis_mask - 1
709 axis = int(math.log2(prev_mask - new_axis_mask))
710 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
711 new_axis_mask >>= 1
712
713 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
714 op.attrs["new_axis_mask"] = 0
715 if axis >= 0:
716 axis_4D[idx] = axis + (4 - len(output_shape))
717 else:
718 axis_4D[idx] = axis
719 op.ofm_shapes[idx] = Shape4D(output_shape)
720
721 op.attrs["split_axis_4D"] = axis_4D
722 return op
723
724
725def rewrite_unpack_output(op, arch, nng):
726 tens = op.outputs[0]
727 if op.run_on_npu and op.type == Op.Unpack:
728 # Unpack is also referred to as Unstack
729 axis = int(op.attrs["axis"])
730 if axis < 0: # Convert to positive axis
731 axis = len(op.inputs[0].shape) + 1 + axis
732 op.type = Op.UnpackReshaped
733 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
734
735 axis_4D = axis + (4 - len(desired_output_shape))
736 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
737
738 for idx, out_tens in enumerate(op.outputs):
739 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
740 return op
741
742
743def add_padding_fields(op, arch, nng):
744 if op.run_on_npu:
745 if "padding" in op.attrs:
746 input_shape = op.ifm_shapes[0]
747 output_shape = op.ofm_shapes[0]
748 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
749 kernel_size = op.inputs[1].shape[:2]
750 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
751 kernel_size = op.attrs["ksize"][1:3]
752 else:
753 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
754
755 if op.type == Op.Conv2DBackpropInputSwitchedBias:
756 upscaling_factor = output_shape.height // input_shape.height
757 padding, skirt = calc_upscaled_padding_and_skirt(
758 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
759 )
760 else:
761 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200762 op.attrs["padding"],
763 op.kernel,
764 input_shape,
765 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200766 )
767
768 op.attrs["explicit_padding"] = padding
769 op.attrs["skirt"] = skirt
770
771 return op
772
773
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200774def reorder_depthwise_weights(op, arch, nng):
775 if op.type.is_depthwise_conv2d_op():
776 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100777 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
778 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200779 weight_tensor.weight_transpose_depthwise = True
780
781 return op
782
783
784def optimise_strided_conv(op, arch, nng):
Louis Verhaard43d27582022-03-17 14:06:00 +0100785 if op.type != Op.Conv2DBias or op.op_index != 0:
786 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200787 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100788 weight_tensor = op.weights
789 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200790
791 if (
Louis Verhaard43d27582022-03-17 14:06:00 +0100792 stride_x == 2
793 and ifm_shape.depth <= 4
794 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200795 and weight_tensor is not None
796 and weight_tensor.shape[1] >= 2
797 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100798 k_w, _ = op.get_kernel_size()
799 curr_padding_x = needed_total_padding(ifm_shape.width, 2, k_w)
800 optimised_padding_x = needed_total_padding(ifm_shape.width // 2, 1, (k_w + 1) // 2)
801 if curr_padding_x != optimised_padding_x:
802 # Horizontal padding would become different after optimisation; this would not work
803 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200804 # IFM
805 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
806
807 # Weights
808 weight_shape = weight_tensor.shape
809 if weight_shape[1] % 2 != 0:
810 weight_shape[1] = weight_shape[1] + 1
811 padded_array = np.zeros(weight_shape)
812 for i in range(weight_shape[0]):
813 padded_array[i] = np.vstack(
814 [
James Peet7519d502021-07-19 16:47:58 +0100815 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200816 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
817 ]
818 )
James Peet7519d502021-07-19 16:47:58 +0100819 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200820 weight_shape[1] //= 2
821 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100822 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200823 weight_tensor.set_all_shapes(weight_shape)
824 # If multiple copies of the weights are used, we could avoid
825 # them having the same address by changing the value_id
826 weight_tensor.value_id = uuid.uuid4()
827
828 # Strides
829 stride_x = 1
830 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
831
832 return op
833
834
835def convert_conv_to_fc(op, arch, nng):
836 # Conv 1x1 can be equivalent to Fully Connected.
837 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
838 # caching/double buffering for the weights.
839 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
840 if op.type == Op.Conv2DBias:
841 h = op.ifm_shapes[0].height
842 w = op.ifm_shapes[0].width
843 kh, kw, _, _ = op.inputs[1].shape
844 if h == 1 and w == 1 and kh == 1 and kw == 1:
845 # Overwrite this op as a Fully Connected Op
846 op.name += "_fc"
847 op.type = Op.FullyConnected
848 op.attrs = {
849 "weights_format": 0,
850 }
851 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
852 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100853 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
854 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200855
856 DebugDatabase.add_optimised(op, op)
857 return op
858
859
860def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
861 if op.run_on_npu and op.type.is_relu_op():
862 ifm = op.inputs[0]
863 ofm = op.outputs[0]
864 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
865 # and requires its own to be inserted
866 if not check_quantized_tens_scaling_equal(ifm, ofm):
867 # Override this op with its own primary op (avgpool)
868 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
869 # And fuse the original activation function to it
870 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200871 # Add explicit rescaling
872 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
873 multiplier, shift = scaling.quantise_scale(rescale)
874 relu_fused_op.rescale = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200875 # Tidy up and assign the ifm and ofm to the new op
876 ifm.consumer_list.remove(op)
877
878 relu_fused_op.add_input_tensor(ifm)
879 relu_fused_op.set_output_tensor(ofm)
880 relu_fused_op.set_ifm_ofm_shapes()
881 op = relu_fused_op
882 return op
883
884
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200885def convert_softmax(op, arch, nng):
886 if op.type == Op.Softmax and op.run_on_npu:
887 softmax = SoftMax(op)
888 op = softmax.get_graph()
889 return op
890
891
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200892def convert_prelu(op, arch, nng):
893 if op.type == Op.Prelu:
894 ifm, alpha, ofm = op.get_ifm_ifm2_ofm()
895 if None in (ifm, alpha, ofm):
896 return op
897
Fredrik Svedberg66591652022-08-29 10:51:27 +0200898 if alpha.values is not None:
899 # If const alpha check for possible optimisations
900 alpha_zp = alpha.quantization.zero_point
901 alpha_scale = alpha.quantization.scale_f32
902 # If all alpha values are the same the PReLU can be converted to LeakyRelu
903 alpha_min = (alpha.values.min().astype(np.int) - alpha_zp) * alpha_scale
904 alpha_max = (alpha.values.max().astype(np.int) - alpha_zp) * alpha_scale
905 if alpha_min == alpha_max:
906 # or even a Relu
907 if alpha_min == 0:
908 new_op = Op.Relu
909 else:
910 new_op = Op.LeakyRelu
911 op.attrs["alpha"] = alpha_min
912 # setup alpha_scaling for bit exact result
913 ifm_scale = ifm.quantization.scale_f32
914 ofm_scale = ofm.quantization.scale_f32
915 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha_scale, ofm_scale)
916 op.attrs["alpha_scaling"] = (alpha.values.min() - alpha_zp, alpha_scale, alpha_shift)
917 # Change op type
918 op.type = new_op
919 op.name = op.name.replace("Prelu", new_op.name)
920 del op.inputs[1] # Remove alpha tensor
921 return op
922 elif alpha_max < 1:
923 # If alpha_max is less than 1 convert PReLU to Max(alpha * IFM, identity * IFM)
924 # Multiply with alpha tensor
925 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
926 mul_alpha.add_input_tensor(ifm)
927 mul_alpha.add_input_tensor(alpha)
928 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
929 mul_alpha.set_output_tensor(fm_alpha)
930 mul_alpha.set_ifm_ofm_shapes()
931 DebugDatabase.add_optimised(op, mul_alpha)
932 if check_quantized_tens_scaling_equal(ifm, ofm):
933 # No scaling is needed
934 fm_id = ifm
935 else:
936 # Add multiplication with identity
937 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
938 mul_identity.add_input_tensor(ifm)
939 # Create const tensor containing identity as scalar
940 quantization = ifm.quantization.clone()
941 quantization.scale_f32 = np.float32(1)
942 quantization.zero_point = 0
943 one = create_const_tensor("one_const", [], ifm.dtype, [1], quantization=quantization)
944 mul_identity.add_input_tensor(one)
945 # Make sure that fm_id is allocated to a different address than fm_alpha
946 fm_id = ofm.clone(op.name + "_id", set_unique=True)
947 mul_identity.set_output_tensor(fm_id)
948 mul_identity.set_ifm_ofm_shapes()
949
950 # Combine scaled and alpha multiplied values
951 max_op = Operation(Op.Maximum, op.name + "_max")
952 max_op.add_input_tensor(fm_alpha)
953 max_op.add_input_tensor(fm_id)
954 max_op.set_output_tensor(ofm)
955 max_op.set_ifm_ofm_shapes()
956
957 DebugDatabase.add_optimised(op, max_op)
958 ifm.consumer_list.remove(op)
959 return max_op
960
961 # Catch all PReLU conversion for the cases that could not be optimised above
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200962 no_scale_quant = ifm.quantization.clone()
963 no_scale_quant.scale_f32 = None
964 no_scale_quant.zero_point = 0
Fredrik Svedberg66591652022-08-29 10:51:27 +0200965 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +0200966
967 # Select values < 0
968 min_op = Operation(Op.Minimum, op.name + "_min")
969 min_op.add_input_tensor(ifm)
970 min_op.add_input_tensor(zero)
971 fm_negative = ifm.clone(op.name + "_negative", set_unique=True)
972 min_op.set_output_tensor(fm_negative)
973 min_op.set_ifm_ofm_shapes()
974 DebugDatabase.add_optimised(op, min_op)
975
976 # and multiply with alpha tensor
977 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
978 mul_alpha.add_input_tensor(fm_negative)
979 mul_alpha.add_input_tensor(alpha)
980 fm_alpha = ofm.clone(op.name + "_negative_alpha", set_unique=True)
981 mul_alpha.set_output_tensor(fm_alpha)
982 mul_alpha.set_ifm_ofm_shapes()
983 DebugDatabase.add_optimised(op, mul_alpha)
984
985 # Select (and scale) values > 0
986 relu_op = Operation(Op.Relu, op.name + "_relu")
987 relu_op.add_input_tensor(ifm)
988 fm_scaled = ofm.clone(op.name + "_positive_scaled", set_unique=True)
989 relu_op.set_output_tensor(fm_scaled)
990 relu_op.set_ifm_ofm_shapes()
991 DebugDatabase.add_optimised(op, relu_op)
992
993 # Add scaled and alpha multiplied values (without scaling)
994 add_op = Operation(Op.RescaleAdd, op.name + "_add")
995 add_op.rescale = (1, 0) # No scale or shift
996 add_op.add_input_tensor(fm_alpha)
997 add_op.add_input_tensor(fm_scaled)
998 add_op.set_output_tensor(ofm)
999 add_op.set_ifm_ofm_shapes()
1000
1001 DebugDatabase.add_optimised(op, add_op)
1002 ifm.consumer_list.remove(op)
1003 op = add_op
1004
1005 return op
1006
1007
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001008def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
1009 r"""Whenever there is a subgraph with this topology:
1010
Jonas Ohlssond8575072022-03-30 10:30:25 +02001011 Input X For X = -1 or X > 0
1012 | \ / This subgraph can be replaced with either
1013 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
1014 | /
1015 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001016 """
1017
1018 if op.type == Op.Maximum:
1019 # finds the Mul input(s) to the Max
1020 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
1021 if len(muls) == 1:
1022 mul = muls[0].ops[0]
1023 elif len(muls) == 2:
1024 # In the case both inputs are Muls, find the one with the same input as the Max
Fredrik Svedberg66591652022-08-29 10:51:27 +02001025 mul_ifms = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1]
1026 if len(mul_ifms):
1027 mul = mul_ifms[0].ops[0]
1028 else:
1029 # Not using same input
1030 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001031 else:
1032 # No Mul inputs
1033 return op
1034
1035 # make sure the Mul doesn't have any other consumers
1036 mul_ofm = mul.outputs[0]
1037 if len(mul_ofm.consumers()) != 1:
1038 return op
1039 # make sure the Mul doesn't have a fused activation function
1040 if mul.activation:
1041 return op
1042 ifm, ofm = op.get_ifm_ofm()
1043 if ifm is None or ofm is None:
1044 return op
1045
1046 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1047 return op
1048 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
1049 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
1050 return op
1051
1052 # finds the branched input that goes to both the Max and the Mul
1053 shared = set(op.inputs) & set(mul.inputs)
1054 if len(shared) == 1:
1055 shared_in = shared.pop()
1056 # find the constant scalar input to the Mul
1057 const_tens = (set(mul.inputs) - {shared_in}).pop()
1058 # check that it is a scalar
1059 if const_tens.shape != []:
1060 return op
1061 const = const_tens.ops[0]
1062 # check that it is a constant
1063 if const.type != Op.Const:
1064 return op
1065 # Remove the Mul from the shared input's consumers
1066 shared_in.consumer_list.remove(mul)
1067 else:
1068 return op
1069
1070 val = const.outputs[0].values
1071 if val >= 0:
1072 new_op = Op.LeakyRelu
1073 op.attrs["alpha"] = val
1074 # to produce bit exact results, the alpha is not enough;
1075 # save additional scaling info in attr "alpha_scale", to be used as input
1076 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +01001077 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001078 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
1079 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
1080 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
1081 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
1082 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
1083 elif val == -1:
1084 new_op = Op.Abs
1085 else:
1086 return op
1087
1088 op.type = new_op
1089 op.name = op.name.replace("Maximum", new_op.name)
1090 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
1091 op.inputs = [shared_in]
1092 op.set_ifm_ofm_shapes()
1093
1094 # Record optimisation in debug database
1095 DebugDatabase.add_optimised(op, op)
1096
1097 return op
1098
1099
1100def convert_hardswish_to_lut(op, arch, nng):
1101 if op.type == Op.HardSwish:
1102 ifm, ofm = op.get_ifm_ofm()
1103 # Generate the LUT
1104 ifm_scale = np.double(ifm.quantization.scale_f32)
1105 ofm_scale = np.double(ofm.quantization.scale_f32)
1106 zp_in = ifm.quantization.zero_point
1107 zp_out = ofm.quantization.zero_point
1108 ifm_scale_hires = (1 / 128) * ifm_scale
1109 relu_multiplier = np.double(3 / 32768)
1110 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
1111 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
1112 # Use 16bit scale
1113 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
1114 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
1115
1116 values = []
1117 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1118 quantized_min = min(ix)
1119 quantized_max = max(ix)
1120 for x in ix:
1121 input_value = x - zp_in
1122 input_value_hires = input_value * 128
1123 # Compute the input value on essentially the output scale, not shifted yet
1124 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
1125 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
1126 relu_value = np.int16(input_value_hires)
1127 if relu_shift < 31:
1128 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
1129
1130 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
1131
1132 if relu_shift < 31:
1133 relu_value = fp_math.shift_left16(relu_value, 1)
1134
1135 if relu_shift > 31:
1136 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
1137
1138 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
1139 # Now convert that to a 16bit fixedpoint value in [0, 1]
1140 relu_value = (relu_value + (1 << 15)) >> 1
1141 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
1142 shift = 31 - out_shift
1143 shift = -shift if shift < 0 else 0
1144 # Finally apply the output shift
1145 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
1146 lut_result = min(quantized_max, max(quantized_min, lut_result))
1147 values.append(lut_result)
1148 return convert_to_lut(op, values, "hardswish")
1149 return op
1150
1151
1152def convert_lrelu_to_mul_max(op, arch):
1153 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
1154 # (the opposite of convert_mul_max_to_abs_or_lrelu)
1155 ifm, ofm = op.get_ifm_ofm()
1156 if ifm is None or ofm is None:
1157 return op
1158
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001159 alpha = np.float32(op.attrs["alpha"])
1160 use_mul_max = 0 < alpha < 1
Fredrik Svedberg36424312022-09-16 09:39:26 +02001161 is_converted_prelu = "alpha_scaling" in op.attrs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001162 if use_mul_max:
1163 mul_ifm = ifm
1164 new_op = Op.Maximum
1165 else:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001166 # Need to use a different approach for alpha < 0 or alpha > 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001167 no_scale_quant = ifm.quantization.clone()
1168 no_scale_quant.scale_f32 = None
1169 no_scale_quant.zero_point = 0
1170 zero = create_const_tensor("zero_const", [], ifm.dtype, [0], quantization=no_scale_quant)
1171
1172 # Select values < 0
1173 min_op = Operation(Op.Minimum, op.name + "_min")
1174 min_op.add_input_tensor(ifm)
1175 min_op.add_input_tensor(zero)
1176 mul_ifm = ifm.clone(op.name + "_negative", set_unique=True)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001177 if alpha < 0 and not is_converted_prelu:
1178 # For negative alpha that is not from a converted PReLU we need to use
1179 # int32 Mul below to perform the (negative) alpha scaling
1180 mul_ifm.dtype = DataType.int32
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001181 min_op.set_output_tensor(mul_ifm)
1182 min_op.set_ifm_ofm_shapes()
1183 new_op = Op.RescaleAdd
1184 op.rescale = (1, 0) # No scale or shift
1185 DebugDatabase.add_optimised(op, min_op)
1186
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001187 # Add multiplication with alpha
1188 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001189 mul_alpha.add_input_tensor(mul_ifm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001190 # Create const tensor containing alpha as scalar
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001191 quantization = ifm.quantization.clone()
1192 quantization.min = 0
1193 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
1194 quantization.zero_point = 0
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001195 alpha_dtype = mul_ifm.dtype
Fredrik Svedberg36424312022-09-16 09:39:26 +02001196 if is_converted_prelu:
1197 # The LeakyRelu was the result from convert_prelu and the scaling is provided
Fredrik Svedberg66591652022-08-29 10:51:27 +02001198 scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1199 mul_alpha.type = Op.RescaleMul
1200 mul_alpha.rescale = [alpha_scale, alpha_shift]
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001201 elif alpha == 0 or np.isinf(1 / alpha):
1202 # Handling of alpha near or at zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001203 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001204 scalar = 0
1205 else:
1206 quantization.scale_f32 = alpha
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001207 if alpha_dtype == DataType.int32:
Fredrik Svedberg36424312022-09-16 09:39:26 +02001208 # When the datatype is int32 (alpha negative) we need to do the scaling with the multiplication
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001209 scalar, _ = scaling.elementwise_mul_scale(ifm.quantization.scale_f32, alpha, ofm.quantization.scale_f32)
1210 else:
1211 scalar = 1
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001212 alpha_tens = create_const_tensor(
Fredrik Svedberg7f3ccd52022-09-13 15:22:01 +02001213 op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], alpha_dtype.as_numpy_type(), quantization=quantization
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001214 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001215 mul_alpha.add_input_tensor(alpha_tens)
1216 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
1217 mul_alpha.set_output_tensor(fm_alpha)
1218 mul_alpha.set_ifm_ofm_shapes()
1219 DebugDatabase.add_optimised(op, mul_alpha)
1220
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001221 if not use_mul_max:
1222 relu_op = Operation(Op.Relu, op.name + "_relu")
1223 relu_op.add_input_tensor(ifm)
1224 fm_id = ofm.clone(op.name + "_positive_scaled", set_unique=True)
1225 relu_op.set_output_tensor(fm_id)
1226 relu_op.set_ifm_ofm_shapes()
1227 DebugDatabase.add_optimised(op, relu_op)
1228 elif check_quantized_tens_scaling_equal(ifm, ofm):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001229 # No identity multiplication is needed
1230 fm_id = ifm
1231 else:
1232 # Add multiplication with identity
1233 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
1234 mul_identity.add_input_tensor(ifm)
1235 # Create const tensor containing identity as scalar
1236 quantization = ifm.quantization.clone()
1237 quantization.min = 0
1238 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +02001239 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001240 quantization.zero_point = 0
1241 identity_tens = create_const_tensor(
1242 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
1243 )
1244 mul_identity.add_input_tensor(identity_tens)
1245 # Make sure that fm_id is allocated to a different address than fm_alpha
1246 fm_id = ofm.clone(op.name + "_id", set_unique=True)
1247 mul_identity.set_output_tensor(fm_id)
1248 mul_identity.set_ifm_ofm_shapes()
1249 DebugDatabase.add_optimised(op, mul_identity)
1250
1251 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001252 op.type = new_op
1253 op.name = op.name.replace("LeakyRelu", new_op.name)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001254 op.inputs = []
1255 ifm.consumer_list.remove(op)
1256 op.add_input_tensor(fm_alpha)
1257 op.add_input_tensor(fm_id)
1258 op.set_ifm_ofm_shapes()
1259
1260 DebugDatabase.add_optimised(op, op)
1261 return op
1262
1263
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001264def convert_to_lut8(op, fn, fn_name):
1265 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
1266 # fn is a function(real) -> real
1267 ifm, ofm = op.get_ifm_ofm()
1268 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
1269 return op
1270 # Generate the LUT
1271 ifm_scale = np.double(ifm.quantization.scale_f32)
1272 ofm_scale = np.double(ofm.quantization.scale_f32)
1273 zp_in = ifm.quantization.zero_point
1274 zp_out = ofm.quantization.zero_point
1275 values = []
1276 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1277 quantized_min = min(ix)
1278 quantized_max = max(ix)
1279 for x in ix:
1280 x_real = ifm_scale * (x - zp_in)
1281 y_real = fn(x_real)
1282 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1283 lut_result = min(quantized_max, max(quantized_min, lut_result))
1284 values.append(lut_result)
1285 return convert_to_lut(op, values, fn_name)
1286
1287
1288def convert_lrelu_to_lut(op, arch):
1289 ifm, ofm = op.get_ifm_ofm()
1290 # Generate the LUT
1291 alpha = op.attrs["alpha"]
1292 ifm_scale = np.double(ifm.quantization.scale_f32)
1293 ofm_scale = np.double(ofm.quantization.scale_f32)
1294 zp_in = ifm.quantization.zero_point
1295 zp_out = ofm.quantization.zero_point
1296 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1297 alpha_scalar = 1
1298 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1299 if "alpha_scaling" in op.attrs:
1300 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1301 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1302 values = []
1303 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1304 quantized_min = min(ix)
1305 quantized_max = max(ix)
1306 for x in ix:
1307 if x < zp_in:
1308 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1309 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1310 )
1311 else:
1312 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1313 lut_result = min(quantized_max, max(quantized_min, lut_result))
1314 values.append(lut_result)
1315 return convert_to_lut(op, values, "lrelu")
1316
1317
1318def convert_lrelu(op, arch, nng):
1319 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
1320 if op.type != Op.LeakyRelu:
1321 return op
1322 ifm, ofm = op.get_ifm_ofm()
1323 if ifm is None or ofm is None:
1324 return op
Fredrik Svedberg36424312022-09-16 09:39:26 +02001325 alpha = op.attrs["alpha"]
1326 if alpha == 0:
1327 # When alpha is 0 the opertion can be converted to a ReLU
1328 op.type = Op.Relu
1329 op.name = op.name.replace("LeakyRelu", op.type.name)
1330 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001331 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1332 # use LUT for int8/uint8
1333 return convert_lrelu_to_lut(op, arch)
Fredrik Svedberg36424312022-09-16 09:39:26 +02001334 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16 and alpha > 0:
Fredrik Svedberg701ba912022-09-07 16:01:15 +02001335 # use LeakyRelu unmodified for int16 with equal input/output scaling and positive alpha
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001336 return op
1337 return convert_lrelu_to_mul_max(op, arch)
1338
1339
1340def convert_tanh_sigmoid_to_lut(op, arch, nng):
1341 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
1342 if op.type == Op.Sigmoid:
1343 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1344 elif op.type == Op.Tanh:
1345 return convert_to_lut8(op, math.tanh, "tanh")
1346 return op
1347
1348
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001349def remove_memory_only_ops(op, arch):
1350 if op.run_on_npu and op.type in memory_only_ops:
1351 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001352
1353
1354def fuse_activation_function_with_prev(op, arch, nng):
1355 # if op is a no-op: attempts to move the activation function to the preceding op
1356 if not op.attrs.get("is_nop", False) or op.activation is None:
1357 return op
1358 ifm, ofm = op.get_ifm_ofm()
1359 if ifm is None or ofm is None:
1360 return op
1361 # finds the input(s) to the operation
1362 prev_op = ifm.ops[0]
1363 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1364 fuse = (
1365 prev_op.run_on_npu
1366 and prev_op.type.npu_block_type != NpuBlockType.Default
1367 and len(ifm.ops) == 1
1368 and len(prev_op.outputs[0].consumers()) == 1
1369 and prev_op.activation is None
1370 )
1371 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1372 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1373 # LUT currently only works correctly for elementwise ops
1374 fuse = False
1375 if not fuse:
1376 return op
1377 # Move the fused activation function + corresponding info to prev_op
1378 prev_op.activation = op.activation
1379 prev_op.forced_output_quantization = op.forced_output_quantization
1380 if op.activation_lut is not None:
1381 prev_op.set_activation_lut(op.activation_lut)
1382 # Bypass op
1383 prev_op.set_output_tensor(ofm)
1384 DebugDatabase.add_optimised(op, prev_op)
1385 return op
1386
1387
1388def _leading_pad_ok(leading_pad, stride, kernel_size):
1389 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1390 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1391 max_size = kernel_size // 2
1392 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1393
1394
1395def replace_pad_by_hw_pad(op: Operation, arch, nng):
1396 """
1397 Tries to completely remove a PAD operator by using hardware padding.
1398 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1399 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1400 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1401 if both operations can be run on the NPU.
1402 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1403 """
1404 if (
1405 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +00001406 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001407 and op.run_on_npu
1408 and op.attrs["padding"] == Padding.VALID
1409 ):
1410 pad_op = op.ifm.ops[0]
1411 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1412 return op
1413 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1414 return op
1415 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1416 k = op.kernel
1417 k_w, k_h = k.dilated_wh()
1418
1419 # Check if the PAD operator can be replaced by hardware padding
1420 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1421 # Too much padding, it would require hardware padding to actually insert zeros
1422 return op
1423 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1424 return op
1425
1426 if op.type.is_avgpool_op():
1427 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1428 for pad, k_size in (
1429 (left, k_w),
1430 (right, k_w),
1431 (top, k_h),
1432 (bottom, k_h),
1433 ):
1434 if pad not in (0, k_size // 2):
1435 return op
1436 # Average pool is converted to depthwise, because NPU average pool + same padding
1437 # has a special implementation that is different from PAD followed by average pool with
1438 # valid padding.
1439 k_w, k_h = op.kernel.width, op.kernel.height
1440 ifm = op.ifm
1441 # Remember other inputs
1442 other_inputs = op.inputs[1:]
1443 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1444 quantization = QuantizationParameters(0.0, 255.0)
1445 quantization.scale_f32 = 1.0 / (k_w * k_h)
1446 quantization.zero_point = 0
1447 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1448 weights = np.full(shape, 1)
1449
1450 weight_tens = create_const_tensor(
1451 op.name + "_weights",
1452 shape,
1453 op.ifm.dtype,
1454 weights,
1455 np.uint8,
1456 purpose=TensorPurpose.Weights,
1457 quantization=quantization,
1458 )
James Peet7519d502021-07-19 16:47:58 +01001459 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001460 op.type = Op.DepthwiseConv2DBias
1461 op.inputs = []
1462 op.add_input_tensor(ifm)
1463 op.add_input_tensor(weight_tens)
1464 # Add bias tensor, all biases set to 0
1465 op.inputs.append(None)
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001466 fixup_bias_tensors(op, arch, nng, DataType.int32)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001467 # Add other inputs
1468 op.inputs.extend(other_inputs)
1469 op.rounding_mode = NpuRoundingMode.NATURAL
1470
1471 # Bypass the PAD operator
1472 op.set_input_tensor(pad_op.ifm, 0)
1473 # Adjust the padding attributes of the convolution operator
1474 op.attrs["padding"] = Padding.EXPLICIT
1475 op.attrs["explicit_padding"] = (top, left, bottom, right)
1476 op.set_ifm_ofm_shapes()
1477 return op
1478
1479
1480def convert_pad(op: Operation, arch, nng):
1481 """
1482 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1483 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1484 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1485 """
1486 if op.type != Op.Pad or not op.run_on_npu:
1487 return op
1488 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1489
1490 ifm = op.ifm
1491 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001492 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001493 ofm = op.ofm
1494 assert ofm is not None
1495 ofm.ops = []
1496 ofm_shape = op.ofm_shapes[0]
1497
1498 # Average pool op that copies IFM to the right place inside the OFM
1499 shp0 = Shape4D(0, 0, 0, 0)
1500 shp_top = shp0.with_height(top)
1501 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1502 avgpool_op.activation = op.activation
1503 quant = ofm.quantization
1504 pad_value = quant.zero_point
1505 # Add operations that fill the borders of the OFM
1506 if top > 0:
1507 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1508 zero_tens = create_const_tensor(
1509 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1510 )
1511 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1512 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1513 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1514 if bottom > 0:
1515 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1516 zero_tens = create_const_tensor(
1517 op.name + "_bottom",
1518 shape.as_list(),
1519 ofm.dtype,
1520 shape.elements() * [pad_value],
1521 np.uint8,
1522 quantization=quant,
1523 )
1524 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1525 create_avg_pool_for_concat(
1526 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1527 )
1528 if left > 0:
1529 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1530 zero_tens = create_const_tensor(
1531 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1532 )
1533 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1534 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1535 if right > 0:
1536 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1537 zero_tens = create_const_tensor(
1538 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1539 )
1540 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1541 create_avg_pool_for_concat(
1542 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1543 )
1544
1545 op.type = Op.ConcatTFLite
1546 return avgpool_op
1547
1548
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001549def fixup_bias_tensors(op, arch, nng, dtype=None):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001550 if op.type.needs_bias() and op.bias is None:
1551 # Op has no bias, add bias tensor filled with zeros
1552 nr_biases = op.inputs[1].shape[-1]
1553 bias_values = [0] * nr_biases
Fredrik Svedbergcc219be2022-09-20 16:32:52 +02001554 # The DataType of the bias tensor can be explicitly provided or deduced from the ifm
1555 # DataType. Default is int32 bias for 8-bit ifms and int64 for int16 ifms.
1556 # For int16 the selected bias DataType will have an impact on the scaling
1557 # used when encoding the scales and biases later. The default mode will match the
1558 # refence with reduced scaling for int64 bias.
1559 # This means that in cases (in the graph optimiser) where DepthwiseConv2DBias
1560 # is used to emulate average pool int32 bias should be selected for full precision
1561 # int16 scaling.
1562 if dtype is None:
1563 dtype = DataType.int64 if op.ifm.dtype == DataType.int16 else DataType.int32
1564 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], dtype, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001565 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1566
1567 return op
1568
1569
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001570def fixup_asymmetric_weights(op, arch, nng):
1571 if op.run_on_npu and (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op()):
1572 if op.ifm.dtype == DataType.int8:
1573 if not np.all(op.weights.quantization.zero_point == 0):
1574 print(f"Warning: {op.type} '{op.name}' has asymmetric weights, zero points have been adjusted.")
1575 op.weights.quantization.zero_point *= 0
1576
1577 return op
1578
1579
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001580def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1581 if op.type == Op.Mean and op.run_on_npu:
1582 keep_dims = op.attrs.get("keep_dims", False)
1583 inp, axis = op.inputs
1584 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001585 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001586 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001587 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001588
1589 # Height and width axes have different index depending on dimensions
1590 if axis.shape == [] or axis.shape[0] == 1: # single axis
1591 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1592 if dims in (2, 3):
1593 if axis == 0:
1594 h, w = shape[axis], 1
1595 else:
1596 h, w = 1, shape[axis]
1597 else:
1598 if axis == 1:
1599 h, w = shape[axis], 1
1600 else:
1601 h, w = 1, shape[axis]
1602 else: # multiple axes
1603 axis = sorted(axis.values)
1604 h, w = [shape[i] for i in axis]
1605
1606 # Set necessary depthwise attributes
1607 op.attrs.update(
1608 {
1609 "padding": Padding.VALID,
1610 "stride_h": 1,
1611 "stride_w": 1,
1612 "strides": (1, 1, 1, 1),
1613 "depth_multiplier": 1,
1614 "channel_multiplier": 1,
1615 "dilation_h_factor": 1,
1616 "dilation_w_factor": 1,
1617 "dilation": (1, 1, 1, 1),
1618 }
1619 )
1620 # Change op type
1621 op.type = Op.DepthwiseConv2DBias
1622 # Set IFM/OFM shapes after changing op type
1623 op.set_ifm_ofm_shapes()
1624
1625 weight_scale, bias = 1, None
1626 ofmq, ifmq = op.ofm.quantization, inp.quantization
1627 # Set rounding mode, scaling and zero point based on which reference implementation to match
1628 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1629 if inp.dtype == DataType.uint8:
1630 # This attribute means a different scaling calculation is used in order to match reference
1631 op.low_precision_scaling = True
1632 weight_scale = h * w
1633 # Set zero points to 0 as they will be adjusted for with bias term
1634 foq = ofmq.clone()
1635 foq.zero_point = 0
1636 fiq = ifmq.clone()
1637 fiq.zero_point = 0
1638 op.forced_input_quantization = fiq
Johan Alfvén17009392022-08-30 09:14:56 +02001639 bias_term = ofmq.zero_point - round_up_to_int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001640 # If the bias term is outside uint8 range, we need an Add op to apply it.
1641 if bias_term < 0 or bias_term > 255:
1642 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1643 # Bias term has higher bitness (i32) than input/output (u8).
1644 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1645 # the bias can only effectively assume values in the range [-255, 255].
1646 intermediate.dtype = DataType.int16
1647 intermediate.quantization.zero_point = 0
1648 add_op = Operation(Op.Add, op.name + "_bias")
1649 add_op.forced_output_quantization = foq
1650 add_op.add_input_tensor(intermediate)
1651 quant = QuantizationParameters()
1652 quant.zero_point = 0
1653 bias_term_tens = create_const_tensor(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001654 op.name + "_bias",
1655 [1, 1, 1, 1],
1656 DataType.int16,
1657 [bias_term],
1658 np.int16,
1659 quantization=quant,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001660 )
1661 add_op.add_input_tensor(bias_term_tens)
1662 add_op.set_output_tensor(op.ofm)
1663 add_op.set_ifm_ofm_shapes()
1664 add_op.activation = op.activation
1665 op.activation = None
1666 op.set_output_tensor(intermediate)
1667 op.set_ifm_ofm_shapes()
1668 # If not, we can just do it with the OFM zero point.
1669 else:
1670 foq.zero_point = bias_term
1671 op.forced_output_quantization = foq
1672 else:
1673 assert inp.dtype == DataType.int8
1674 # Use a depthwise to calculate the sum,
1675 # followed by a multiplication with 1/N to get the MEAN
1676 weight_scale = 1
1677 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
Johan Alfvén05916632022-09-06 20:33:22 +02001678 intermediate.dtype = DataType.int32
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001679 mul_op = Operation(Op.Mul, op.name + "_mul")
1680 mul_op.add_input_tensor(intermediate)
Johan Alfvén05916632022-09-06 20:33:22 +02001681 mul_op.set_output_tensor(op.ofm)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001682 # Create scalar containing 1/N
1683 quant = QuantizationParameters()
1684 quant.zero_point = 0
1685 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1686 # while rounding mode NATURAL would round this to -1.
1687 # This can only occur if N is even, and can be emulated by
1688 # multiplying with a number that is slightly smaller than 1/N.
1689 # It must be so small that other roundings are not affected;
1690 # the calculated value is based on worst case,
1691 # which is sum 256 * N (the maximum sum that can occur with int8)
1692 n = int(h * w)
1693 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1694 quant.scale_f32 = 1 / (n - eps)
Johan Alfvén05916632022-09-06 20:33:22 +02001695
1696 # For int8/int16 we could use IFM/OFM scaling to do the division
1697 # intermediate * 1 -> scale > round and shift.
1698 #
1699 # For int32 scaling is not supported so instead multiply with the scale
1700 # intermediate * scale -> round and shift.
1701 #
1702 # Calculate the scale and shift value. const Tensor must be created
1703 # with correct quantization since the scale and shift is calculated later
1704 # in the command stream generator.
1705 mul_scale, _ = scaling.elementwise_mul_scale(
1706 mul_op.ifm.quantization.scale_f32, quant.scale_f32, mul_op.ofm.quantization.scale_f32
1707 )
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001708 scalar = create_const_tensor(
Johan Alfvén05916632022-09-06 20:33:22 +02001709 op.name + "_scalar", [1, 1, 1, 1], DataType.int32, [mul_scale], np.int32, quantization=quant
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001710 )
1711 mul_op.add_input_tensor(scalar)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001712 mul_op.set_ifm_ofm_shapes()
1713 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1714 mul_op.activation = op.activation
1715 op.activation = None
1716 op.set_output_tensor(intermediate)
1717 op.set_ifm_ofm_shapes()
1718 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1719 # Here we can just use a simple AvgPool with truncating rounding,
1720 # as we're emulating simple integer division.
1721 op.rounding_mode = NpuRoundingMode.TRUNCATE
1722 op.type = Op.AvgPool
1723 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1724 else:
1725 op.rounding_mode = NpuRoundingMode.NATURAL
1726 weight_scale = 1 / (h * w)
1727 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1728 bias = -ifmq.zero_point * h * w
1729 fiq = ifmq.clone()
1730 fiq.zero_point = 0
1731 op.forced_input_quantization = fiq
1732
1733 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001734 def extend_dims(dim, in_shape):
1735 if dim < 4:
1736 in_shape = [1] + in_shape
1737 if dim == 2:
1738 in_shape += [1]
1739 return in_shape
1740
1741 if dims < 4 or dims_ofm < 4:
1742 # Fix the ofm dimension when keep_dims is false
1743 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1744 if isinstance(axis, int) and dims_ofm + 1 == dims:
1745 ofm_shape.insert(axis, 1)
1746 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1747 for i in axis:
1748 ofm_shape.insert(i, 1)
1749 shape = extend_dims(dims, shape)
1750 dims_ofm = len(ofm_shape)
1751 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1752 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001753
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001754 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
1755 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001756 shape = [shape[0], 1, h * w, shape[3]]
1757 op.ifm_shapes[0] = Shape4D(shape)
1758 if h > 256 and op.type == Op.AvgPool:
1759 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1760
1761 # If the AvgPool version is used, we don't need to do anything else
1762 if op.type == Op.AvgPool:
1763 return op
1764
1765 # Make unit weight tensor quantization
1766 weight_quant = ifmq.clone()
1767 weight_quant.min = 0
1768 weight_quant.max = 255
1769 weight_quant.scale_f32 = weight_scale
1770 weight_quant.zero_point = 0
1771
1772 # Set weight shape to [H,W,C,B]
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001773 weight_shape = [h, w, shape[3], shape[0]]
1774
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001775 # Add unit weight tensor
1776 op.set_input_tensor(
1777 create_const_tensor(
1778 "weights",
1779 weight_shape,
1780 inp.dtype,
1781 np.ones(weight_shape),
1782 value_dtype=np.uint8,
1783 quantization=weight_quant,
1784 ),
1785 1,
1786 )
James Peet7519d502021-07-19 16:47:58 +01001787 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001788
1789 # Add None bias tensor
1790 op.inputs.append(None)
1791 # Add bias tensor
1792 if bias:
1793 bias_shape = [shape[-1]]
1794 op.set_input_tensor(
1795 create_const_tensor(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001796 "bias",
1797 bias_shape,
1798 inp.dtype,
1799 np.ones(bias_shape) * bias,
1800 value_dtype=np.int32,
1801 quantization=None,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001802 ),
1803 2,
1804 )
1805
1806 return op
1807
1808
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001809def optimise_quantize(op: Operation, arch, nng):
1810
1811 if op.type == Op.Quantize and op.run_on_npu:
1812
1813 ifm, ofm = op.get_ifm_ofm()
1814 input_values = ifm.values
1815
1816 # Guard clause - input not const or no values to quantize
1817 if ifm.ops[0].type != Op.Const or input_values is None:
1818 return op
1819
1820 # Singular val in numpy array, convert to indexable array
1821 if input_values.ndim == 0:
1822 input_values = np.array([input_values])
1823
Fredrik Svedberg11563172022-07-06 14:54:12 +02001824 # requantized int8 to int8 or int16 to int16
1825 if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001826
1827 # scale needs to use double precision to match TFLite reference kernel
1828 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
1829 effective_multiplier, effective_shift = quantise_scale(effective_scale)
1830
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001831 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001832 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001833 input_val = val - ifm.quantization.zero_point
1834
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001835 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
1836 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001837
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001838 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
1839 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001840
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001841 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
1842 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001843
1844 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001845 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001846
1847 quantized_vals = []
1848 for val in input_values:
1849
1850 # Derive quantized value
1851 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001852 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
1853 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001854
1855 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001856 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
1857
1858 # Unsupported data type
1859 else:
1860 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001861
1862 # Make quantize op const and disconnect from parent node
1863
1864 # Remove reference of the current quant op from the parent tensor's consumer list
1865 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1866
1867 # Clear any references to parent node
1868 op.inputs = []
1869
1870 # Convert this quantize op to const
1871 op.type = Op.Const
1872
1873 return op
1874
1875
Ayaan Masood4965fae2022-06-29 11:30:57 +01001876def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
1877 """Static optimisation for SHAPE operator output value known at compile time"""
1878
1879 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
1880
1881 if op.type == Op.Shape and op.run_on_npu:
1882
1883 ifm, ofm = op.get_ifm_ofm()
1884
1885 if len(ifm.shape) != ofm.shape[0]:
1886 return op
1887
1888 # Remove reference of the current shape op from the parent tensor's consumer list
1889 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1890
1891 # Clear any references to parent node
1892 op.inputs = []
1893
1894 # Convert this SHAPE op to const
1895 op.type = Op.Const
1896
1897 # Add size calculation to shape output tensors
1898 ofm.values = np.array(ifm.shape)
1899
1900 return op
1901
1902
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001903def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001904 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001905 return op
1906
1907
1908def tflite_optimise_graph(nng, arch):
Fredrik Svedberg11563172022-07-06 14:54:12 +02001909 # Compile time static optimisations
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001910 optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
1911
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001912 for idx, sg in enumerate(nng.subgraphs):
1913 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001914 nng,
1915 sg,
1916 arch,
1917 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01001918 optimisation_list,
1919 rewrite_unsupported=False,
1920 )
1921
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001922 # Pre-processing step
1923 pre_process_list = [
1924 supported_operator_check,
1925 set_ifm_ofm_op_shapes,
1926 ]
1927
Ayaan Masood4965fae2022-06-29 11:30:57 +01001928 for idx, sg in enumerate(nng.subgraphs):
1929 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1930 nng,
1931 sg,
1932 arch,
1933 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02001934 pre_process_list,
1935 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001936 )
1937
1938 # Handle Concat Ops
1939 for idx, sg in enumerate(nng.subgraphs):
1940 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1941 sg.refresh_after_modification()
1942
1943 # Handle Split Ops
1944 for idx, sg in enumerate(nng.subgraphs):
1945 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1946 nng,
1947 sg,
1948 arch,
1949 [],
1950 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1951 rewrite_unsupported=False,
1952 )
1953
1954 for idx, sg in enumerate(nng.subgraphs):
1955 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001956 nng,
1957 sg,
1958 arch,
1959 [rewrite_split_ops],
1960 [],
1961 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001962 )
1963
1964 # Handle sg input output
1965 for idx, sg in enumerate(nng.subgraphs):
1966 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001967 nng,
1968 sg,
1969 arch,
1970 [],
1971 [fix_sg_input_output],
1972 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001973 )
1974
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001975 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001976 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001977 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001978 sg.refresh_after_modification()
1979
1980 # Rewrite of operators
1981 op_rewrite_list = [
1982 set_tensor_equivalence,
1983 convert_mean_to_depthwise_conv_or_avgpool,
1984 convert_depthwise_to_conv,
1985 convert_conv_to_fc,
1986 convert_softmax,
Fredrik Svedberg8ddd4892022-08-19 16:06:04 +02001987 convert_prelu,
Fredrik Svedberg36424312022-09-16 09:39:26 +02001988 convert_mul_max_to_abs_or_lrelu,
1989 convert_lrelu,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001990 optimise_strided_conv,
1991 convert_hardswish_to_lut,
1992 rewrite_fully_connected_input,
1993 convert_batched_fc_shape,
1994 fixup_conv2d_backprop,
1995 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001996 reorder_depthwise_weights,
Tim Hall885033b2022-07-21 11:46:03 +01001997 fixup_resize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001998 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001999 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002000 convert_tanh_sigmoid_to_lut,
2001 replace_pad_by_hw_pad,
2002 ]
2003
2004 for idx, sg in enumerate(nng.subgraphs):
2005 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02002006 nng,
2007 sg,
2008 arch,
2009 [],
2010 op_rewrite_list,
2011 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002012 )
2013
2014 for idx, sg in enumerate(nng.subgraphs):
2015 # remove passthrough tensors and attempt further optimizations
2016 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
2017 nng,
2018 sg,
2019 arch,
2020 [remove_passthrough_tensor],
2021 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
2022 )
2023
2024 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
2025 # since ifm/ofm_shapes are of importance to this function
2026 for sg in nng.subgraphs:
2027 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
2028 sg.refresh_after_modification()
2029
2030 return nng