blob: f2a8c803dc1996da3e4449e4400cd1f07bf00686 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from . import rewrite_graph
26from . import scaling
27from .api import NpuRoundingMode
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020028from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020029from .data_type import DataType
30from .debug_database import DebugDatabase
31from .errors import UnsupportedFeatureError
32from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020033from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020034from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020035from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020036from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020037from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020038from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020039from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020040from .graph_optimiser_util import needed_total_padding
41from .graph_optimiser_util import set_ifm_ofm_op_shapes
42from .graph_optimiser_util import set_tensor_equivalence
43from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020044from .numeric_util import round_away_zero
45from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020046from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020047from .operation import NpuBlockType
48from .operation import Op
49from .operation import Operation
50from .operation import Padding
51from .operation_util import create_avgpool_nop
52from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010053from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020054from .shape4d import Shape4D
55from .softmax import SoftMax
56from .tensor import check_quantized_tens_scaling_equal
57from .tensor import create_const_tensor
58from .tensor import create_equivalence_id
59from .tensor import QuantizationParameters
60from .tensor import Tensor
61from .tensor import TensorPurpose
62from .tflite_mapping import optype_to_builtintype
63
64passthrough_nodes = (Op.Identity,)
65
66
67def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
68 """Creates an average pool for the given concat op/input feature map"""
69 ofm = concat_op.ofm
70 avgpool_op = create_avgpool_nop(name)
71 avgpool_op.inputs = [ifm]
72 avgpool_op.outputs = [ofm]
73
74 avgpool_op.write_offset = write_offset
75 avgpool_op.write_shape = ifm_shape
76 ofm.ops.append(avgpool_op)
77 DebugDatabase.add_optimised(concat_op, avgpool_op)
78 avgpool_op.ifm_shapes.append(ifm_shape)
79 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
80 avgpool_op.memory_function = Op.ConcatSliceWrite
81 return avgpool_op
82
83
84def remove_passthrough_tensor(tens, arch, nng):
85 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
86 assert len(tens.ops[0].inputs) == 1
87 tens = tens.ops[0].inputs[0]
88 return tens
89
90
91def rewrite_concat_ops(op, arch):
92 if not op.run_on_npu or not op.type.is_concat_op():
93 return
94
95 axis_4D = 0
96 ofm = op.ofm
97 ofm.ops = []
98 offset = 0
99
100 unfuse_activation_function(op)
101
102 if op.type == Op.Pack:
103 # Pack is also referred to as Stack
104 axis = int(op.attrs["axis"])
105 if axis < 0: # Convert to positive axis
106 axis = len(op.inputs[0].shape) + 1 + axis
107
108 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
109
110 axis_4D = axis + (4 - len(desired_shape))
111
112 for idx, inp in enumerate(op.inputs):
113 op.ifm_shapes[idx] = Shape4D(desired_shape)
114 op.type = Op.PackReshaped
115
116 inputs, axis = op.get_concat_inputs_axis()
117 for idx, inp in enumerate(inputs):
118 if op.type != Op.PackReshaped:
119 op.ifm_shapes[idx] = Shape4D(inp.shape)
120 if axis >= 0:
121 axis_4D = axis + (4 - len(inp.shape))
122 else:
123 axis_4D = axis
124 write_offset = [0, 0, 0, 0]
125 write_offset[axis_4D] = offset
126 concat_end = offset + op.ifm_shapes[idx][axis_4D]
127 create_avg_pool_for_concat(
128 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
129 )
130 offset = concat_end
131 assert ofm.shape[axis] == offset
132
133 return op
134
135
136def rewrite_split_ops(tens, arch, nng):
137
138 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
139 split_op = tens.ops[0]
140
141 # Not supported so leave it and run on CPU
142 if not split_op.run_on_npu:
143 return tens
144
145 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
146
147 tens.ops = []
148 new_op = Operation(Op.SplitSliceRead, split_op.name)
149 new_op.inputs = [inp]
150 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000151 if None in (offset_end, offset_start):
152 read_shape = None
153 else:
154 # the read shape is relative to each start offset
155 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200156
157 # For Split the offset cannot be extracted from the tensor so it has to
158 # be calculated from the index of the output tensor
159 if axis is not None:
160 # Get the start and end of the split
161 offset_start = [0] * 4
162 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
163 for idx, out in enumerate(outputs):
164 if axis_4D_list is not None:
165 axis_4D = axis_4D_list[idx]
166 else:
167 split_op.ofm_shapes[idx] = Shape4D(out.shape)
168 if axis >= 0:
169 axis_4D = axis + (4 - len(out.shape))
170 else:
171 axis_4D = axis
172
173 if out == tens:
174 ofm_shape_idx = idx
175 read_shape = split_op.ofm_shapes[idx]
176 break
177
178 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
179
180 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
181 new_op.read_shapes[0] = read_shape
182 new_op.run_on_npu = True
183 new_op.set_output_tensor(tens)
184 new_op.ifm_shapes.append(Shape4D(inp.shape))
185 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
186 DebugDatabase.add_optimised(split_op, new_op)
187
188 return tens
189
190
191def remove_SplitSliceRead(op, arch):
192
193 if op.type == Op.SplitSliceRead:
194 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
195 if (
196 len(op.ofm.consumer_list) == 1
197 and op.ofm.consumer_list[0] is not None
198 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200199 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200200 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
201 ):
202 # SplitSliceRead can be performed by tensor consumer
203 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200204 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200205 else:
206 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
207 avgpool_op.add_input_tensor(op.ifm)
208 avgpool_op.outputs = [op.ofm]
209 op.ofm.ops.remove(op)
210 op.ofm.ops.append(avgpool_op)
211 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
212 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
213 avgpool_op.read_offsets[0] = op.read_offsets[0]
214 avgpool_op.read_shapes[0] = op.read_shapes[0]
215
216 op.ifm.consumer_list.remove(op)
217 DebugDatabase.add_optimised(op, avgpool_op)
218
219
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200220def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
221 k_w, k_h = kernel.dilated_wh()
222 s_x, s_y = kernel.stride
223 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
224 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
225 if padding_type == Padding.SAME:
226 left_pad = (xpad + 0) // 2
227 right_pad = (xpad + 1) // 2
228 top_pad = (ypad + 0) // 2
229 bottom_pad = (ypad + 1) // 2
230 elif padding_type == Padding.VALID:
231 left_pad = 0
232 right_pad = 0
233 top_pad = 0
234 bottom_pad = 0
235 elif padding_type == Padding.EXPLICIT:
236 # Padding is specified in a PAD operator which has been bypassed.
237 top, left, bottom, right = explicit_padding
238 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
239 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
240 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000241 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200242 padding = (top_pad, left_pad, bottom_pad, right_pad)
243 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
244 return padding, skirt
245
246
247def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
248 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
249 if padding_type == Padding.SAME:
250 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
251 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
252 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
253 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
254 left_pad = max(kernel_width - 1 - right_pad, 0)
255 top_pad = max(kernel_height - 1 - bottom_pad, 0)
256 elif padding_type == Padding.VALID:
257 right_pad = max(kernel_width - 2, 0)
258 bottom_pad = max(kernel_height - 2, 0)
259 left_pad = kernel_width - 1
260 top_pad = kernel_height - 1
261 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000262 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200263 padding = (top_pad, left_pad, bottom_pad, right_pad)
264 skirt = padding
265 return padding, skirt
266
267
268def fixup_conv2d_backprop(op, arch, nng):
269 if op.type == Op.Conv2DBackpropInput:
270 # flip the inputs
271 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
272 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000273 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200274
275 # Update strides
276 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
277
278 return op
279
280
281# Convert the op to an elementwise add
282def convert_resizebilinear_1x1_to_add(op):
283 op.type = Op.Add
284 op.name = op.name + "_add"
285 op.attrs["resizebilinear"] = True
286 # Create an input tensor filled with zeros
287 shape = op.ofm_shapes[0].as_list()
288 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100289 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200290 tens.quantization = QuantizationParameters(0.0, 255.0)
291 tens.quantization.scale_f32 = 1.0
292 tens.quantization.zero_point = 0
293 tens.consumer_list = [op]
294 tens_op = op.inputs[1].ops[0]
295 tens_op.set_output_tensor(tens)
296 # Set the add inputs
297 op.inputs[1] = op.inputs[0]
298 op.inputs[0] = tens
299 op.set_ifm_ofm_shapes()
300
301 return op
302
303
Rickard Boline546def2022-01-25 15:45:00 +0000304# Convert ResizeBilinear to a number of 2x2 nearest neighbor upscaling and one avgpool op with kernel size dependent
305# on the upscaling factor. Avgpool kernel limit of 8x8 when padding is applied limits upscaling to 8x8.
306def convert_resizebilinear_to_nearest_neighbor_upscaling_and_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200307 pre_op = op
308 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000309 dtype = op.ifm.dtype
310 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200311 if op.attrs["align_corners"]:
312 shape_modifier = 1
313 op.attrs["padding"] = Padding.VALID
314 else:
315 shape_modifier = 0
316 op.attrs["padding"] = Padding.SAME
Tim Hall3c5cfe92022-03-16 16:31:57 +0000317 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200318
319 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
320 out_shape = np.array(op.ofm_shapes[0].get_hw_as_list())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200321
Rickard Boline546def2022-01-25 15:45:00 +0000322 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100323 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
324 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
325 upscale_factor = int(round(out_shape[1] / upscaled_shape[1]))
Rickard Boline546def2022-01-25 15:45:00 +0000326 n = int(np.log2(upscale_factor))
327
328 # Perform 2x2 upscaling n-1 times
329 scaled_op = pre_op
330 for count in range(n - 1):
331 if count > 0:
332 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200333 scaled_op.inputs[0] = pre_op.outputs[0]
334
Rickard Boline546def2022-01-25 15:45:00 +0000335 # Nearest neighbor 2x2 upscaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200336 upscaled_shape = upscaled_shape * 2 - shape_modifier
Rickard Boline546def2022-01-25 15:45:00 +0000337 shape = op.ofm_shapes[0].as_list()
338 shape[1:3] = upscaled_shape
339 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
340 out_tens.quantization = op.outputs[0].quantization.clone()
341 scaled_op.set_output_tensor(out_tens)
342 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200343
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200344 scaled_op.set_ifm_ofm_shapes()
345
Rickard Boline546def2022-01-25 15:45:00 +0000346 # Last 2x2 upscaling also applies avgpool with kernel size dependent on the upscaling factor and adds
347 # padding to the right and bottom.
348 if n > 1:
349 scaled_op = op.clone(f"_{n-1}")
350 scaled_op.inputs[0] = pre_op.outputs[0]
351 scaled_op.attrs["padding"] = Padding.EXPLICIT
352 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
353 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
354 scaled_op.outputs = outputs
355 scaled_op.outputs[0].ops = [scaled_op]
356 scaled_op.set_ifm_ofm_shapes()
357
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200358 return op
359
360
361def fixup_resizebilinear(op, arch, nng):
362 if op.type == Op.ResizeBilinear and op.run_on_npu:
363 if op.ifm_shapes[0] == op.ofm_shapes[0]:
364 # Bypass nop resizebilinear
365 op.inputs = op.inputs[:1]
366 op.type = Op.Identity
367 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
368 convert_resizebilinear_1x1_to_add(op)
369 else:
Rickard Boline546def2022-01-25 15:45:00 +0000370 convert_resizebilinear_to_nearest_neighbor_upscaling_and_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200371
372 return op
373
374
375def convert_nop_split_to_identity(op, arch, nng):
376 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
377 # the list comprehension should return a list with a single tensor
378 # if it shouldn't, remove_passthrough_tensor will fail appropriately
379 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
380 op.type = Op.Identity
381 return op
382
383
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100384def rewrite_fully_connected_input(op: Operation, arch, nng):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200385
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100386 if op.type == Op.FullyConnected:
387 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
388 assert new_shape is not None, "Tensor can not be reshaped to 2D"
389 op.ifm_shapes[0] = new_shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200390 return op
391
392
393def convert_batched_fc_shape(op, arch, nng):
394 if op.type == Op.FullyConnected:
395 # Check if the first dimension indicates batching
396 if op.ifm_shapes[0].batch > 1:
397 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
398 n = op.ifm_shapes[0].batch
399 h, w = batching_split.get(n, (1, n))
400 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
401
402 # Reshape Weights to be 4D. IO becomes HWIO
403 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100404 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
405 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200406
407 n = op.ofm_shapes[0].batch
408 h, w = batching_split.get(n, (1, n))
409 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
410 return op
411
412
413def unfuse_activation_function(op):
414 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
415 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
416 op.activation = None
417 out_tens = op.outputs[0]
418 intermediate_tens = out_tens.clone("_act_intermediate")
419 act_op.set_output_tensor(out_tens)
420 act_op.add_input_tensor(intermediate_tens)
421 op.set_output_tensor(intermediate_tens)
422 act_op.set_ifm_ofm_shapes()
423
424
425def rewrite_stridedslice_output(op, arch, nng):
426 if not op.run_on_npu or op.type != Op.StridedSlice:
427 return op
428
429 new_axis_mask = op.attrs["new_axis_mask"]
430 shrink_axis_mask = op.attrs["shrink_axis_mask"]
431
432 if shrink_axis_mask == 0 and new_axis_mask == 0:
433 return op
434
435 axis_4D = [0] * len(op.outputs)
436 for idx, out_tens in enumerate(op.outputs):
437 output_shape = list(out_tens.shape)
438
439 if shrink_axis_mask != 0:
440 n = 0
441 axis = 0
442 while shrink_axis_mask:
443 prev_mask = shrink_axis_mask
444 n += 1
445 shrink_axis_mask &= shrink_axis_mask - 1
446 axis = int(math.log2(prev_mask - shrink_axis_mask))
447 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
448
449 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
450 op.attrs["shrink_axis_mask"] = 0
451 if axis >= 0:
452 axis_4D[idx] = axis + (4 - len(output_shape))
453 else:
454 axis_4D[idx] = axis
455 op.ofm_shapes[idx] = Shape4D(output_shape)
456
457 elif new_axis_mask != 0:
458 n = 0
459 axis = 0
460 while new_axis_mask:
461 prev_mask = new_axis_mask
462 n += 1
463 new_axis_mask &= new_axis_mask - 1
464 axis = int(math.log2(prev_mask - new_axis_mask))
465 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
466 new_axis_mask >>= 1
467
468 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
469 op.attrs["new_axis_mask"] = 0
470 if axis >= 0:
471 axis_4D[idx] = axis + (4 - len(output_shape))
472 else:
473 axis_4D[idx] = axis
474 op.ofm_shapes[idx] = Shape4D(output_shape)
475
476 op.attrs["split_axis_4D"] = axis_4D
477 return op
478
479
480def rewrite_unpack_output(op, arch, nng):
481 tens = op.outputs[0]
482 if op.run_on_npu and op.type == Op.Unpack:
483 # Unpack is also referred to as Unstack
484 axis = int(op.attrs["axis"])
485 if axis < 0: # Convert to positive axis
486 axis = len(op.inputs[0].shape) + 1 + axis
487 op.type = Op.UnpackReshaped
488 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
489
490 axis_4D = axis + (4 - len(desired_output_shape))
491 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
492
493 for idx, out_tens in enumerate(op.outputs):
494 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
495 return op
496
497
498def add_padding_fields(op, arch, nng):
499 if op.run_on_npu:
500 if "padding" in op.attrs:
501 input_shape = op.ifm_shapes[0]
502 output_shape = op.ofm_shapes[0]
503 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
504 kernel_size = op.inputs[1].shape[:2]
505 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
506 kernel_size = op.attrs["ksize"][1:3]
507 else:
508 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
509
510 if op.type == Op.Conv2DBackpropInputSwitchedBias:
511 upscaling_factor = output_shape.height // input_shape.height
512 padding, skirt = calc_upscaled_padding_and_skirt(
513 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
514 )
515 else:
516 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200517 op.attrs["padding"],
518 op.kernel,
519 input_shape,
520 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200521 )
522
523 op.attrs["explicit_padding"] = padding
524 op.attrs["skirt"] = skirt
525
526 return op
527
528
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200529def reorder_depthwise_weights(op, arch, nng):
530 if op.type.is_depthwise_conv2d_op():
531 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100532 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
533 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200534 weight_tensor.weight_transpose_depthwise = True
535
536 return op
537
538
539def optimise_strided_conv(op, arch, nng):
Louis Verhaard43d27582022-03-17 14:06:00 +0100540 if op.type != Op.Conv2DBias or op.op_index != 0:
541 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200542 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100543 weight_tensor = op.weights
544 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200545
546 if (
Louis Verhaard43d27582022-03-17 14:06:00 +0100547 stride_x == 2
548 and ifm_shape.depth <= 4
549 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200550 and weight_tensor is not None
551 and weight_tensor.shape[1] >= 2
552 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100553 k_w, _ = op.get_kernel_size()
554 curr_padding_x = needed_total_padding(ifm_shape.width, 2, k_w)
555 optimised_padding_x = needed_total_padding(ifm_shape.width // 2, 1, (k_w + 1) // 2)
556 if curr_padding_x != optimised_padding_x:
557 # Horizontal padding would become different after optimisation; this would not work
558 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200559 # IFM
560 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
561
562 # Weights
563 weight_shape = weight_tensor.shape
564 if weight_shape[1] % 2 != 0:
565 weight_shape[1] = weight_shape[1] + 1
566 padded_array = np.zeros(weight_shape)
567 for i in range(weight_shape[0]):
568 padded_array[i] = np.vstack(
569 [
James Peet7519d502021-07-19 16:47:58 +0100570 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200571 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
572 ]
573 )
James Peet7519d502021-07-19 16:47:58 +0100574 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200575 weight_shape[1] //= 2
576 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100577 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200578 weight_tensor.set_all_shapes(weight_shape)
579 # If multiple copies of the weights are used, we could avoid
580 # them having the same address by changing the value_id
581 weight_tensor.value_id = uuid.uuid4()
582
583 # Strides
584 stride_x = 1
585 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
586
587 return op
588
589
590def convert_conv_to_fc(op, arch, nng):
591 # Conv 1x1 can be equivalent to Fully Connected.
592 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
593 # caching/double buffering for the weights.
594 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
595 if op.type == Op.Conv2DBias:
596 h = op.ifm_shapes[0].height
597 w = op.ifm_shapes[0].width
598 kh, kw, _, _ = op.inputs[1].shape
599 if h == 1 and w == 1 and kh == 1 and kw == 1:
600 # Overwrite this op as a Fully Connected Op
601 op.name += "_fc"
602 op.type = Op.FullyConnected
603 op.attrs = {
604 "weights_format": 0,
605 }
606 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
607 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100608 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
609 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200610
611 DebugDatabase.add_optimised(op, op)
612 return op
613
614
615def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
616 if op.run_on_npu and op.type.is_relu_op():
617 ifm = op.inputs[0]
618 ofm = op.outputs[0]
619 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
620 # and requires its own to be inserted
621 if not check_quantized_tens_scaling_equal(ifm, ofm):
622 # Override this op with its own primary op (avgpool)
623 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
624 # And fuse the original activation function to it
625 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200626 # Add explicit rescaling
627 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
628 multiplier, shift = scaling.quantise_scale(rescale)
629 relu_fused_op.rescale = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200630 # Tidy up and assign the ifm and ofm to the new op
631 ifm.consumer_list.remove(op)
632
633 relu_fused_op.add_input_tensor(ifm)
634 relu_fused_op.set_output_tensor(ofm)
635 relu_fused_op.set_ifm_ofm_shapes()
636 op = relu_fused_op
637 return op
638
639
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200640def convert_softmax(op, arch, nng):
641 if op.type == Op.Softmax and op.run_on_npu:
642 softmax = SoftMax(op)
643 op = softmax.get_graph()
644 return op
645
646
647def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
648 r"""Whenever there is a subgraph with this topology:
649
Jonas Ohlssond8575072022-03-30 10:30:25 +0200650 Input X For X = -1 or X > 0
651 | \ / This subgraph can be replaced with either
652 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
653 | /
654 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200655 """
656
657 if op.type == Op.Maximum:
658 # finds the Mul input(s) to the Max
659 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
660 if len(muls) == 1:
661 mul = muls[0].ops[0]
662 elif len(muls) == 2:
663 # In the case both inputs are Muls, find the one with the same input as the Max
664 mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
665 else:
666 # No Mul inputs
667 return op
668
669 # make sure the Mul doesn't have any other consumers
670 mul_ofm = mul.outputs[0]
671 if len(mul_ofm.consumers()) != 1:
672 return op
673 # make sure the Mul doesn't have a fused activation function
674 if mul.activation:
675 return op
676 ifm, ofm = op.get_ifm_ofm()
677 if ifm is None or ofm is None:
678 return op
679
680 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
681 return op
682 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
683 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
684 return op
685
686 # finds the branched input that goes to both the Max and the Mul
687 shared = set(op.inputs) & set(mul.inputs)
688 if len(shared) == 1:
689 shared_in = shared.pop()
690 # find the constant scalar input to the Mul
691 const_tens = (set(mul.inputs) - {shared_in}).pop()
692 # check that it is a scalar
693 if const_tens.shape != []:
694 return op
695 const = const_tens.ops[0]
696 # check that it is a constant
697 if const.type != Op.Const:
698 return op
699 # Remove the Mul from the shared input's consumers
700 shared_in.consumer_list.remove(mul)
701 else:
702 return op
703
704 val = const.outputs[0].values
705 if val >= 0:
706 new_op = Op.LeakyRelu
707 op.attrs["alpha"] = val
708 # to produce bit exact results, the alpha is not enough;
709 # save additional scaling info in attr "alpha_scale", to be used as input
710 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +0100711 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200712 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
713 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
714 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
715 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
716 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
717 elif val == -1:
718 new_op = Op.Abs
719 else:
720 return op
721
722 op.type = new_op
723 op.name = op.name.replace("Maximum", new_op.name)
724 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
725 op.inputs = [shared_in]
726 op.set_ifm_ofm_shapes()
727
728 # Record optimisation in debug database
729 DebugDatabase.add_optimised(op, op)
730
731 return op
732
733
734def convert_hardswish_to_lut(op, arch, nng):
735 if op.type == Op.HardSwish:
736 ifm, ofm = op.get_ifm_ofm()
737 # Generate the LUT
738 ifm_scale = np.double(ifm.quantization.scale_f32)
739 ofm_scale = np.double(ofm.quantization.scale_f32)
740 zp_in = ifm.quantization.zero_point
741 zp_out = ofm.quantization.zero_point
742 ifm_scale_hires = (1 / 128) * ifm_scale
743 relu_multiplier = np.double(3 / 32768)
744 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
745 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
746 # Use 16bit scale
747 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
748 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
749
750 values = []
751 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
752 quantized_min = min(ix)
753 quantized_max = max(ix)
754 for x in ix:
755 input_value = x - zp_in
756 input_value_hires = input_value * 128
757 # Compute the input value on essentially the output scale, not shifted yet
758 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
759 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
760 relu_value = np.int16(input_value_hires)
761 if relu_shift < 31:
762 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
763
764 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
765
766 if relu_shift < 31:
767 relu_value = fp_math.shift_left16(relu_value, 1)
768
769 if relu_shift > 31:
770 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
771
772 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
773 # Now convert that to a 16bit fixedpoint value in [0, 1]
774 relu_value = (relu_value + (1 << 15)) >> 1
775 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
776 shift = 31 - out_shift
777 shift = -shift if shift < 0 else 0
778 # Finally apply the output shift
779 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
780 lut_result = min(quantized_max, max(quantized_min, lut_result))
781 values.append(lut_result)
782 return convert_to_lut(op, values, "hardswish")
783 return op
784
785
786def convert_lrelu_to_mul_max(op, arch):
787 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
788 # (the opposite of convert_mul_max_to_abs_or_lrelu)
789 ifm, ofm = op.get_ifm_ofm()
790 if ifm is None or ofm is None:
791 return op
792
793 # Add multiplication with alpha
794 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
795 mul_alpha.add_input_tensor(ifm)
796 # Create const tensor containing alpha as scalar
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200797 alpha = np.float32(op.attrs["alpha"])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200798 quantization = ifm.quantization.clone()
799 quantization.min = 0
800 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
801 quantization.zero_point = 0
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200802 if np.isinf(1 / alpha):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200803 # Handling of alpha near zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200804 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200805 scalar = 0
806 else:
807 quantization.scale_f32 = alpha
808 scalar = alpha
809 alpha_tens = create_const_tensor(
810 op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
811 )
James Peet7519d502021-07-19 16:47:58 +0100812 alpha_tens.values = np.array([1])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200813 mul_alpha.add_input_tensor(alpha_tens)
814 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
815 mul_alpha.set_output_tensor(fm_alpha)
816 mul_alpha.set_ifm_ofm_shapes()
817 DebugDatabase.add_optimised(op, mul_alpha)
818
819 if check_quantized_tens_scaling_equal(ifm, ofm):
820 # No identity multiplication is needed
821 fm_id = ifm
822 else:
823 # Add multiplication with identity
824 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
825 mul_identity.add_input_tensor(ifm)
826 # Create const tensor containing identity as scalar
827 quantization = ifm.quantization.clone()
828 quantization.min = 0
829 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200830 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200831 quantization.zero_point = 0
832 identity_tens = create_const_tensor(
833 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
834 )
835 mul_identity.add_input_tensor(identity_tens)
836 # Make sure that fm_id is allocated to a different address than fm_alpha
837 fm_id = ofm.clone(op.name + "_id", set_unique=True)
838 mul_identity.set_output_tensor(fm_id)
839 mul_identity.set_ifm_ofm_shapes()
840 DebugDatabase.add_optimised(op, mul_identity)
841
842 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
843 op.type = Op.Maximum
844 op.name = op.name.replace("LeakyRelu", "Maximum")
845 op.inputs = []
846 ifm.consumer_list.remove(op)
847 op.add_input_tensor(fm_alpha)
848 op.add_input_tensor(fm_id)
849 op.set_ifm_ofm_shapes()
850
851 DebugDatabase.add_optimised(op, op)
852 return op
853
854
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200855def convert_to_lut8(op, fn, fn_name):
856 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
857 # fn is a function(real) -> real
858 ifm, ofm = op.get_ifm_ofm()
859 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
860 return op
861 # Generate the LUT
862 ifm_scale = np.double(ifm.quantization.scale_f32)
863 ofm_scale = np.double(ofm.quantization.scale_f32)
864 zp_in = ifm.quantization.zero_point
865 zp_out = ofm.quantization.zero_point
866 values = []
867 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
868 quantized_min = min(ix)
869 quantized_max = max(ix)
870 for x in ix:
871 x_real = ifm_scale * (x - zp_in)
872 y_real = fn(x_real)
873 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
874 lut_result = min(quantized_max, max(quantized_min, lut_result))
875 values.append(lut_result)
876 return convert_to_lut(op, values, fn_name)
877
878
879def convert_lrelu_to_lut(op, arch):
880 ifm, ofm = op.get_ifm_ofm()
881 # Generate the LUT
882 alpha = op.attrs["alpha"]
883 ifm_scale = np.double(ifm.quantization.scale_f32)
884 ofm_scale = np.double(ofm.quantization.scale_f32)
885 zp_in = ifm.quantization.zero_point
886 zp_out = ofm.quantization.zero_point
887 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
888 alpha_scalar = 1
889 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
890 if "alpha_scaling" in op.attrs:
891 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
892 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
893 values = []
894 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
895 quantized_min = min(ix)
896 quantized_max = max(ix)
897 for x in ix:
898 if x < zp_in:
899 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
900 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
901 )
902 else:
903 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
904 lut_result = min(quantized_max, max(quantized_min, lut_result))
905 values.append(lut_result)
906 return convert_to_lut(op, values, "lrelu")
907
908
909def convert_lrelu(op, arch, nng):
910 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
911 if op.type != Op.LeakyRelu:
912 return op
913 ifm, ofm = op.get_ifm_ofm()
914 if ifm is None or ofm is None:
915 return op
916 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
917 # use LUT for int8/uint8
918 return convert_lrelu_to_lut(op, arch)
919 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
920 # use LeakyRelu unmodified for int16 with equal input/output scaling
921 return op
922 return convert_lrelu_to_mul_max(op, arch)
923
924
925def convert_tanh_sigmoid_to_lut(op, arch, nng):
926 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
927 if op.type == Op.Sigmoid:
928 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
929 elif op.type == Op.Tanh:
930 return convert_to_lut8(op, math.tanh, "tanh")
931 return op
932
933
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200934def remove_memory_only_ops(op, arch):
935 if op.run_on_npu and op.type in memory_only_ops:
936 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200937
938
939def fuse_activation_function_with_prev(op, arch, nng):
940 # if op is a no-op: attempts to move the activation function to the preceding op
941 if not op.attrs.get("is_nop", False) or op.activation is None:
942 return op
943 ifm, ofm = op.get_ifm_ofm()
944 if ifm is None or ofm is None:
945 return op
946 # finds the input(s) to the operation
947 prev_op = ifm.ops[0]
948 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
949 fuse = (
950 prev_op.run_on_npu
951 and prev_op.type.npu_block_type != NpuBlockType.Default
952 and len(ifm.ops) == 1
953 and len(prev_op.outputs[0].consumers()) == 1
954 and prev_op.activation is None
955 )
956 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
957 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
958 # LUT currently only works correctly for elementwise ops
959 fuse = False
960 if not fuse:
961 return op
962 # Move the fused activation function + corresponding info to prev_op
963 prev_op.activation = op.activation
964 prev_op.forced_output_quantization = op.forced_output_quantization
965 if op.activation_lut is not None:
966 prev_op.set_activation_lut(op.activation_lut)
967 # Bypass op
968 prev_op.set_output_tensor(ofm)
969 DebugDatabase.add_optimised(op, prev_op)
970 return op
971
972
973def _leading_pad_ok(leading_pad, stride, kernel_size):
974 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
975 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
976 max_size = kernel_size // 2
977 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
978
979
980def replace_pad_by_hw_pad(op: Operation, arch, nng):
981 """
982 Tries to completely remove a PAD operator by using hardware padding.
983 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
984 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
985 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
986 if both operations can be run on the NPU.
987 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
988 """
989 if (
990 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +0000991 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200992 and op.run_on_npu
993 and op.attrs["padding"] == Padding.VALID
994 ):
995 pad_op = op.ifm.ops[0]
996 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
997 return op
998 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
999 return op
1000 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1001 k = op.kernel
1002 k_w, k_h = k.dilated_wh()
1003
1004 # Check if the PAD operator can be replaced by hardware padding
1005 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1006 # Too much padding, it would require hardware padding to actually insert zeros
1007 return op
1008 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1009 return op
1010
1011 if op.type.is_avgpool_op():
1012 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1013 for pad, k_size in (
1014 (left, k_w),
1015 (right, k_w),
1016 (top, k_h),
1017 (bottom, k_h),
1018 ):
1019 if pad not in (0, k_size // 2):
1020 return op
1021 # Average pool is converted to depthwise, because NPU average pool + same padding
1022 # has a special implementation that is different from PAD followed by average pool with
1023 # valid padding.
1024 k_w, k_h = op.kernel.width, op.kernel.height
1025 ifm = op.ifm
1026 # Remember other inputs
1027 other_inputs = op.inputs[1:]
1028 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1029 quantization = QuantizationParameters(0.0, 255.0)
1030 quantization.scale_f32 = 1.0 / (k_w * k_h)
1031 quantization.zero_point = 0
1032 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1033 weights = np.full(shape, 1)
1034
1035 weight_tens = create_const_tensor(
1036 op.name + "_weights",
1037 shape,
1038 op.ifm.dtype,
1039 weights,
1040 np.uint8,
1041 purpose=TensorPurpose.Weights,
1042 quantization=quantization,
1043 )
James Peet7519d502021-07-19 16:47:58 +01001044 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001045 op.type = Op.DepthwiseConv2DBias
1046 op.inputs = []
1047 op.add_input_tensor(ifm)
1048 op.add_input_tensor(weight_tens)
1049 # Add bias tensor, all biases set to 0
1050 op.inputs.append(None)
1051 fixup_bias_tensors(op, arch, nng)
1052 # Add other inputs
1053 op.inputs.extend(other_inputs)
1054 op.rounding_mode = NpuRoundingMode.NATURAL
1055
1056 # Bypass the PAD operator
1057 op.set_input_tensor(pad_op.ifm, 0)
1058 # Adjust the padding attributes of the convolution operator
1059 op.attrs["padding"] = Padding.EXPLICIT
1060 op.attrs["explicit_padding"] = (top, left, bottom, right)
1061 op.set_ifm_ofm_shapes()
1062 return op
1063
1064
1065def convert_pad(op: Operation, arch, nng):
1066 """
1067 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1068 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1069 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1070 """
1071 if op.type != Op.Pad or not op.run_on_npu:
1072 return op
1073 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1074
1075 ifm = op.ifm
1076 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001077 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001078 ofm = op.ofm
1079 assert ofm is not None
1080 ofm.ops = []
1081 ofm_shape = op.ofm_shapes[0]
1082
1083 # Average pool op that copies IFM to the right place inside the OFM
1084 shp0 = Shape4D(0, 0, 0, 0)
1085 shp_top = shp0.with_height(top)
1086 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1087 avgpool_op.activation = op.activation
1088 quant = ofm.quantization
1089 pad_value = quant.zero_point
1090 # Add operations that fill the borders of the OFM
1091 if top > 0:
1092 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1093 zero_tens = create_const_tensor(
1094 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1095 )
1096 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1097 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1098 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1099 if bottom > 0:
1100 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1101 zero_tens = create_const_tensor(
1102 op.name + "_bottom",
1103 shape.as_list(),
1104 ofm.dtype,
1105 shape.elements() * [pad_value],
1106 np.uint8,
1107 quantization=quant,
1108 )
1109 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1110 create_avg_pool_for_concat(
1111 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1112 )
1113 if left > 0:
1114 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1115 zero_tens = create_const_tensor(
1116 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1117 )
1118 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1119 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1120 if right > 0:
1121 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1122 zero_tens = create_const_tensor(
1123 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1124 )
1125 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1126 create_avg_pool_for_concat(
1127 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1128 )
1129
1130 op.type = Op.ConcatTFLite
1131 return avgpool_op
1132
1133
1134def add_attrs_to_resizebilinear(op, arch, nng):
1135 if op.type == Op.ResizeBilinear and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001136 input_shape = op.ifm_shapes[0]
1137 upscaled_height = input_shape.height * 2
1138 upscaled_width = input_shape.width * 2
1139 out_shape = op.ofm_shapes[0]
1140 if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
1141 # this means the output is supposed to be a x2 upscale,
1142 # so we need to do SAME padding
1143 op.attrs["padding"] = Padding.SAME
1144 elif (
1145 op.attrs["align_corners"]
1146 and out_shape.height == (upscaled_height - 1)
1147 and out_shape.width == (upscaled_width - 1)
1148 ):
1149 # here we can just run the avg pool without padding and
1150 # produce a (M * 2 - 1, N * 2 - 1) sized output
1151 op.attrs["padding"] = Padding.VALID
1152 else:
1153 return op
Tim Hall3c5cfe92022-03-16 16:31:57 +00001154 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001155 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
1156 return op
1157
1158
1159def fixup_bias_tensors(op, arch, nng):
1160 if op.type.needs_bias() and op.bias is None:
1161 # Op has no bias, add bias tensor filled with zeros
1162 nr_biases = op.inputs[1].shape[-1]
1163 bias_values = [0] * nr_biases
1164 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001165 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1166
1167 return op
1168
1169
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001170def fixup_asymmetric_weights(op, arch, nng):
1171 if op.run_on_npu and (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op()):
1172 if op.ifm.dtype == DataType.int8:
1173 if not np.all(op.weights.quantization.zero_point == 0):
1174 print(f"Warning: {op.type} '{op.name}' has asymmetric weights, zero points have been adjusted.")
1175 op.weights.quantization.zero_point *= 0
1176
1177 return op
1178
1179
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001180def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1181 if op.type == Op.Mean and op.run_on_npu:
1182 keep_dims = op.attrs.get("keep_dims", False)
1183 inp, axis = op.inputs
1184 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001185 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001186 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001187 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001188
1189 # Height and width axes have different index depending on dimensions
1190 if axis.shape == [] or axis.shape[0] == 1: # single axis
1191 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1192 if dims in (2, 3):
1193 if axis == 0:
1194 h, w = shape[axis], 1
1195 else:
1196 h, w = 1, shape[axis]
1197 else:
1198 if axis == 1:
1199 h, w = shape[axis], 1
1200 else:
1201 h, w = 1, shape[axis]
1202 else: # multiple axes
1203 axis = sorted(axis.values)
1204 h, w = [shape[i] for i in axis]
1205
1206 # Set necessary depthwise attributes
1207 op.attrs.update(
1208 {
1209 "padding": Padding.VALID,
1210 "stride_h": 1,
1211 "stride_w": 1,
1212 "strides": (1, 1, 1, 1),
1213 "depth_multiplier": 1,
1214 "channel_multiplier": 1,
1215 "dilation_h_factor": 1,
1216 "dilation_w_factor": 1,
1217 "dilation": (1, 1, 1, 1),
1218 }
1219 )
1220 # Change op type
1221 op.type = Op.DepthwiseConv2DBias
1222 # Set IFM/OFM shapes after changing op type
1223 op.set_ifm_ofm_shapes()
1224
1225 weight_scale, bias = 1, None
1226 ofmq, ifmq = op.ofm.quantization, inp.quantization
1227 # Set rounding mode, scaling and zero point based on which reference implementation to match
1228 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1229 if inp.dtype == DataType.uint8:
1230 # This attribute means a different scaling calculation is used in order to match reference
1231 op.low_precision_scaling = True
1232 weight_scale = h * w
1233 # Set zero points to 0 as they will be adjusted for with bias term
1234 foq = ofmq.clone()
1235 foq.zero_point = 0
1236 fiq = ifmq.clone()
1237 fiq.zero_point = 0
1238 op.forced_input_quantization = fiq
1239 bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
1240 # If the bias term is outside uint8 range, we need an Add op to apply it.
1241 if bias_term < 0 or bias_term > 255:
1242 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1243 # Bias term has higher bitness (i32) than input/output (u8).
1244 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1245 # the bias can only effectively assume values in the range [-255, 255].
1246 intermediate.dtype = DataType.int16
1247 intermediate.quantization.zero_point = 0
1248 add_op = Operation(Op.Add, op.name + "_bias")
1249 add_op.forced_output_quantization = foq
1250 add_op.add_input_tensor(intermediate)
1251 quant = QuantizationParameters()
1252 quant.zero_point = 0
1253 bias_term_tens = create_const_tensor(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001254 op.name + "_bias",
1255 [1, 1, 1, 1],
1256 DataType.int16,
1257 [bias_term],
1258 np.int16,
1259 quantization=quant,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001260 )
1261 add_op.add_input_tensor(bias_term_tens)
1262 add_op.set_output_tensor(op.ofm)
1263 add_op.set_ifm_ofm_shapes()
1264 add_op.activation = op.activation
1265 op.activation = None
1266 op.set_output_tensor(intermediate)
1267 op.set_ifm_ofm_shapes()
1268 # If not, we can just do it with the OFM zero point.
1269 else:
1270 foq.zero_point = bias_term
1271 op.forced_output_quantization = foq
1272 else:
1273 assert inp.dtype == DataType.int8
1274 # Use a depthwise to calculate the sum,
1275 # followed by a multiplication with 1/N to get the MEAN
1276 weight_scale = 1
1277 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1278 intermediate.dtype = DataType.int16
1279 mul_op = Operation(Op.Mul, op.name + "_mul")
1280 mul_op.add_input_tensor(intermediate)
1281 # Create scalar containing 1/N
1282 quant = QuantizationParameters()
1283 quant.zero_point = 0
1284 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1285 # while rounding mode NATURAL would round this to -1.
1286 # This can only occur if N is even, and can be emulated by
1287 # multiplying with a number that is slightly smaller than 1/N.
1288 # It must be so small that other roundings are not affected;
1289 # the calculated value is based on worst case,
1290 # which is sum 256 * N (the maximum sum that can occur with int8)
1291 n = int(h * w)
1292 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1293 quant.scale_f32 = 1 / (n - eps)
1294 scalar = create_const_tensor(
1295 op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
1296 )
1297 mul_op.add_input_tensor(scalar)
1298 mul_op.set_output_tensor(op.ofm)
1299 mul_op.set_ifm_ofm_shapes()
1300 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1301 mul_op.activation = op.activation
1302 op.activation = None
1303 op.set_output_tensor(intermediate)
1304 op.set_ifm_ofm_shapes()
1305 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1306 # Here we can just use a simple AvgPool with truncating rounding,
1307 # as we're emulating simple integer division.
1308 op.rounding_mode = NpuRoundingMode.TRUNCATE
1309 op.type = Op.AvgPool
1310 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1311 else:
1312 op.rounding_mode = NpuRoundingMode.NATURAL
1313 weight_scale = 1 / (h * w)
1314 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1315 bias = -ifmq.zero_point * h * w
1316 fiq = ifmq.clone()
1317 fiq.zero_point = 0
1318 op.forced_input_quantization = fiq
1319
1320 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001321 def extend_dims(dim, in_shape):
1322 if dim < 4:
1323 in_shape = [1] + in_shape
1324 if dim == 2:
1325 in_shape += [1]
1326 return in_shape
1327
1328 if dims < 4 or dims_ofm < 4:
1329 # Fix the ofm dimension when keep_dims is false
1330 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1331 if isinstance(axis, int) and dims_ofm + 1 == dims:
1332 ofm_shape.insert(axis, 1)
1333 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1334 for i in axis:
1335 ofm_shape.insert(i, 1)
1336 shape = extend_dims(dims, shape)
1337 dims_ofm = len(ofm_shape)
1338 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1339 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001340
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001341 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
1342 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001343 shape = [shape[0], 1, h * w, shape[3]]
1344 op.ifm_shapes[0] = Shape4D(shape)
1345 if h > 256 and op.type == Op.AvgPool:
1346 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1347
1348 # If the AvgPool version is used, we don't need to do anything else
1349 if op.type == Op.AvgPool:
1350 return op
1351
1352 # Make unit weight tensor quantization
1353 weight_quant = ifmq.clone()
1354 weight_quant.min = 0
1355 weight_quant.max = 255
1356 weight_quant.scale_f32 = weight_scale
1357 weight_quant.zero_point = 0
1358
1359 # Set weight shape to [H,W,C,B]
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001360 weight_shape = [h, w, shape[3], shape[0]]
1361
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001362 # Add unit weight tensor
1363 op.set_input_tensor(
1364 create_const_tensor(
1365 "weights",
1366 weight_shape,
1367 inp.dtype,
1368 np.ones(weight_shape),
1369 value_dtype=np.uint8,
1370 quantization=weight_quant,
1371 ),
1372 1,
1373 )
James Peet7519d502021-07-19 16:47:58 +01001374 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001375
1376 # Add None bias tensor
1377 op.inputs.append(None)
1378 # Add bias tensor
1379 if bias:
1380 bias_shape = [shape[-1]]
1381 op.set_input_tensor(
1382 create_const_tensor(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001383 "bias",
1384 bias_shape,
1385 inp.dtype,
1386 np.ones(bias_shape) * bias,
1387 value_dtype=np.int32,
1388 quantization=None,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001389 ),
1390 2,
1391 )
1392
1393 return op
1394
1395
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001396def optimise_quantize(op: Operation, arch, nng):
1397
1398 if op.type == Op.Quantize and op.run_on_npu:
1399
1400 ifm, ofm = op.get_ifm_ofm()
1401 input_values = ifm.values
1402
1403 # Guard clause - input not const or no values to quantize
1404 if ifm.ops[0].type != Op.Const or input_values is None:
1405 return op
1406
1407 # Singular val in numpy array, convert to indexable array
1408 if input_values.ndim == 0:
1409 input_values = np.array([input_values])
1410
1411 # requantized int8 to int8
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001412 if (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8) or (
1413 ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16
1414 ):
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001415
1416 # scale needs to use double precision to match TFLite reference kernel
1417 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
1418 effective_multiplier, effective_shift = quantise_scale(effective_scale)
1419
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001420 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001421 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001422 input_val = val - ifm.quantization.zero_point
1423
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001424 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
1425 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001426
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001427 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
1428 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001429
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001430 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
1431 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001432
1433 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001434 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001435
1436 quantized_vals = []
1437 for val in input_values:
1438
1439 # Derive quantized value
1440 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001441 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
1442 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001443
1444 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001445 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
1446
1447 # Unsupported data type
1448 else:
1449 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001450
1451 # Make quantize op const and disconnect from parent node
1452
1453 # Remove reference of the current quant op from the parent tensor's consumer list
1454 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1455
1456 # Clear any references to parent node
1457 op.inputs = []
1458
1459 # Convert this quantize op to const
1460 op.type = Op.Const
1461
1462 return op
1463
1464
Ayaan Masood4965fae2022-06-29 11:30:57 +01001465def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
1466 """Static optimisation for SHAPE operator output value known at compile time"""
1467
1468 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
1469
1470 if op.type == Op.Shape and op.run_on_npu:
1471
1472 ifm, ofm = op.get_ifm_ofm()
1473
1474 if len(ifm.shape) != ofm.shape[0]:
1475 return op
1476
1477 # Remove reference of the current shape op from the parent tensor's consumer list
1478 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1479
1480 # Clear any references to parent node
1481 op.inputs = []
1482
1483 # Convert this SHAPE op to const
1484 op.type = Op.Const
1485
1486 # Add size calculation to shape output tensors
1487 ofm.values = np.array(ifm.shape)
1488
1489 return op
1490
1491
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001492def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001493 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001494 return op
1495
1496
1497def tflite_optimise_graph(nng, arch):
Ayaan Masood4965fae2022-06-29 11:30:57 +01001498 # Compile time optimisations
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001499 optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
1500
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001501 for idx, sg in enumerate(nng.subgraphs):
1502 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001503 nng,
1504 sg,
1505 arch,
1506 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01001507 optimisation_list,
1508 rewrite_unsupported=False,
1509 )
1510
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001511 # Pre-processing step
1512 pre_process_list = [
1513 supported_operator_check,
1514 set_ifm_ofm_op_shapes,
1515 ]
1516
Ayaan Masood4965fae2022-06-29 11:30:57 +01001517 for idx, sg in enumerate(nng.subgraphs):
1518 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1519 nng,
1520 sg,
1521 arch,
1522 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02001523 pre_process_list,
1524 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001525 )
1526
1527 # Handle Concat Ops
1528 for idx, sg in enumerate(nng.subgraphs):
1529 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1530 sg.refresh_after_modification()
1531
1532 # Handle Split Ops
1533 for idx, sg in enumerate(nng.subgraphs):
1534 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1535 nng,
1536 sg,
1537 arch,
1538 [],
1539 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1540 rewrite_unsupported=False,
1541 )
1542
1543 for idx, sg in enumerate(nng.subgraphs):
1544 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001545 nng,
1546 sg,
1547 arch,
1548 [rewrite_split_ops],
1549 [],
1550 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001551 )
1552
1553 # Handle sg input output
1554 for idx, sg in enumerate(nng.subgraphs):
1555 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001556 nng,
1557 sg,
1558 arch,
1559 [],
1560 [fix_sg_input_output],
1561 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001562 )
1563
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001564 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001565 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001566 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001567 sg.refresh_after_modification()
1568
1569 # Rewrite of operators
1570 op_rewrite_list = [
1571 set_tensor_equivalence,
1572 convert_mean_to_depthwise_conv_or_avgpool,
1573 convert_depthwise_to_conv,
1574 convert_conv_to_fc,
1575 convert_softmax,
1576 optimise_strided_conv,
1577 convert_hardswish_to_lut,
1578 rewrite_fully_connected_input,
1579 convert_batched_fc_shape,
1580 fixup_conv2d_backprop,
1581 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001582 reorder_depthwise_weights,
1583 fixup_resizebilinear,
1584 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001585 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001586 convert_mul_max_to_abs_or_lrelu,
1587 convert_lrelu,
1588 convert_tanh_sigmoid_to_lut,
1589 replace_pad_by_hw_pad,
1590 ]
1591
1592 for idx, sg in enumerate(nng.subgraphs):
1593 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001594 nng,
1595 sg,
1596 arch,
1597 [],
1598 op_rewrite_list,
1599 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001600 )
1601
1602 for idx, sg in enumerate(nng.subgraphs):
1603 # remove passthrough tensors and attempt further optimizations
1604 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1605 nng,
1606 sg,
1607 arch,
1608 [remove_passthrough_tensor],
1609 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1610 )
1611
1612 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1613 # since ifm/ofm_shapes are of importance to this function
1614 for sg in nng.subgraphs:
1615 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1616 sg.refresh_after_modification()
1617
1618 return nng