blob: 10ddca604317aa728277cae390749c5d5513a803 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from . import rewrite_graph
26from . import scaling
27from .api import NpuRoundingMode
28from .data_type import DataType
29from .debug_database import DebugDatabase
30from .errors import UnsupportedFeatureError
31from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020032from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020033from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020034from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020035from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020036from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020037from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020038from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020039from .graph_optimiser_util import needed_total_padding
40from .graph_optimiser_util import set_ifm_ofm_op_shapes
41from .graph_optimiser_util import set_tensor_equivalence
42from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020043from .numeric_util import round_away_zero
44from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020045from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020046from .operation import NpuBlockType
47from .operation import Op
48from .operation import Operation
49from .operation import Padding
50from .operation_util import create_avgpool_nop
51from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010052from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020053from .shape4d import Shape4D
54from .softmax import SoftMax
55from .tensor import check_quantized_tens_scaling_equal
56from .tensor import create_const_tensor
57from .tensor import create_equivalence_id
58from .tensor import QuantizationParameters
59from .tensor import Tensor
60from .tensor import TensorPurpose
61from .tflite_mapping import optype_to_builtintype
62
63passthrough_nodes = (Op.Identity,)
64
65
66def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
67 """Creates an average pool for the given concat op/input feature map"""
68 ofm = concat_op.ofm
69 avgpool_op = create_avgpool_nop(name)
70 avgpool_op.inputs = [ifm]
71 avgpool_op.outputs = [ofm]
72
73 avgpool_op.write_offset = write_offset
74 avgpool_op.write_shape = ifm_shape
75 ofm.ops.append(avgpool_op)
76 DebugDatabase.add_optimised(concat_op, avgpool_op)
77 avgpool_op.ifm_shapes.append(ifm_shape)
78 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
79 avgpool_op.memory_function = Op.ConcatSliceWrite
80 return avgpool_op
81
82
83def remove_passthrough_tensor(tens, arch, nng):
84 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
85 assert len(tens.ops[0].inputs) == 1
86 tens = tens.ops[0].inputs[0]
87 return tens
88
89
90def rewrite_concat_ops(op, arch):
91 if not op.run_on_npu or not op.type.is_concat_op():
92 return
93
94 axis_4D = 0
95 ofm = op.ofm
96 ofm.ops = []
97 offset = 0
98
99 unfuse_activation_function(op)
100
101 if op.type == Op.Pack:
102 # Pack is also referred to as Stack
103 axis = int(op.attrs["axis"])
104 if axis < 0: # Convert to positive axis
105 axis = len(op.inputs[0].shape) + 1 + axis
106
107 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
108
109 axis_4D = axis + (4 - len(desired_shape))
110
111 for idx, inp in enumerate(op.inputs):
112 op.ifm_shapes[idx] = Shape4D(desired_shape)
113 op.type = Op.PackReshaped
114
115 inputs, axis = op.get_concat_inputs_axis()
116 for idx, inp in enumerate(inputs):
117 if op.type != Op.PackReshaped:
118 op.ifm_shapes[idx] = Shape4D(inp.shape)
119 if axis >= 0:
120 axis_4D = axis + (4 - len(inp.shape))
121 else:
122 axis_4D = axis
123 write_offset = [0, 0, 0, 0]
124 write_offset[axis_4D] = offset
125 concat_end = offset + op.ifm_shapes[idx][axis_4D]
126 create_avg_pool_for_concat(
127 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
128 )
129 offset = concat_end
130 assert ofm.shape[axis] == offset
131
132 return op
133
134
135def rewrite_split_ops(tens, arch, nng):
136
137 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
138 split_op = tens.ops[0]
139
140 # Not supported so leave it and run on CPU
141 if not split_op.run_on_npu:
142 return tens
143
144 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
145
146 tens.ops = []
147 new_op = Operation(Op.SplitSliceRead, split_op.name)
148 new_op.inputs = [inp]
149 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000150 if None in (offset_end, offset_start):
151 read_shape = None
152 else:
153 # the read shape is relative to each start offset
154 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200155
156 # For Split the offset cannot be extracted from the tensor so it has to
157 # be calculated from the index of the output tensor
158 if axis is not None:
159 # Get the start and end of the split
160 offset_start = [0] * 4
161 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
162 for idx, out in enumerate(outputs):
163 if axis_4D_list is not None:
164 axis_4D = axis_4D_list[idx]
165 else:
166 split_op.ofm_shapes[idx] = Shape4D(out.shape)
167 if axis >= 0:
168 axis_4D = axis + (4 - len(out.shape))
169 else:
170 axis_4D = axis
171
172 if out == tens:
173 ofm_shape_idx = idx
174 read_shape = split_op.ofm_shapes[idx]
175 break
176
177 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
178
179 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
180 new_op.read_shapes[0] = read_shape
181 new_op.run_on_npu = True
182 new_op.set_output_tensor(tens)
183 new_op.ifm_shapes.append(Shape4D(inp.shape))
184 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
185 DebugDatabase.add_optimised(split_op, new_op)
186
187 return tens
188
189
190def remove_SplitSliceRead(op, arch):
191
192 if op.type == Op.SplitSliceRead:
193 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
194 if (
195 len(op.ofm.consumer_list) == 1
196 and op.ofm.consumer_list[0] is not None
197 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200198 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200199 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
200 ):
201 # SplitSliceRead can be performed by tensor consumer
202 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200203 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200204 else:
205 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
206 avgpool_op.add_input_tensor(op.ifm)
207 avgpool_op.outputs = [op.ofm]
208 op.ofm.ops.remove(op)
209 op.ofm.ops.append(avgpool_op)
210 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
211 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
212 avgpool_op.read_offsets[0] = op.read_offsets[0]
213 avgpool_op.read_shapes[0] = op.read_shapes[0]
214
215 op.ifm.consumer_list.remove(op)
216 DebugDatabase.add_optimised(op, avgpool_op)
217
218
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200219def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
220 k_w, k_h = kernel.dilated_wh()
221 s_x, s_y = kernel.stride
222 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
223 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
224 if padding_type == Padding.SAME:
225 left_pad = (xpad + 0) // 2
226 right_pad = (xpad + 1) // 2
227 top_pad = (ypad + 0) // 2
228 bottom_pad = (ypad + 1) // 2
229 elif padding_type == Padding.VALID:
230 left_pad = 0
231 right_pad = 0
232 top_pad = 0
233 bottom_pad = 0
234 elif padding_type == Padding.EXPLICIT:
235 # Padding is specified in a PAD operator which has been bypassed.
236 top, left, bottom, right = explicit_padding
237 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
238 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
239 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000240 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200241 padding = (top_pad, left_pad, bottom_pad, right_pad)
242 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
243 return padding, skirt
244
245
246def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
247 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
248 if padding_type == Padding.SAME:
249 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
250 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
251 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
252 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
253 left_pad = max(kernel_width - 1 - right_pad, 0)
254 top_pad = max(kernel_height - 1 - bottom_pad, 0)
255 elif padding_type == Padding.VALID:
256 right_pad = max(kernel_width - 2, 0)
257 bottom_pad = max(kernel_height - 2, 0)
258 left_pad = kernel_width - 1
259 top_pad = kernel_height - 1
260 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000261 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200262 padding = (top_pad, left_pad, bottom_pad, right_pad)
263 skirt = padding
264 return padding, skirt
265
266
267def fixup_conv2d_backprop(op, arch, nng):
268 if op.type == Op.Conv2DBackpropInput:
269 # flip the inputs
270 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
271 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000272 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200273
274 # Update strides
275 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
276
277 return op
278
279
280# Convert the op to an elementwise add
281def convert_resizebilinear_1x1_to_add(op):
282 op.type = Op.Add
283 op.name = op.name + "_add"
284 op.attrs["resizebilinear"] = True
285 # Create an input tensor filled with zeros
286 shape = op.ofm_shapes[0].as_list()
287 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100288 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200289 tens.quantization = QuantizationParameters(0.0, 255.0)
290 tens.quantization.scale_f32 = 1.0
291 tens.quantization.zero_point = 0
292 tens.consumer_list = [op]
293 tens_op = op.inputs[1].ops[0]
294 tens_op.set_output_tensor(tens)
295 # Set the add inputs
296 op.inputs[1] = op.inputs[0]
297 op.inputs[0] = tens
298 op.set_ifm_ofm_shapes()
299
300 return op
301
302
Rickard Boline546def2022-01-25 15:45:00 +0000303# Convert ResizeBilinear to a number of 2x2 nearest neighbor upscaling and one avgpool op with kernel size dependent
304# on the upscaling factor. Avgpool kernel limit of 8x8 when padding is applied limits upscaling to 8x8.
305def convert_resizebilinear_to_nearest_neighbor_upscaling_and_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200306 pre_op = op
307 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000308 dtype = op.ifm.dtype
309 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200310 if op.attrs["align_corners"]:
311 shape_modifier = 1
312 op.attrs["padding"] = Padding.VALID
313 else:
314 shape_modifier = 0
315 op.attrs["padding"] = Padding.SAME
Tim Hall3c5cfe92022-03-16 16:31:57 +0000316 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200317
318 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
319 out_shape = np.array(op.ofm_shapes[0].get_hw_as_list())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200320
Rickard Boline546def2022-01-25 15:45:00 +0000321 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100322 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
323 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
324 upscale_factor = int(round(out_shape[1] / upscaled_shape[1]))
Rickard Boline546def2022-01-25 15:45:00 +0000325 n = int(np.log2(upscale_factor))
326
327 # Perform 2x2 upscaling n-1 times
328 scaled_op = pre_op
329 for count in range(n - 1):
330 if count > 0:
331 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200332 scaled_op.inputs[0] = pre_op.outputs[0]
333
Rickard Boline546def2022-01-25 15:45:00 +0000334 # Nearest neighbor 2x2 upscaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200335 upscaled_shape = upscaled_shape * 2 - shape_modifier
Rickard Boline546def2022-01-25 15:45:00 +0000336 shape = op.ofm_shapes[0].as_list()
337 shape[1:3] = upscaled_shape
338 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
339 out_tens.quantization = op.outputs[0].quantization.clone()
340 scaled_op.set_output_tensor(out_tens)
341 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200342
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200343 scaled_op.set_ifm_ofm_shapes()
344
Rickard Boline546def2022-01-25 15:45:00 +0000345 # Last 2x2 upscaling also applies avgpool with kernel size dependent on the upscaling factor and adds
346 # padding to the right and bottom.
347 if n > 1:
348 scaled_op = op.clone(f"_{n-1}")
349 scaled_op.inputs[0] = pre_op.outputs[0]
350 scaled_op.attrs["padding"] = Padding.EXPLICIT
351 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
352 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
353 scaled_op.outputs = outputs
354 scaled_op.outputs[0].ops = [scaled_op]
355 scaled_op.set_ifm_ofm_shapes()
356
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200357 return op
358
359
360def fixup_resizebilinear(op, arch, nng):
361 if op.type == Op.ResizeBilinear and op.run_on_npu:
362 if op.ifm_shapes[0] == op.ofm_shapes[0]:
363 # Bypass nop resizebilinear
364 op.inputs = op.inputs[:1]
365 op.type = Op.Identity
366 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
367 convert_resizebilinear_1x1_to_add(op)
368 else:
Rickard Boline546def2022-01-25 15:45:00 +0000369 convert_resizebilinear_to_nearest_neighbor_upscaling_and_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200370
371 return op
372
373
374def convert_nop_split_to_identity(op, arch, nng):
375 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
376 # the list comprehension should return a list with a single tensor
377 # if it shouldn't, remove_passthrough_tensor will fail appropriately
378 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
379 op.type = Op.Identity
380 return op
381
382
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100383def rewrite_fully_connected_input(op: Operation, arch, nng):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200384
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100385 if op.type == Op.FullyConnected:
386 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
387 assert new_shape is not None, "Tensor can not be reshaped to 2D"
388 op.ifm_shapes[0] = new_shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200389 return op
390
391
392def convert_batched_fc_shape(op, arch, nng):
393 if op.type == Op.FullyConnected:
394 # Check if the first dimension indicates batching
395 if op.ifm_shapes[0].batch > 1:
396 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
397 n = op.ifm_shapes[0].batch
398 h, w = batching_split.get(n, (1, n))
399 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
400
401 # Reshape Weights to be 4D. IO becomes HWIO
402 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100403 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
404 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200405
406 n = op.ofm_shapes[0].batch
407 h, w = batching_split.get(n, (1, n))
408 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
409 return op
410
411
412def unfuse_activation_function(op):
413 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
414 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
415 op.activation = None
416 out_tens = op.outputs[0]
417 intermediate_tens = out_tens.clone("_act_intermediate")
418 act_op.set_output_tensor(out_tens)
419 act_op.add_input_tensor(intermediate_tens)
420 op.set_output_tensor(intermediate_tens)
421 act_op.set_ifm_ofm_shapes()
422
423
424def rewrite_stridedslice_output(op, arch, nng):
425 if not op.run_on_npu or op.type != Op.StridedSlice:
426 return op
427
428 new_axis_mask = op.attrs["new_axis_mask"]
429 shrink_axis_mask = op.attrs["shrink_axis_mask"]
430
431 if shrink_axis_mask == 0 and new_axis_mask == 0:
432 return op
433
434 axis_4D = [0] * len(op.outputs)
435 for idx, out_tens in enumerate(op.outputs):
436 output_shape = list(out_tens.shape)
437
438 if shrink_axis_mask != 0:
439 n = 0
440 axis = 0
441 while shrink_axis_mask:
442 prev_mask = shrink_axis_mask
443 n += 1
444 shrink_axis_mask &= shrink_axis_mask - 1
445 axis = int(math.log2(prev_mask - shrink_axis_mask))
446 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
447
448 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
449 op.attrs["shrink_axis_mask"] = 0
450 if axis >= 0:
451 axis_4D[idx] = axis + (4 - len(output_shape))
452 else:
453 axis_4D[idx] = axis
454 op.ofm_shapes[idx] = Shape4D(output_shape)
455
456 elif new_axis_mask != 0:
457 n = 0
458 axis = 0
459 while new_axis_mask:
460 prev_mask = new_axis_mask
461 n += 1
462 new_axis_mask &= new_axis_mask - 1
463 axis = int(math.log2(prev_mask - new_axis_mask))
464 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
465 new_axis_mask >>= 1
466
467 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
468 op.attrs["new_axis_mask"] = 0
469 if axis >= 0:
470 axis_4D[idx] = axis + (4 - len(output_shape))
471 else:
472 axis_4D[idx] = axis
473 op.ofm_shapes[idx] = Shape4D(output_shape)
474
475 op.attrs["split_axis_4D"] = axis_4D
476 return op
477
478
479def rewrite_unpack_output(op, arch, nng):
480 tens = op.outputs[0]
481 if op.run_on_npu and op.type == Op.Unpack:
482 # Unpack is also referred to as Unstack
483 axis = int(op.attrs["axis"])
484 if axis < 0: # Convert to positive axis
485 axis = len(op.inputs[0].shape) + 1 + axis
486 op.type = Op.UnpackReshaped
487 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
488
489 axis_4D = axis + (4 - len(desired_output_shape))
490 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
491
492 for idx, out_tens in enumerate(op.outputs):
493 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
494 return op
495
496
497def add_padding_fields(op, arch, nng):
498 if op.run_on_npu:
499 if "padding" in op.attrs:
500 input_shape = op.ifm_shapes[0]
501 output_shape = op.ofm_shapes[0]
502 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
503 kernel_size = op.inputs[1].shape[:2]
504 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
505 kernel_size = op.attrs["ksize"][1:3]
506 else:
507 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
508
509 if op.type == Op.Conv2DBackpropInputSwitchedBias:
510 upscaling_factor = output_shape.height // input_shape.height
511 padding, skirt = calc_upscaled_padding_and_skirt(
512 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
513 )
514 else:
515 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200516 op.attrs["padding"],
517 op.kernel,
518 input_shape,
519 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200520 )
521
522 op.attrs["explicit_padding"] = padding
523 op.attrs["skirt"] = skirt
524
525 return op
526
527
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200528def reorder_depthwise_weights(op, arch, nng):
529 if op.type.is_depthwise_conv2d_op():
530 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100531 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
532 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200533 weight_tensor.weight_transpose_depthwise = True
534
535 return op
536
537
538def optimise_strided_conv(op, arch, nng):
Louis Verhaard43d27582022-03-17 14:06:00 +0100539 if op.type != Op.Conv2DBias or op.op_index != 0:
540 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200541 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100542 weight_tensor = op.weights
543 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200544
545 if (
Louis Verhaard43d27582022-03-17 14:06:00 +0100546 stride_x == 2
547 and ifm_shape.depth <= 4
548 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200549 and weight_tensor is not None
550 and weight_tensor.shape[1] >= 2
551 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100552 k_w, _ = op.get_kernel_size()
553 curr_padding_x = needed_total_padding(ifm_shape.width, 2, k_w)
554 optimised_padding_x = needed_total_padding(ifm_shape.width // 2, 1, (k_w + 1) // 2)
555 if curr_padding_x != optimised_padding_x:
556 # Horizontal padding would become different after optimisation; this would not work
557 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200558 # IFM
559 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
560
561 # Weights
562 weight_shape = weight_tensor.shape
563 if weight_shape[1] % 2 != 0:
564 weight_shape[1] = weight_shape[1] + 1
565 padded_array = np.zeros(weight_shape)
566 for i in range(weight_shape[0]):
567 padded_array[i] = np.vstack(
568 [
James Peet7519d502021-07-19 16:47:58 +0100569 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200570 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
571 ]
572 )
James Peet7519d502021-07-19 16:47:58 +0100573 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200574 weight_shape[1] //= 2
575 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100576 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200577 weight_tensor.set_all_shapes(weight_shape)
578 # If multiple copies of the weights are used, we could avoid
579 # them having the same address by changing the value_id
580 weight_tensor.value_id = uuid.uuid4()
581
582 # Strides
583 stride_x = 1
584 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
585
586 return op
587
588
589def convert_conv_to_fc(op, arch, nng):
590 # Conv 1x1 can be equivalent to Fully Connected.
591 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
592 # caching/double buffering for the weights.
593 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
594 if op.type == Op.Conv2DBias:
595 h = op.ifm_shapes[0].height
596 w = op.ifm_shapes[0].width
597 kh, kw, _, _ = op.inputs[1].shape
598 if h == 1 and w == 1 and kh == 1 and kw == 1:
599 # Overwrite this op as a Fully Connected Op
600 op.name += "_fc"
601 op.type = Op.FullyConnected
602 op.attrs = {
603 "weights_format": 0,
604 }
605 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
606 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100607 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
608 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200609
610 DebugDatabase.add_optimised(op, op)
611 return op
612
613
614def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
615 if op.run_on_npu and op.type.is_relu_op():
616 ifm = op.inputs[0]
617 ofm = op.outputs[0]
618 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
619 # and requires its own to be inserted
620 if not check_quantized_tens_scaling_equal(ifm, ofm):
621 # Override this op with its own primary op (avgpool)
622 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
623 # And fuse the original activation function to it
624 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200625 # Add explicit rescaling
626 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
627 multiplier, shift = scaling.quantise_scale(rescale)
628 relu_fused_op.rescale = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200629 # Tidy up and assign the ifm and ofm to the new op
630 ifm.consumer_list.remove(op)
631
632 relu_fused_op.add_input_tensor(ifm)
633 relu_fused_op.set_output_tensor(ofm)
634 relu_fused_op.set_ifm_ofm_shapes()
635 op = relu_fused_op
636 return op
637
638
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200639def convert_softmax(op, arch, nng):
640 if op.type == Op.Softmax and op.run_on_npu:
641 softmax = SoftMax(op)
642 op = softmax.get_graph()
643 return op
644
645
646def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
647 r"""Whenever there is a subgraph with this topology:
648
Jonas Ohlssond8575072022-03-30 10:30:25 +0200649 Input X For X = -1 or X > 0
650 | \ / This subgraph can be replaced with either
651 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
652 | /
653 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200654 """
655
656 if op.type == Op.Maximum:
657 # finds the Mul input(s) to the Max
658 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
659 if len(muls) == 1:
660 mul = muls[0].ops[0]
661 elif len(muls) == 2:
662 # In the case both inputs are Muls, find the one with the same input as the Max
663 mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
664 else:
665 # No Mul inputs
666 return op
667
668 # make sure the Mul doesn't have any other consumers
669 mul_ofm = mul.outputs[0]
670 if len(mul_ofm.consumers()) != 1:
671 return op
672 # make sure the Mul doesn't have a fused activation function
673 if mul.activation:
674 return op
675 ifm, ofm = op.get_ifm_ofm()
676 if ifm is None or ofm is None:
677 return op
678
679 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
680 return op
681 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
682 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
683 return op
684
685 # finds the branched input that goes to both the Max and the Mul
686 shared = set(op.inputs) & set(mul.inputs)
687 if len(shared) == 1:
688 shared_in = shared.pop()
689 # find the constant scalar input to the Mul
690 const_tens = (set(mul.inputs) - {shared_in}).pop()
691 # check that it is a scalar
692 if const_tens.shape != []:
693 return op
694 const = const_tens.ops[0]
695 # check that it is a constant
696 if const.type != Op.Const:
697 return op
698 # Remove the Mul from the shared input's consumers
699 shared_in.consumer_list.remove(mul)
700 else:
701 return op
702
703 val = const.outputs[0].values
704 if val >= 0:
705 new_op = Op.LeakyRelu
706 op.attrs["alpha"] = val
707 # to produce bit exact results, the alpha is not enough;
708 # save additional scaling info in attr "alpha_scale", to be used as input
709 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +0100710 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200711 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
712 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
713 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
714 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
715 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
716 elif val == -1:
717 new_op = Op.Abs
718 else:
719 return op
720
721 op.type = new_op
722 op.name = op.name.replace("Maximum", new_op.name)
723 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
724 op.inputs = [shared_in]
725 op.set_ifm_ofm_shapes()
726
727 # Record optimisation in debug database
728 DebugDatabase.add_optimised(op, op)
729
730 return op
731
732
733def convert_hardswish_to_lut(op, arch, nng):
734 if op.type == Op.HardSwish:
735 ifm, ofm = op.get_ifm_ofm()
736 # Generate the LUT
737 ifm_scale = np.double(ifm.quantization.scale_f32)
738 ofm_scale = np.double(ofm.quantization.scale_f32)
739 zp_in = ifm.quantization.zero_point
740 zp_out = ofm.quantization.zero_point
741 ifm_scale_hires = (1 / 128) * ifm_scale
742 relu_multiplier = np.double(3 / 32768)
743 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
744 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
745 # Use 16bit scale
746 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
747 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
748
749 values = []
750 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
751 quantized_min = min(ix)
752 quantized_max = max(ix)
753 for x in ix:
754 input_value = x - zp_in
755 input_value_hires = input_value * 128
756 # Compute the input value on essentially the output scale, not shifted yet
757 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
758 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
759 relu_value = np.int16(input_value_hires)
760 if relu_shift < 31:
761 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
762
763 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
764
765 if relu_shift < 31:
766 relu_value = fp_math.shift_left16(relu_value, 1)
767
768 if relu_shift > 31:
769 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
770
771 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
772 # Now convert that to a 16bit fixedpoint value in [0, 1]
773 relu_value = (relu_value + (1 << 15)) >> 1
774 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
775 shift = 31 - out_shift
776 shift = -shift if shift < 0 else 0
777 # Finally apply the output shift
778 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
779 lut_result = min(quantized_max, max(quantized_min, lut_result))
780 values.append(lut_result)
781 return convert_to_lut(op, values, "hardswish")
782 return op
783
784
785def convert_lrelu_to_mul_max(op, arch):
786 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
787 # (the opposite of convert_mul_max_to_abs_or_lrelu)
788 ifm, ofm = op.get_ifm_ofm()
789 if ifm is None or ofm is None:
790 return op
791
792 # Add multiplication with alpha
793 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
794 mul_alpha.add_input_tensor(ifm)
795 # Create const tensor containing alpha as scalar
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200796 alpha = np.float32(op.attrs["alpha"])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200797 quantization = ifm.quantization.clone()
798 quantization.min = 0
799 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
800 quantization.zero_point = 0
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200801 if np.isinf(1 / alpha):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200802 # Handling of alpha near zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200803 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200804 scalar = 0
805 else:
806 quantization.scale_f32 = alpha
807 scalar = alpha
808 alpha_tens = create_const_tensor(
809 op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
810 )
James Peet7519d502021-07-19 16:47:58 +0100811 alpha_tens.values = np.array([1])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200812 mul_alpha.add_input_tensor(alpha_tens)
813 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
814 mul_alpha.set_output_tensor(fm_alpha)
815 mul_alpha.set_ifm_ofm_shapes()
816 DebugDatabase.add_optimised(op, mul_alpha)
817
818 if check_quantized_tens_scaling_equal(ifm, ofm):
819 # No identity multiplication is needed
820 fm_id = ifm
821 else:
822 # Add multiplication with identity
823 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
824 mul_identity.add_input_tensor(ifm)
825 # Create const tensor containing identity as scalar
826 quantization = ifm.quantization.clone()
827 quantization.min = 0
828 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200829 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200830 quantization.zero_point = 0
831 identity_tens = create_const_tensor(
832 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
833 )
834 mul_identity.add_input_tensor(identity_tens)
835 # Make sure that fm_id is allocated to a different address than fm_alpha
836 fm_id = ofm.clone(op.name + "_id", set_unique=True)
837 mul_identity.set_output_tensor(fm_id)
838 mul_identity.set_ifm_ofm_shapes()
839 DebugDatabase.add_optimised(op, mul_identity)
840
841 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
842 op.type = Op.Maximum
843 op.name = op.name.replace("LeakyRelu", "Maximum")
844 op.inputs = []
845 ifm.consumer_list.remove(op)
846 op.add_input_tensor(fm_alpha)
847 op.add_input_tensor(fm_id)
848 op.set_ifm_ofm_shapes()
849
850 DebugDatabase.add_optimised(op, op)
851 return op
852
853
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200854def convert_to_lut8(op, fn, fn_name):
855 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
856 # fn is a function(real) -> real
857 ifm, ofm = op.get_ifm_ofm()
858 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
859 return op
860 # Generate the LUT
861 ifm_scale = np.double(ifm.quantization.scale_f32)
862 ofm_scale = np.double(ofm.quantization.scale_f32)
863 zp_in = ifm.quantization.zero_point
864 zp_out = ofm.quantization.zero_point
865 values = []
866 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
867 quantized_min = min(ix)
868 quantized_max = max(ix)
869 for x in ix:
870 x_real = ifm_scale * (x - zp_in)
871 y_real = fn(x_real)
872 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
873 lut_result = min(quantized_max, max(quantized_min, lut_result))
874 values.append(lut_result)
875 return convert_to_lut(op, values, fn_name)
876
877
878def convert_lrelu_to_lut(op, arch):
879 ifm, ofm = op.get_ifm_ofm()
880 # Generate the LUT
881 alpha = op.attrs["alpha"]
882 ifm_scale = np.double(ifm.quantization.scale_f32)
883 ofm_scale = np.double(ofm.quantization.scale_f32)
884 zp_in = ifm.quantization.zero_point
885 zp_out = ofm.quantization.zero_point
886 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
887 alpha_scalar = 1
888 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
889 if "alpha_scaling" in op.attrs:
890 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
891 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
892 values = []
893 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
894 quantized_min = min(ix)
895 quantized_max = max(ix)
896 for x in ix:
897 if x < zp_in:
898 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
899 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
900 )
901 else:
902 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
903 lut_result = min(quantized_max, max(quantized_min, lut_result))
904 values.append(lut_result)
905 return convert_to_lut(op, values, "lrelu")
906
907
908def convert_lrelu(op, arch, nng):
909 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
910 if op.type != Op.LeakyRelu:
911 return op
912 ifm, ofm = op.get_ifm_ofm()
913 if ifm is None or ofm is None:
914 return op
915 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
916 # use LUT for int8/uint8
917 return convert_lrelu_to_lut(op, arch)
918 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
919 # use LeakyRelu unmodified for int16 with equal input/output scaling
920 return op
921 return convert_lrelu_to_mul_max(op, arch)
922
923
924def convert_tanh_sigmoid_to_lut(op, arch, nng):
925 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
926 if op.type == Op.Sigmoid:
927 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
928 elif op.type == Op.Tanh:
929 return convert_to_lut8(op, math.tanh, "tanh")
930 return op
931
932
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200933def remove_memory_only_ops(op, arch):
934 if op.run_on_npu and op.type in memory_only_ops:
935 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200936
937
938def fuse_activation_function_with_prev(op, arch, nng):
939 # if op is a no-op: attempts to move the activation function to the preceding op
940 if not op.attrs.get("is_nop", False) or op.activation is None:
941 return op
942 ifm, ofm = op.get_ifm_ofm()
943 if ifm is None or ofm is None:
944 return op
945 # finds the input(s) to the operation
946 prev_op = ifm.ops[0]
947 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
948 fuse = (
949 prev_op.run_on_npu
950 and prev_op.type.npu_block_type != NpuBlockType.Default
951 and len(ifm.ops) == 1
952 and len(prev_op.outputs[0].consumers()) == 1
953 and prev_op.activation is None
954 )
955 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
956 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
957 # LUT currently only works correctly for elementwise ops
958 fuse = False
959 if not fuse:
960 return op
961 # Move the fused activation function + corresponding info to prev_op
962 prev_op.activation = op.activation
963 prev_op.forced_output_quantization = op.forced_output_quantization
964 if op.activation_lut is not None:
965 prev_op.set_activation_lut(op.activation_lut)
966 # Bypass op
967 prev_op.set_output_tensor(ofm)
968 DebugDatabase.add_optimised(op, prev_op)
969 return op
970
971
972def _leading_pad_ok(leading_pad, stride, kernel_size):
973 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
974 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
975 max_size = kernel_size // 2
976 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
977
978
979def replace_pad_by_hw_pad(op: Operation, arch, nng):
980 """
981 Tries to completely remove a PAD operator by using hardware padding.
982 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
983 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
984 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
985 if both operations can be run on the NPU.
986 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
987 """
988 if (
989 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +0000990 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200991 and op.run_on_npu
992 and op.attrs["padding"] == Padding.VALID
993 ):
994 pad_op = op.ifm.ops[0]
995 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
996 return op
997 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
998 return op
999 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1000 k = op.kernel
1001 k_w, k_h = k.dilated_wh()
1002
1003 # Check if the PAD operator can be replaced by hardware padding
1004 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1005 # Too much padding, it would require hardware padding to actually insert zeros
1006 return op
1007 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1008 return op
1009
1010 if op.type.is_avgpool_op():
1011 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1012 for pad, k_size in (
1013 (left, k_w),
1014 (right, k_w),
1015 (top, k_h),
1016 (bottom, k_h),
1017 ):
1018 if pad not in (0, k_size // 2):
1019 return op
1020 # Average pool is converted to depthwise, because NPU average pool + same padding
1021 # has a special implementation that is different from PAD followed by average pool with
1022 # valid padding.
1023 k_w, k_h = op.kernel.width, op.kernel.height
1024 ifm = op.ifm
1025 # Remember other inputs
1026 other_inputs = op.inputs[1:]
1027 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1028 quantization = QuantizationParameters(0.0, 255.0)
1029 quantization.scale_f32 = 1.0 / (k_w * k_h)
1030 quantization.zero_point = 0
1031 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1032 weights = np.full(shape, 1)
1033
1034 weight_tens = create_const_tensor(
1035 op.name + "_weights",
1036 shape,
1037 op.ifm.dtype,
1038 weights,
1039 np.uint8,
1040 purpose=TensorPurpose.Weights,
1041 quantization=quantization,
1042 )
James Peet7519d502021-07-19 16:47:58 +01001043 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001044 op.type = Op.DepthwiseConv2DBias
1045 op.inputs = []
1046 op.add_input_tensor(ifm)
1047 op.add_input_tensor(weight_tens)
1048 # Add bias tensor, all biases set to 0
1049 op.inputs.append(None)
1050 fixup_bias_tensors(op, arch, nng)
1051 # Add other inputs
1052 op.inputs.extend(other_inputs)
1053 op.rounding_mode = NpuRoundingMode.NATURAL
1054
1055 # Bypass the PAD operator
1056 op.set_input_tensor(pad_op.ifm, 0)
1057 # Adjust the padding attributes of the convolution operator
1058 op.attrs["padding"] = Padding.EXPLICIT
1059 op.attrs["explicit_padding"] = (top, left, bottom, right)
1060 op.set_ifm_ofm_shapes()
1061 return op
1062
1063
1064def convert_pad(op: Operation, arch, nng):
1065 """
1066 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1067 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1068 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1069 """
1070 if op.type != Op.Pad or not op.run_on_npu:
1071 return op
1072 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1073
1074 ifm = op.ifm
1075 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001076 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001077 ofm = op.ofm
1078 assert ofm is not None
1079 ofm.ops = []
1080 ofm_shape = op.ofm_shapes[0]
1081
1082 # Average pool op that copies IFM to the right place inside the OFM
1083 shp0 = Shape4D(0, 0, 0, 0)
1084 shp_top = shp0.with_height(top)
1085 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1086 avgpool_op.activation = op.activation
1087 quant = ofm.quantization
1088 pad_value = quant.zero_point
1089 # Add operations that fill the borders of the OFM
1090 if top > 0:
1091 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1092 zero_tens = create_const_tensor(
1093 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1094 )
1095 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1096 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1097 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1098 if bottom > 0:
1099 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1100 zero_tens = create_const_tensor(
1101 op.name + "_bottom",
1102 shape.as_list(),
1103 ofm.dtype,
1104 shape.elements() * [pad_value],
1105 np.uint8,
1106 quantization=quant,
1107 )
1108 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1109 create_avg_pool_for_concat(
1110 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1111 )
1112 if left > 0:
1113 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1114 zero_tens = create_const_tensor(
1115 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1116 )
1117 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1118 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1119 if right > 0:
1120 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1121 zero_tens = create_const_tensor(
1122 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1123 )
1124 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1125 create_avg_pool_for_concat(
1126 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1127 )
1128
1129 op.type = Op.ConcatTFLite
1130 return avgpool_op
1131
1132
1133def add_attrs_to_resizebilinear(op, arch, nng):
1134 if op.type == Op.ResizeBilinear and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001135 input_shape = op.ifm_shapes[0]
1136 upscaled_height = input_shape.height * 2
1137 upscaled_width = input_shape.width * 2
1138 out_shape = op.ofm_shapes[0]
1139 if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
1140 # this means the output is supposed to be a x2 upscale,
1141 # so we need to do SAME padding
1142 op.attrs["padding"] = Padding.SAME
1143 elif (
1144 op.attrs["align_corners"]
1145 and out_shape.height == (upscaled_height - 1)
1146 and out_shape.width == (upscaled_width - 1)
1147 ):
1148 # here we can just run the avg pool without padding and
1149 # produce a (M * 2 - 1, N * 2 - 1) sized output
1150 op.attrs["padding"] = Padding.VALID
1151 else:
1152 return op
Tim Hall3c5cfe92022-03-16 16:31:57 +00001153 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001154 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
1155 return op
1156
1157
1158def fixup_bias_tensors(op, arch, nng):
1159 if op.type.needs_bias() and op.bias is None:
1160 # Op has no bias, add bias tensor filled with zeros
1161 nr_biases = op.inputs[1].shape[-1]
1162 bias_values = [0] * nr_biases
1163 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001164 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1165
1166 return op
1167
1168
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001169def fixup_asymmetric_weights(op, arch, nng):
1170 if op.run_on_npu and (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op()):
1171 if op.ifm.dtype == DataType.int8:
1172 if not np.all(op.weights.quantization.zero_point == 0):
1173 print(f"Warning: {op.type} '{op.name}' has asymmetric weights, zero points have been adjusted.")
1174 op.weights.quantization.zero_point *= 0
1175
1176 return op
1177
1178
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001179def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1180 if op.type == Op.Mean and op.run_on_npu:
1181 keep_dims = op.attrs.get("keep_dims", False)
1182 inp, axis = op.inputs
1183 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001184 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001185 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001186 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001187
1188 # Height and width axes have different index depending on dimensions
1189 if axis.shape == [] or axis.shape[0] == 1: # single axis
1190 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1191 if dims in (2, 3):
1192 if axis == 0:
1193 h, w = shape[axis], 1
1194 else:
1195 h, w = 1, shape[axis]
1196 else:
1197 if axis == 1:
1198 h, w = shape[axis], 1
1199 else:
1200 h, w = 1, shape[axis]
1201 else: # multiple axes
1202 axis = sorted(axis.values)
1203 h, w = [shape[i] for i in axis]
1204
1205 # Set necessary depthwise attributes
1206 op.attrs.update(
1207 {
1208 "padding": Padding.VALID,
1209 "stride_h": 1,
1210 "stride_w": 1,
1211 "strides": (1, 1, 1, 1),
1212 "depth_multiplier": 1,
1213 "channel_multiplier": 1,
1214 "dilation_h_factor": 1,
1215 "dilation_w_factor": 1,
1216 "dilation": (1, 1, 1, 1),
1217 }
1218 )
1219 # Change op type
1220 op.type = Op.DepthwiseConv2DBias
1221 # Set IFM/OFM shapes after changing op type
1222 op.set_ifm_ofm_shapes()
1223
1224 weight_scale, bias = 1, None
1225 ofmq, ifmq = op.ofm.quantization, inp.quantization
1226 # Set rounding mode, scaling and zero point based on which reference implementation to match
1227 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1228 if inp.dtype == DataType.uint8:
1229 # This attribute means a different scaling calculation is used in order to match reference
1230 op.low_precision_scaling = True
1231 weight_scale = h * w
1232 # Set zero points to 0 as they will be adjusted for with bias term
1233 foq = ofmq.clone()
1234 foq.zero_point = 0
1235 fiq = ifmq.clone()
1236 fiq.zero_point = 0
1237 op.forced_input_quantization = fiq
1238 bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
1239 # If the bias term is outside uint8 range, we need an Add op to apply it.
1240 if bias_term < 0 or bias_term > 255:
1241 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1242 # Bias term has higher bitness (i32) than input/output (u8).
1243 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1244 # the bias can only effectively assume values in the range [-255, 255].
1245 intermediate.dtype = DataType.int16
1246 intermediate.quantization.zero_point = 0
1247 add_op = Operation(Op.Add, op.name + "_bias")
1248 add_op.forced_output_quantization = foq
1249 add_op.add_input_tensor(intermediate)
1250 quant = QuantizationParameters()
1251 quant.zero_point = 0
1252 bias_term_tens = create_const_tensor(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001253 op.name + "_bias",
1254 [1, 1, 1, 1],
1255 DataType.int16,
1256 [bias_term],
1257 np.int16,
1258 quantization=quant,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001259 )
1260 add_op.add_input_tensor(bias_term_tens)
1261 add_op.set_output_tensor(op.ofm)
1262 add_op.set_ifm_ofm_shapes()
1263 add_op.activation = op.activation
1264 op.activation = None
1265 op.set_output_tensor(intermediate)
1266 op.set_ifm_ofm_shapes()
1267 # If not, we can just do it with the OFM zero point.
1268 else:
1269 foq.zero_point = bias_term
1270 op.forced_output_quantization = foq
1271 else:
1272 assert inp.dtype == DataType.int8
1273 # Use a depthwise to calculate the sum,
1274 # followed by a multiplication with 1/N to get the MEAN
1275 weight_scale = 1
1276 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1277 intermediate.dtype = DataType.int16
1278 mul_op = Operation(Op.Mul, op.name + "_mul")
1279 mul_op.add_input_tensor(intermediate)
1280 # Create scalar containing 1/N
1281 quant = QuantizationParameters()
1282 quant.zero_point = 0
1283 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1284 # while rounding mode NATURAL would round this to -1.
1285 # This can only occur if N is even, and can be emulated by
1286 # multiplying with a number that is slightly smaller than 1/N.
1287 # It must be so small that other roundings are not affected;
1288 # the calculated value is based on worst case,
1289 # which is sum 256 * N (the maximum sum that can occur with int8)
1290 n = int(h * w)
1291 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1292 quant.scale_f32 = 1 / (n - eps)
1293 scalar = create_const_tensor(
1294 op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
1295 )
1296 mul_op.add_input_tensor(scalar)
1297 mul_op.set_output_tensor(op.ofm)
1298 mul_op.set_ifm_ofm_shapes()
1299 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1300 mul_op.activation = op.activation
1301 op.activation = None
1302 op.set_output_tensor(intermediate)
1303 op.set_ifm_ofm_shapes()
1304 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1305 # Here we can just use a simple AvgPool with truncating rounding,
1306 # as we're emulating simple integer division.
1307 op.rounding_mode = NpuRoundingMode.TRUNCATE
1308 op.type = Op.AvgPool
1309 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1310 else:
1311 op.rounding_mode = NpuRoundingMode.NATURAL
1312 weight_scale = 1 / (h * w)
1313 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1314 bias = -ifmq.zero_point * h * w
1315 fiq = ifmq.clone()
1316 fiq.zero_point = 0
1317 op.forced_input_quantization = fiq
1318
1319 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001320 def extend_dims(dim, in_shape):
1321 if dim < 4:
1322 in_shape = [1] + in_shape
1323 if dim == 2:
1324 in_shape += [1]
1325 return in_shape
1326
1327 if dims < 4 or dims_ofm < 4:
1328 # Fix the ofm dimension when keep_dims is false
1329 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1330 if isinstance(axis, int) and dims_ofm + 1 == dims:
1331 ofm_shape.insert(axis, 1)
1332 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1333 for i in axis:
1334 ofm_shape.insert(i, 1)
1335 shape = extend_dims(dims, shape)
1336 dims_ofm = len(ofm_shape)
1337 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1338 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001339
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001340 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
1341 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001342 shape = [shape[0], 1, h * w, shape[3]]
1343 op.ifm_shapes[0] = Shape4D(shape)
1344 if h > 256 and op.type == Op.AvgPool:
1345 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1346
1347 # If the AvgPool version is used, we don't need to do anything else
1348 if op.type == Op.AvgPool:
1349 return op
1350
1351 # Make unit weight tensor quantization
1352 weight_quant = ifmq.clone()
1353 weight_quant.min = 0
1354 weight_quant.max = 255
1355 weight_quant.scale_f32 = weight_scale
1356 weight_quant.zero_point = 0
1357
1358 # Set weight shape to [H,W,C,B]
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001359 weight_shape = [h, w, shape[3], shape[0]]
1360
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001361 # Add unit weight tensor
1362 op.set_input_tensor(
1363 create_const_tensor(
1364 "weights",
1365 weight_shape,
1366 inp.dtype,
1367 np.ones(weight_shape),
1368 value_dtype=np.uint8,
1369 quantization=weight_quant,
1370 ),
1371 1,
1372 )
James Peet7519d502021-07-19 16:47:58 +01001373 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001374
1375 # Add None bias tensor
1376 op.inputs.append(None)
1377 # Add bias tensor
1378 if bias:
1379 bias_shape = [shape[-1]]
1380 op.set_input_tensor(
1381 create_const_tensor(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001382 "bias",
1383 bias_shape,
1384 inp.dtype,
1385 np.ones(bias_shape) * bias,
1386 value_dtype=np.int32,
1387 quantization=None,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001388 ),
1389 2,
1390 )
1391
1392 return op
1393
1394
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001395def optimise_quantize(op: Operation, arch, nng):
1396
1397 if op.type == Op.Quantize and op.run_on_npu:
1398
1399 ifm, ofm = op.get_ifm_ofm()
1400 input_values = ifm.values
1401
1402 # Guard clause - input not const or no values to quantize
1403 if ifm.ops[0].type != Op.Const or input_values is None:
1404 return op
1405
1406 # Singular val in numpy array, convert to indexable array
1407 if input_values.ndim == 0:
1408 input_values = np.array([input_values])
1409
1410 # requantized int8 to int8
1411 if ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8:
1412
1413 # scale needs to use double precision to match TFLite reference kernel
1414 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
1415 effective_multiplier, effective_shift = quantise_scale(effective_scale)
1416
1417 assert effective_shift >= 0
1418 assert -31 <= effective_shift <= 30
1419 round_val = 1 << (effective_shift - 1)
1420
1421 requantized_vals = []
1422 for val in input_values:
1423 input_val = val - ifm.quantization.zero_point
1424
1425 output = input_val * effective_multiplier + round_val
1426 ofm_val = (output >> effective_shift) + ofm.quantization.zero_point
1427
1428 clamped_ofm_values = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
1429 requantized_vals.append(clamped_ofm_values)
1430
1431 ofm.values = np.array(requantized_vals)
1432
1433 # Case: Float input - quantize to int
1434 elif np.issubdtype(input_values.dtype, np.float):
1435
1436 quantized_vals = []
1437 for val in input_values:
1438
1439 # Derive quantized value
1440 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
1441 quantized_vals.append(quant_val)
1442
1443 # Pass the statically calculated quant val to output tensor
1444 ofm.values = np.array(quantized_vals)
1445
1446 # Make quantize op const and disconnect from parent node
1447
1448 # Remove reference of the current quant op from the parent tensor's consumer list
1449 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1450
1451 # Clear any references to parent node
1452 op.inputs = []
1453
1454 # Convert this quantize op to const
1455 op.type = Op.Const
1456
1457 return op
1458
1459
Ayaan Masood4965fae2022-06-29 11:30:57 +01001460def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
1461 """Static optimisation for SHAPE operator output value known at compile time"""
1462
1463 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
1464
1465 if op.type == Op.Shape and op.run_on_npu:
1466
1467 ifm, ofm = op.get_ifm_ofm()
1468
1469 if len(ifm.shape) != ofm.shape[0]:
1470 return op
1471
1472 # Remove reference of the current shape op from the parent tensor's consumer list
1473 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1474
1475 # Clear any references to parent node
1476 op.inputs = []
1477
1478 # Convert this SHAPE op to const
1479 op.type = Op.Const
1480
1481 # Add size calculation to shape output tensors
1482 ofm.values = np.array(ifm.shape)
1483
1484 return op
1485
1486
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001487def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001488 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001489 return op
1490
1491
1492def tflite_optimise_graph(nng, arch):
Ayaan Masood4965fae2022-06-29 11:30:57 +01001493 # Compile time optimisations
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001494 optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
1495
1496 for optimisation in optimisation_list:
1497 for idx, sg in enumerate(nng.subgraphs):
1498 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1499 nng,
1500 sg,
1501 arch,
1502 [],
1503 [optimisation],
1504 rewrite_unsupported=False,
1505 )
Ayaan Masood4965fae2022-06-29 11:30:57 +01001506
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001507 # Pre-processing step
1508 pre_process_list = [
1509 supported_operator_check,
1510 set_ifm_ofm_op_shapes,
1511 ]
1512
1513 for idx, sg in enumerate(nng.subgraphs):
1514 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001515 nng,
1516 sg,
1517 arch,
1518 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01001519 optimisation_list,
1520 rewrite_unsupported=False,
1521 )
1522
1523 for idx, sg in enumerate(nng.subgraphs):
1524 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1525 nng,
1526 sg,
1527 arch,
1528 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02001529 pre_process_list,
1530 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001531 )
1532
1533 # Handle Concat Ops
1534 for idx, sg in enumerate(nng.subgraphs):
1535 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1536 sg.refresh_after_modification()
1537
1538 # Handle Split Ops
1539 for idx, sg in enumerate(nng.subgraphs):
1540 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1541 nng,
1542 sg,
1543 arch,
1544 [],
1545 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1546 rewrite_unsupported=False,
1547 )
1548
1549 for idx, sg in enumerate(nng.subgraphs):
1550 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001551 nng,
1552 sg,
1553 arch,
1554 [rewrite_split_ops],
1555 [],
1556 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001557 )
1558
1559 # Handle sg input output
1560 for idx, sg in enumerate(nng.subgraphs):
1561 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001562 nng,
1563 sg,
1564 arch,
1565 [],
1566 [fix_sg_input_output],
1567 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001568 )
1569
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001570 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001571 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001572 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001573 sg.refresh_after_modification()
1574
1575 # Rewrite of operators
1576 op_rewrite_list = [
1577 set_tensor_equivalence,
1578 convert_mean_to_depthwise_conv_or_avgpool,
1579 convert_depthwise_to_conv,
1580 convert_conv_to_fc,
1581 convert_softmax,
1582 optimise_strided_conv,
1583 convert_hardswish_to_lut,
1584 rewrite_fully_connected_input,
1585 convert_batched_fc_shape,
1586 fixup_conv2d_backprop,
1587 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001588 reorder_depthwise_weights,
1589 fixup_resizebilinear,
1590 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001591 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001592 convert_mul_max_to_abs_or_lrelu,
1593 convert_lrelu,
1594 convert_tanh_sigmoid_to_lut,
1595 replace_pad_by_hw_pad,
1596 ]
1597
1598 for idx, sg in enumerate(nng.subgraphs):
1599 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001600 nng,
1601 sg,
1602 arch,
1603 [],
1604 op_rewrite_list,
1605 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001606 )
1607
1608 for idx, sg in enumerate(nng.subgraphs):
1609 # remove passthrough tensors and attempt further optimizations
1610 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1611 nng,
1612 sg,
1613 arch,
1614 [remove_passthrough_tensor],
1615 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1616 )
1617
1618 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1619 # since ifm/ofm_shapes are of importance to this function
1620 for sg in nng.subgraphs:
1621 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1622 sg.refresh_after_modification()
1623
1624 return nng