blob: ff2f5a08338e6f20a781f3b11a7757cd1642dc4f [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
21from typing import Tuple
22
23import numpy as np
24
25from . import fp_math
26from . import lut
27from . import rewrite_graph
28from . import scaling
29from .api import NpuRoundingMode
30from .data_type import DataType
31from .debug_database import DebugDatabase
32from .errors import UnsupportedFeatureError
33from .ethos_u55_regs.ethos_u55_regs import resampling_mode
34from .graph_optimiser_util import needed_total_padding
35from .graph_optimiser_util import set_ifm_ofm_op_shapes
36from .graph_optimiser_util import set_tensor_equivalence
37from .numeric_util import clamp_sigmoid
38from .numeric_util import full_shape
39from .numeric_util import round_away_zero
40from .operation import create_activation_function
41from .operation import NpuBlockType
42from .operation import Op
43from .operation import Operation
44from .operation import Padding
45from .operation_util import create_avgpool_nop
46from .operation_util import get_pad_values_from_input
47from .shape4d import Shape4D
48from .softmax import SoftMax
49from .tensor import check_quantized_tens_scaling_equal
50from .tensor import create_const_tensor
51from .tensor import create_equivalence_id
52from .tensor import QuantizationParameters
53from .tensor import Tensor
54from .tensor import TensorPurpose
55from .tflite_mapping import optype_to_builtintype
56
57passthrough_nodes = (Op.Identity,)
58
59
60def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
61 """Creates an average pool for the given concat op/input feature map"""
62 ofm = concat_op.ofm
63 avgpool_op = create_avgpool_nop(name)
64 avgpool_op.inputs = [ifm]
65 avgpool_op.outputs = [ofm]
66
67 avgpool_op.write_offset = write_offset
68 avgpool_op.write_shape = ifm_shape
69 ofm.ops.append(avgpool_op)
70 DebugDatabase.add_optimised(concat_op, avgpool_op)
71 avgpool_op.ifm_shapes.append(ifm_shape)
72 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
73 avgpool_op.memory_function = Op.ConcatSliceWrite
74 return avgpool_op
75
76
77def remove_passthrough_tensor(tens, arch, nng):
78 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
79 assert len(tens.ops[0].inputs) == 1
80 tens = tens.ops[0].inputs[0]
81 return tens
82
83
84def rewrite_concat_ops(op, arch):
85 if not op.run_on_npu or not op.type.is_concat_op():
86 return
87
88 axis_4D = 0
89 ofm = op.ofm
90 ofm.ops = []
91 offset = 0
92
93 unfuse_activation_function(op)
94
95 if op.type == Op.Pack:
96 # Pack is also referred to as Stack
97 axis = int(op.attrs["axis"])
98 if axis < 0: # Convert to positive axis
99 axis = len(op.inputs[0].shape) + 1 + axis
100
101 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
102
103 axis_4D = axis + (4 - len(desired_shape))
104
105 for idx, inp in enumerate(op.inputs):
106 op.ifm_shapes[idx] = Shape4D(desired_shape)
107 op.type = Op.PackReshaped
108
109 inputs, axis = op.get_concat_inputs_axis()
110 for idx, inp in enumerate(inputs):
111 if op.type != Op.PackReshaped:
112 op.ifm_shapes[idx] = Shape4D(inp.shape)
113 if axis >= 0:
114 axis_4D = axis + (4 - len(inp.shape))
115 else:
116 axis_4D = axis
117 write_offset = [0, 0, 0, 0]
118 write_offset[axis_4D] = offset
119 concat_end = offset + op.ifm_shapes[idx][axis_4D]
120 create_avg_pool_for_concat(
121 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
122 )
123 offset = concat_end
124 assert ofm.shape[axis] == offset
125
126 return op
127
128
129def rewrite_split_ops(tens, arch, nng):
130
131 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
132 split_op = tens.ops[0]
133
134 # Not supported so leave it and run on CPU
135 if not split_op.run_on_npu:
136 return tens
137
138 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
139
140 tens.ops = []
141 new_op = Operation(Op.SplitSliceRead, split_op.name)
142 new_op.inputs = [inp]
143 ofm_shape_idx = 0
144 read_shape = offset_end
145
146 # For Split the offset cannot be extracted from the tensor so it has to
147 # be calculated from the index of the output tensor
148 if axis is not None:
149 # Get the start and end of the split
150 offset_start = [0] * 4
151 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
152 for idx, out in enumerate(outputs):
153 if axis_4D_list is not None:
154 axis_4D = axis_4D_list[idx]
155 else:
156 split_op.ofm_shapes[idx] = Shape4D(out.shape)
157 if axis >= 0:
158 axis_4D = axis + (4 - len(out.shape))
159 else:
160 axis_4D = axis
161
162 if out == tens:
163 ofm_shape_idx = idx
164 read_shape = split_op.ofm_shapes[idx]
165 break
166
167 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
168
169 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
170 new_op.read_shapes[0] = read_shape
171 new_op.run_on_npu = True
172 new_op.set_output_tensor(tens)
173 new_op.ifm_shapes.append(Shape4D(inp.shape))
174 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
175 DebugDatabase.add_optimised(split_op, new_op)
176
177 return tens
178
179
180def remove_SplitSliceRead(op, arch):
181
182 if op.type == Op.SplitSliceRead:
183 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
184 if (
185 len(op.ofm.consumer_list) == 1
186 and op.ofm.consumer_list[0] is not None
187 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200188 and op.ofm.consumer_list[0].type not in (Op.Reshape, Op.Squeeze)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200189 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
190 ):
191 # SplitSliceRead can be performed by tensor consumer
192 cons_op = op.ofm.consumer_list[0]
193 if cons_op.ifm == op.ofm:
194 cons_op.read_offsets[0] = op.read_offsets[0]
195 cons_op.read_shapes[0] = op.read_shapes[0]
196 cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[0])
197 cons_op.ifm_shapes[0] = op.ifm_shapes[0]
198 elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == op.ofm:
199 cons_op.read_offsets[1] = op.read_offsets[0]
200 cons_op.read_shapes[1] = op.read_shapes[0]
201 cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[1])
202 cons_op.ifm_shapes[1] = op.ifm_shapes[0]
203
204 if "skirt" in cons_op.attrs:
205 assert cons_op.attrs["explicit_padding"] == cons_op.attrs["skirt"]
206 cons_op.attrs["skirt"] = None
207 cons_op.attrs["force_padding"] = True
208 op.ofm.consumer_list.remove(cons_op)
209 op.ofm.ops = []
210 op.ifm.consumer_list.remove(op)
211 else:
212 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
213 avgpool_op.add_input_tensor(op.ifm)
214 avgpool_op.outputs = [op.ofm]
215 op.ofm.ops.remove(op)
216 op.ofm.ops.append(avgpool_op)
217 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
218 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
219 avgpool_op.read_offsets[0] = op.read_offsets[0]
220 avgpool_op.read_shapes[0] = op.read_shapes[0]
221
222 op.ifm.consumer_list.remove(op)
223 DebugDatabase.add_optimised(op, avgpool_op)
224
225
226def insert_copy_op_after_tens(tens):
227 tens_cons_list_copy = tens.consumer_list.copy()
228
229 # Create a avg_pool nop op with ifm as input
230 copy_tens = tens.clone()
231 copy_op = create_avgpool_nop(tens.name + "_avgpool")
232 copy_op.add_input_tensor(tens)
233 copy_op.set_output_tensor(copy_tens)
234 copy_op.set_ifm_ofm_shapes()
235 copy_op.run_on_npu = True
236
237 # Set copy_ifm consumers
238 for tens_cons in tens_cons_list_copy:
239 if tens_cons is not None:
240 for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
241 if cons_inp == tens:
242 tens_cons.set_input_tensor(copy_tens, ifm_idx)
243
244 DebugDatabase.add_optimised(tens.ops[0], copy_op)
245
246
247def fix_sg_input_output(op, arch, nng):
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200248 if not op.run_on_npu or op.type not in (Op.Reshape, Op.Squeeze):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200249 return op
250
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200251 # For the Reshape/Squeeze operators we want to remove, tensors are removed.
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200252 # But in order to to do this, they cannot be outputs of the sg,
253 # this need to be fixed prior to the removal.
254 # Solution is to add a avgpool NOP, to maintain the original tensor.
255 # This is also valid when reshape ifm/ofm is produced respectively
256 # consumed by CPU
257
258 # Check if operator ifm/ofm are sg ifm/ofm
259 ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
260 ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
261 ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200262 # Check if ifm/ofm is produced respectively consumed by CPU
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200263 ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
264 ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
265
266 if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200267 # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape/Squeeze
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200268 insert_copy_op_after_tens(op.ifm)
269
270 return op
271
272
273def calc_explicit_padding(input_size, stride, filter_size, pad_before, pad_after) -> Tuple[int, int]:
274 """
275 Based on explicit padding provided in a PAD operation, returns the corresponding hardware padding
276 that provides equivalent results.
277 """
278 total_padding = needed_total_padding(input_size, stride, filter_size)
279 # The top/left padding can be taken as is from the PAD
280 output_pad_before = pad_before
281 # The bottom/right padding might need downward adjustment depending on stride/input size
282 output_pad_after = pad_after
283 while output_pad_after > 0 and output_pad_after % stride != (total_padding - pad_before) % stride:
284 output_pad_after -= 1
285 return output_pad_before, output_pad_after
286
287
288def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
289 k_w, k_h = kernel.dilated_wh()
290 s_x, s_y = kernel.stride
291 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
292 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
293 if padding_type == Padding.SAME:
294 left_pad = (xpad + 0) // 2
295 right_pad = (xpad + 1) // 2
296 top_pad = (ypad + 0) // 2
297 bottom_pad = (ypad + 1) // 2
298 elif padding_type == Padding.VALID:
299 left_pad = 0
300 right_pad = 0
301 top_pad = 0
302 bottom_pad = 0
303 elif padding_type == Padding.EXPLICIT:
304 # Padding is specified in a PAD operator which has been bypassed.
305 top, left, bottom, right = explicit_padding
306 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
307 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
308 else:
309 raise UnsupportedFeatureError(f"Unknown padding")
310 padding = (top_pad, left_pad, bottom_pad, right_pad)
311 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
312 return padding, skirt
313
314
315def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
316 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
317 if padding_type == Padding.SAME:
318 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
319 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
320 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
321 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
322 left_pad = max(kernel_width - 1 - right_pad, 0)
323 top_pad = max(kernel_height - 1 - bottom_pad, 0)
324 elif padding_type == Padding.VALID:
325 right_pad = max(kernel_width - 2, 0)
326 bottom_pad = max(kernel_height - 2, 0)
327 left_pad = kernel_width - 1
328 top_pad = kernel_height - 1
329 else:
330 raise UnsupportedFeatureError(f"Unknown padding")
331 padding = (top_pad, left_pad, bottom_pad, right_pad)
332 skirt = padding
333 return padding, skirt
334
335
336def fixup_conv2d_backprop(op, arch, nng):
337 if op.type == Op.Conv2DBackpropInput:
338 # flip the inputs
339 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
340 op.type = Op.Conv2DBackpropInputSwitchedBias
341 op.ifm.resampling_mode = resampling_mode.TRANSPOSE
342
343 # Update strides
344 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
345
346 return op
347
348
349# Convert the op to an elementwise add
350def convert_resizebilinear_1x1_to_add(op):
351 op.type = Op.Add
352 op.name = op.name + "_add"
353 op.attrs["resizebilinear"] = True
354 # Create an input tensor filled with zeros
355 shape = op.ofm_shapes[0].as_list()
356 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100357 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200358 tens.quantization = QuantizationParameters(0.0, 255.0)
359 tens.quantization.scale_f32 = 1.0
360 tens.quantization.zero_point = 0
361 tens.consumer_list = [op]
362 tens_op = op.inputs[1].ops[0]
363 tens_op.set_output_tensor(tens)
364 # Set the add inputs
365 op.inputs[1] = op.inputs[0]
366 op.inputs[0] = tens
367 op.set_ifm_ofm_shapes()
368
369 return op
370
371
372# Convert ResizeBilinear to a number of 2x2 pool ops
373def convert_resizebilinear_to_2x2_pool(op):
374 count = 0
375 pre_op = op
376 outputs = op.outputs
377
378 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
379 if op.attrs["align_corners"]:
380 shape_modifier = 1
381 op.attrs["padding"] = Padding.VALID
382 else:
383 shape_modifier = 0
384 op.attrs["padding"] = Padding.SAME
385 op.inputs[0].resampling_mode = resampling_mode.NEAREST
386
387 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
388 out_shape = np.array(op.ofm_shapes[0].get_hw_as_list())
389 if (upscaled_shape == upscaled_shape * 2 - shape_modifier).all():
390 return op
391
392 while (upscaled_shape < out_shape).all():
393 if count == 0:
394 scaled_op = pre_op
395 else:
396 scaled_op = op.clone("_{}".format(count))
397 scaled_op.inputs[0] = pre_op.outputs[0]
398
399 upscaled_shape = upscaled_shape * 2 - shape_modifier
400
401 if (upscaled_shape == out_shape).all():
402 scaled_op.outputs = outputs
403 scaled_op.outputs[0].ops = [scaled_op]
404 else:
405 shape = op.ofm_shapes[0].as_list()
406 shape[1:3] = upscaled_shape
407 out_tens = Tensor(shape, DataType.int16, "{}_{}".format(op.outputs[0].name, count))
408 out_tens.quantization = op.outputs[0].quantization.clone()
409 out_tens.quantization.quant_min = np.iinfo(np.int16).min
410 out_tens.quantization.quant_max = np.iinfo(np.int16).max
411 scaled_op.set_output_tensor(out_tens)
412 pre_op = scaled_op
413 count += 1
414
415 # Setup the scale value
416 if scaled_op.inputs[0].dtype.bits == 8 and scaled_op.outputs[0].dtype.bits == 16:
417 scaled_op.rescale = 128
418 elif scaled_op.inputs[0].dtype.bits == 16 and scaled_op.outputs[0].dtype.bits == 8:
419 scaled_op.rescale = 1 / 128
420 else:
421 scaled_op.rescale = None
422 scaled_op.set_ifm_ofm_shapes()
423
424 return op
425
426
427def fixup_resizebilinear(op, arch, nng):
428 if op.type == Op.ResizeBilinear and op.run_on_npu:
429 if op.ifm_shapes[0] == op.ofm_shapes[0]:
430 # Bypass nop resizebilinear
431 op.inputs = op.inputs[:1]
432 op.type = Op.Identity
433 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
434 convert_resizebilinear_1x1_to_add(op)
435 else:
436 convert_resizebilinear_to_2x2_pool(op)
437
438 return op
439
440
441def convert_nop_split_to_identity(op, arch, nng):
442 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
443 # the list comprehension should return a list with a single tensor
444 # if it shouldn't, remove_passthrough_tensor will fail appropriately
445 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
446 op.type = Op.Identity
447 return op
448
449
450def rewrite_fully_connected_input(op, arch, nng):
451 if op.type == Op.FullyConnected:
452 n_in_elems = op.weights.shape[-2]
453 elms = op.ifm.elements()
454 batch_size = elms // n_in_elems
455 assert batch_size * n_in_elems == elms
456
457 op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
458 return op
459
460
461def convert_batched_fc_shape(op, arch, nng):
462 if op.type == Op.FullyConnected:
463 # Check if the first dimension indicates batching
464 if op.ifm_shapes[0].batch > 1:
465 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
466 n = op.ifm_shapes[0].batch
467 h, w = batching_split.get(n, (1, n))
468 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
469
470 # Reshape Weights to be 4D. IO becomes HWIO
471 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100472 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
473 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200474
475 n = op.ofm_shapes[0].batch
476 h, w = batching_split.get(n, (1, n))
477 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
478 return op
479
480
481def unfuse_activation_function(op):
482 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
483 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
484 op.activation = None
485 out_tens = op.outputs[0]
486 intermediate_tens = out_tens.clone("_act_intermediate")
487 act_op.set_output_tensor(out_tens)
488 act_op.add_input_tensor(intermediate_tens)
489 op.set_output_tensor(intermediate_tens)
490 act_op.set_ifm_ofm_shapes()
491
492
493def rewrite_stridedslice_output(op, arch, nng):
494 if not op.run_on_npu or op.type != Op.StridedSlice:
495 return op
496
497 new_axis_mask = op.attrs["new_axis_mask"]
498 shrink_axis_mask = op.attrs["shrink_axis_mask"]
499
500 if shrink_axis_mask == 0 and new_axis_mask == 0:
501 return op
502
503 axis_4D = [0] * len(op.outputs)
504 for idx, out_tens in enumerate(op.outputs):
505 output_shape = list(out_tens.shape)
506
507 if shrink_axis_mask != 0:
508 n = 0
509 axis = 0
510 while shrink_axis_mask:
511 prev_mask = shrink_axis_mask
512 n += 1
513 shrink_axis_mask &= shrink_axis_mask - 1
514 axis = int(math.log2(prev_mask - shrink_axis_mask))
515 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
516
517 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
518 op.attrs["shrink_axis_mask"] = 0
519 if axis >= 0:
520 axis_4D[idx] = axis + (4 - len(output_shape))
521 else:
522 axis_4D[idx] = axis
523 op.ofm_shapes[idx] = Shape4D(output_shape)
524
525 elif new_axis_mask != 0:
526 n = 0
527 axis = 0
528 while new_axis_mask:
529 prev_mask = new_axis_mask
530 n += 1
531 new_axis_mask &= new_axis_mask - 1
532 axis = int(math.log2(prev_mask - new_axis_mask))
533 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
534 new_axis_mask >>= 1
535
536 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
537 op.attrs["new_axis_mask"] = 0
538 if axis >= 0:
539 axis_4D[idx] = axis + (4 - len(output_shape))
540 else:
541 axis_4D[idx] = axis
542 op.ofm_shapes[idx] = Shape4D(output_shape)
543
544 op.attrs["split_axis_4D"] = axis_4D
545 return op
546
547
548def rewrite_unpack_output(op, arch, nng):
549 tens = op.outputs[0]
550 if op.run_on_npu and op.type == Op.Unpack:
551 # Unpack is also referred to as Unstack
552 axis = int(op.attrs["axis"])
553 if axis < 0: # Convert to positive axis
554 axis = len(op.inputs[0].shape) + 1 + axis
555 op.type = Op.UnpackReshaped
556 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
557
558 axis_4D = axis + (4 - len(desired_output_shape))
559 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
560
561 for idx, out_tens in enumerate(op.outputs):
562 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
563 return op
564
565
566def add_padding_fields(op, arch, nng):
567 if op.run_on_npu:
568 if "padding" in op.attrs:
569 input_shape = op.ifm_shapes[0]
570 output_shape = op.ofm_shapes[0]
571 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
572 kernel_size = op.inputs[1].shape[:2]
573 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
574 kernel_size = op.attrs["ksize"][1:3]
575 else:
576 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
577
578 if op.type == Op.Conv2DBackpropInputSwitchedBias:
579 upscaling_factor = output_shape.height // input_shape.height
580 padding, skirt = calc_upscaled_padding_and_skirt(
581 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
582 )
583 else:
584 padding, skirt = calc_padding_and_skirt(
585 op.attrs["padding"], op.kernel, input_shape, op.attrs.get("explicit_padding"),
586 )
587
588 op.attrs["explicit_padding"] = padding
589 op.attrs["skirt"] = skirt
590
591 return op
592
593
594def convert_depthwise_to_conv(op, arch, nng):
595 # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and
596 # the ofm depth equals the depth multipler.
597 # If those conditions are true, then we can perform a simple
598 # switch of the operator type (and weight order)
599
600 if op.type == Op.DepthwiseConv2DBias and (op.attrs["depth_multiplier"] != 1):
601 ifm_shape = op.ifm_shapes[0]
602 weight_tensor = op.inputs[1]
603 ofm_shape = op.ofm_shapes[0]
604 if (ifm_shape.depth == 1) and (ofm_shape.depth == op.attrs["depth_multiplier"]):
605 # Change op type to Conv2d
606 op.type = Op.Conv2DBias
607 del op.attrs["channel_multiplier"]
608 del op.attrs["depth_multiplier"]
609
James Peet7519d502021-07-19 16:47:58 +0100610 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
611 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200612 else:
613 raise UnsupportedFeatureError(
614 f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},",
615 f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}",
616 )
617 DebugDatabase.add_optimised(op, op)
618 return op
619
620
621def reorder_depthwise_weights(op, arch, nng):
622 if op.type.is_depthwise_conv2d_op():
623 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100624 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
625 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200626 weight_tensor.weight_transpose_depthwise = True
627
628 return op
629
630
631def optimise_strided_conv(op, arch, nng):
632 stride_x, stride_y = op.get_kernel_stride()
633 ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()
634
635 if (
636 op.type == Op.Conv2DBias
637 and op.op_index == 0
638 and stride_x == 2
639 and op.ifm_shapes[0].depth <= 4
640 and op.ifm_shapes[0].width % 2 == 0
641 and weight_tensor is not None
642 and weight_tensor.shape[1] >= 2
643 ):
644 ifm_shape = op.ifm_shapes[0]
645 # IFM
646 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
647
648 # Weights
649 weight_shape = weight_tensor.shape
650 if weight_shape[1] % 2 != 0:
651 weight_shape[1] = weight_shape[1] + 1
652 padded_array = np.zeros(weight_shape)
653 for i in range(weight_shape[0]):
654 padded_array[i] = np.vstack(
655 [
James Peet7519d502021-07-19 16:47:58 +0100656 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200657 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
658 ]
659 )
James Peet7519d502021-07-19 16:47:58 +0100660 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200661 weight_shape[1] //= 2
662 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100663 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200664 weight_tensor.set_all_shapes(weight_shape)
665 # If multiple copies of the weights are used, we could avoid
666 # them having the same address by changing the value_id
667 weight_tensor.value_id = uuid.uuid4()
668
669 # Strides
670 stride_x = 1
671 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
672
673 return op
674
675
676def convert_conv_to_fc(op, arch, nng):
677 # Conv 1x1 can be equivalent to Fully Connected.
678 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
679 # caching/double buffering for the weights.
680 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
681 if op.type == Op.Conv2DBias:
682 h = op.ifm_shapes[0].height
683 w = op.ifm_shapes[0].width
684 kh, kw, _, _ = op.inputs[1].shape
685 if h == 1 and w == 1 and kh == 1 and kw == 1:
686 # Overwrite this op as a Fully Connected Op
687 op.name += "_fc"
688 op.type = Op.FullyConnected
689 op.attrs = {
690 "weights_format": 0,
691 }
692 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
693 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100694 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
695 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200696
697 DebugDatabase.add_optimised(op, op)
698 return op
699
700
701def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
702 if op.run_on_npu and op.type.is_relu_op():
703 ifm = op.inputs[0]
704 ofm = op.outputs[0]
705 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
706 # and requires its own to be inserted
707 if not check_quantized_tens_scaling_equal(ifm, ofm):
708 # Override this op with its own primary op (avgpool)
709 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
710 # And fuse the original activation function to it
711 relu_fused_op.activation = create_activation_function(op.type)
712 # Tidy up and assign the ifm and ofm to the new op
713 ifm.consumer_list.remove(op)
714
715 relu_fused_op.add_input_tensor(ifm)
716 relu_fused_op.set_output_tensor(ofm)
717 relu_fused_op.set_ifm_ofm_shapes()
718 op = relu_fused_op
719 return op
720
721
722def fixup_elementwise_with_scalars(op, arch, nng):
723 if op.type.is_binary_elementwise_op():
724 ifm_tensor, ifm2_tensor, _, _ = op.get_ifm_ifm2_weights_ofm()
725 if ifm2_tensor.shape != [] and ifm_tensor.shape != []:
726 diff = len(ifm_tensor.shape) - len(ifm2_tensor.shape)
727 if diff > 0:
728 ifm2_tensor.shape = full_shape(len(ifm_tensor.shape), ifm2_tensor.shape, 1)
729 elif diff < 0:
730 ifm_tensor.shape = full_shape(len(ifm2_tensor.shape), ifm_tensor.shape, 1)
James Peet7519d502021-07-19 16:47:58 +0100731 elif ifm_tensor.shape == [] and ifm_tensor.values is None:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200732 # IFM is marked as a scalar, but is a result of an operation; change it to a shape of size 1
733 ifm_tensor.shape = len(ifm2_tensor.shape) * [1]
734 ifm_tensor.storage_shape = ifm_tensor.shape
James Peet7519d502021-07-19 16:47:58 +0100735 elif ifm2_tensor.shape == [] and ifm2_tensor.values is None:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200736 # IFM2 is marked as a scalar, but is a result of an operation; change it to a shape of size 1
737 ifm2_tensor.shape = len(ifm_tensor.shape) * [1]
738 ifm2_tensor.storage_shape = ifm2_tensor.shape
739 return op
740
741
742def convert_softmax(op, arch, nng):
743 if op.type == Op.Softmax and op.run_on_npu:
744 softmax = SoftMax(op)
745 op = softmax.get_graph()
746 return op
747
748
749def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
750 r"""Whenever there is a subgraph with this topology:
751
752 Input X For X = -1 or X > 0
753 | \ / This subgraph can be replaced with either
754 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
755 | /
756 Max
757 """
758
759 if op.type == Op.Maximum:
760 # finds the Mul input(s) to the Max
761 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
762 if len(muls) == 1:
763 mul = muls[0].ops[0]
764 elif len(muls) == 2:
765 # In the case both inputs are Muls, find the one with the same input as the Max
766 mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
767 else:
768 # No Mul inputs
769 return op
770
771 # make sure the Mul doesn't have any other consumers
772 mul_ofm = mul.outputs[0]
773 if len(mul_ofm.consumers()) != 1:
774 return op
775 # make sure the Mul doesn't have a fused activation function
776 if mul.activation:
777 return op
778 ifm, ofm = op.get_ifm_ofm()
779 if ifm is None or ofm is None:
780 return op
781
782 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
783 return op
784 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
785 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
786 return op
787
788 # finds the branched input that goes to both the Max and the Mul
789 shared = set(op.inputs) & set(mul.inputs)
790 if len(shared) == 1:
791 shared_in = shared.pop()
792 # find the constant scalar input to the Mul
793 const_tens = (set(mul.inputs) - {shared_in}).pop()
794 # check that it is a scalar
795 if const_tens.shape != []:
796 return op
797 const = const_tens.ops[0]
798 # check that it is a constant
799 if const.type != Op.Const:
800 return op
801 # Remove the Mul from the shared input's consumers
802 shared_in.consumer_list.remove(mul)
803 else:
804 return op
805
806 val = const.outputs[0].values
807 if val >= 0:
808 new_op = Op.LeakyRelu
809 op.attrs["alpha"] = val
810 # to produce bit exact results, the alpha is not enough;
811 # save additional scaling info in attr "alpha_scale", to be used as input
812 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +0100813 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200814 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
815 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
816 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
817 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
818 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
819 elif val == -1:
820 new_op = Op.Abs
821 else:
822 return op
823
824 op.type = new_op
825 op.name = op.name.replace("Maximum", new_op.name)
826 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
827 op.inputs = [shared_in]
828 op.set_ifm_ofm_shapes()
829
830 # Record optimisation in debug database
831 DebugDatabase.add_optimised(op, op)
832
833 return op
834
835
836def convert_hardswish_to_lut(op, arch, nng):
837 if op.type == Op.HardSwish:
838 ifm, ofm = op.get_ifm_ofm()
839 # Generate the LUT
840 ifm_scale = np.double(ifm.quantization.scale_f32)
841 ofm_scale = np.double(ofm.quantization.scale_f32)
842 zp_in = ifm.quantization.zero_point
843 zp_out = ofm.quantization.zero_point
844 ifm_scale_hires = (1 / 128) * ifm_scale
845 relu_multiplier = np.double(3 / 32768)
846 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
847 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
848 # Use 16bit scale
849 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
850 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
851
852 values = []
853 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
854 quantized_min = min(ix)
855 quantized_max = max(ix)
856 for x in ix:
857 input_value = x - zp_in
858 input_value_hires = input_value * 128
859 # Compute the input value on essentially the output scale, not shifted yet
860 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
861 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
862 relu_value = np.int16(input_value_hires)
863 if relu_shift < 31:
864 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
865
866 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
867
868 if relu_shift < 31:
869 relu_value = fp_math.shift_left16(relu_value, 1)
870
871 if relu_shift > 31:
872 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
873
874 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
875 # Now convert that to a 16bit fixedpoint value in [0, 1]
876 relu_value = (relu_value + (1 << 15)) >> 1
877 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
878 shift = 31 - out_shift
879 shift = -shift if shift < 0 else 0
880 # Finally apply the output shift
881 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
882 lut_result = min(quantized_max, max(quantized_min, lut_result))
883 values.append(lut_result)
884 return convert_to_lut(op, values, "hardswish")
885 return op
886
887
888def convert_lrelu_to_mul_max(op, arch):
889 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
890 # (the opposite of convert_mul_max_to_abs_or_lrelu)
891 ifm, ofm = op.get_ifm_ofm()
892 if ifm is None or ofm is None:
893 return op
894
895 # Add multiplication with alpha
896 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
897 mul_alpha.add_input_tensor(ifm)
898 # Create const tensor containing alpha as scalar
899 alpha = op.attrs["alpha"]
900 quantization = ifm.quantization.clone()
901 quantization.min = 0
902 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
903 quantization.zero_point = 0
904 if np.isinf(1 / np.float32(alpha)):
905 # Handling of alpha near zero
906 quantization.scale_f32 = 1
907 scalar = 0
908 else:
909 quantization.scale_f32 = alpha
910 scalar = alpha
911 alpha_tens = create_const_tensor(
912 op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
913 )
James Peet7519d502021-07-19 16:47:58 +0100914 alpha_tens.values = np.array([1])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200915 mul_alpha.add_input_tensor(alpha_tens)
916 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
917 mul_alpha.set_output_tensor(fm_alpha)
918 mul_alpha.set_ifm_ofm_shapes()
919 DebugDatabase.add_optimised(op, mul_alpha)
920
921 if check_quantized_tens_scaling_equal(ifm, ofm):
922 # No identity multiplication is needed
923 fm_id = ifm
924 else:
925 # Add multiplication with identity
926 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
927 mul_identity.add_input_tensor(ifm)
928 # Create const tensor containing identity as scalar
929 quantization = ifm.quantization.clone()
930 quantization.min = 0
931 quantization.max = quantization.quant_max - quantization.quant_min
932 quantization.scale_f32 = 1
933 quantization.zero_point = 0
934 identity_tens = create_const_tensor(
935 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
936 )
937 mul_identity.add_input_tensor(identity_tens)
938 # Make sure that fm_id is allocated to a different address than fm_alpha
939 fm_id = ofm.clone(op.name + "_id", set_unique=True)
940 mul_identity.set_output_tensor(fm_id)
941 mul_identity.set_ifm_ofm_shapes()
942 DebugDatabase.add_optimised(op, mul_identity)
943
944 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
945 op.type = Op.Maximum
946 op.name = op.name.replace("LeakyRelu", "Maximum")
947 op.inputs = []
948 ifm.consumer_list.remove(op)
949 op.add_input_tensor(fm_alpha)
950 op.add_input_tensor(fm_id)
951 op.set_ifm_ofm_shapes()
952
953 DebugDatabase.add_optimised(op, op)
954 return op
955
956
957def convert_to_lut(op, lut_values, lut_name):
958 # Rewrite the operation by Add with scalar 0 + LUT activation
959 ifm = op.inputs[0]
960 if ifm is None:
961 return op
962 assert ifm.dtype.size_in_bytes() == 1
963 op.type = Op.Add
964 op.name = op.name + "_lut_" + lut_name
965 # Mark as no-op to enable potential fusing optimizations
966 op.attrs["is_nop"] = True
967 # Create an input tensor containing scalar zero
968 quantization = QuantizationParameters(0.0, 255.0)
969 quantization.scale_f32 = ifm.quantization.scale_f32
970 quantization.zero_point = 0
971 tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization)
972 op.add_input_tensor(tens)
973 op.ifm_shapes.append(Shape4D(tens.shape))
974
975 # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
976 # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
977 # should be the same as the IFM
978 op.forced_output_quantization = ifm.quantization
979 lut_tensor = lut.create_lut_tensor(op.name + "_values", lut_values, DataType.int8)
980 op.set_activation_lut(lut_tensor)
981 op.set_ifm_ofm_shapes()
982 return op
983
984
985def convert_to_lut8(op, fn, fn_name):
986 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
987 # fn is a function(real) -> real
988 ifm, ofm = op.get_ifm_ofm()
989 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
990 return op
991 # Generate the LUT
992 ifm_scale = np.double(ifm.quantization.scale_f32)
993 ofm_scale = np.double(ofm.quantization.scale_f32)
994 zp_in = ifm.quantization.zero_point
995 zp_out = ofm.quantization.zero_point
996 values = []
997 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
998 quantized_min = min(ix)
999 quantized_max = max(ix)
1000 for x in ix:
1001 x_real = ifm_scale * (x - zp_in)
1002 y_real = fn(x_real)
1003 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1004 lut_result = min(quantized_max, max(quantized_min, lut_result))
1005 values.append(lut_result)
1006 return convert_to_lut(op, values, fn_name)
1007
1008
1009def convert_lrelu_to_lut(op, arch):
1010 ifm, ofm = op.get_ifm_ofm()
1011 # Generate the LUT
1012 alpha = op.attrs["alpha"]
1013 ifm_scale = np.double(ifm.quantization.scale_f32)
1014 ofm_scale = np.double(ofm.quantization.scale_f32)
1015 zp_in = ifm.quantization.zero_point
1016 zp_out = ofm.quantization.zero_point
1017 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1018 alpha_scalar = 1
1019 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1020 if "alpha_scaling" in op.attrs:
1021 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1022 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1023 values = []
1024 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1025 quantized_min = min(ix)
1026 quantized_max = max(ix)
1027 for x in ix:
1028 if x < zp_in:
1029 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1030 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1031 )
1032 else:
1033 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1034 lut_result = min(quantized_max, max(quantized_min, lut_result))
1035 values.append(lut_result)
1036 return convert_to_lut(op, values, "lrelu")
1037
1038
1039def convert_lrelu(op, arch, nng):
1040 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
1041 if op.type != Op.LeakyRelu:
1042 return op
1043 ifm, ofm = op.get_ifm_ofm()
1044 if ifm is None or ofm is None:
1045 return op
1046 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1047 # use LUT for int8/uint8
1048 return convert_lrelu_to_lut(op, arch)
1049 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
1050 # use LeakyRelu unmodified for int16 with equal input/output scaling
1051 return op
1052 return convert_lrelu_to_mul_max(op, arch)
1053
1054
1055def convert_tanh_sigmoid_to_lut(op, arch, nng):
1056 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
1057 if op.type == Op.Sigmoid:
1058 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1059 elif op.type == Op.Tanh:
1060 return convert_to_lut8(op, math.tanh, "tanh")
1061 return op
1062
1063
Jonas Ohlsson81942e92021-08-20 09:33:28 +02001064def remove_reshape_and_squeeze_ops(op, arch):
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +02001065 if op.run_on_npu and op.type in (Op.Reshape, Op.Squeeze):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001066 ofm = op.ofm
1067 ifm = op.ifm
1068
1069 # Check if quantization is the same in the input and output for the reshape ops
1070 if not check_quantized_tens_scaling_equal(ifm, ofm):
1071 # TODO Both tensors are needed, since quantisation properties currently are linked to Tensors.
1072 # In order to remove this reshape either quantization properties need to be moved to Operator,
1073 # or the reshape need to be replace with a NOP.
1074 return
1075
Jonas Ohlsson81942e92021-08-20 09:33:28 +02001076 # Check if ifm/ofm are network ifm/ofm
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001077 ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
1078 ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list)
1079 ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list)
Jonas Ohlsson81942e92021-08-20 09:33:28 +02001080 # Check if ifm/ofm is produced respectively consumed by CPU
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001081 ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
1082 ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
1083
1084 # This case should be handled prior to this function
1085 assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed))
1086
1087 if ofm_is_sg_ofm or ofm_is_cpu_consumed:
1088 # Bypassed by replacing ifm with ofm
1089 ofm.ops = []
1090 for prev_op in ifm.ops:
1091 prev_op.outputs = [ofm]
1092 ofm.ops.append(prev_op)
1093
1094 # All ifm consumers need to use ofm as input
1095 for ifm_cons in ifm.consumer_list:
1096 for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
1097 if cons_ifm == ifm:
1098 ifm_cons.set_input_tensor(ofm, ifm_idx)
1099 else:
Jonas Ohlsson81942e92021-08-20 09:33:28 +02001100 # Bypassed by replacing ofm with ifm
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001101 for cons in ofm.consumer_list:
1102 for ifm_idx, cons_ifm in enumerate(cons.inputs):
1103 if cons_ifm == ofm:
1104 cons.set_input_tensor(ifm, ifm_idx)
1105
1106
1107def fuse_activation_function_with_prev(op, arch, nng):
1108 # if op is a no-op: attempts to move the activation function to the preceding op
1109 if not op.attrs.get("is_nop", False) or op.activation is None:
1110 return op
1111 ifm, ofm = op.get_ifm_ofm()
1112 if ifm is None or ofm is None:
1113 return op
1114 # finds the input(s) to the operation
1115 prev_op = ifm.ops[0]
1116 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1117 fuse = (
1118 prev_op.run_on_npu
1119 and prev_op.type.npu_block_type != NpuBlockType.Default
1120 and len(ifm.ops) == 1
1121 and len(prev_op.outputs[0].consumers()) == 1
1122 and prev_op.activation is None
1123 )
1124 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1125 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1126 # LUT currently only works correctly for elementwise ops
1127 fuse = False
1128 if not fuse:
1129 return op
1130 # Move the fused activation function + corresponding info to prev_op
1131 prev_op.activation = op.activation
1132 prev_op.forced_output_quantization = op.forced_output_quantization
1133 if op.activation_lut is not None:
1134 prev_op.set_activation_lut(op.activation_lut)
1135 # Bypass op
1136 prev_op.set_output_tensor(ofm)
1137 DebugDatabase.add_optimised(op, prev_op)
1138 return op
1139
1140
1141def _leading_pad_ok(leading_pad, stride, kernel_size):
1142 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1143 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1144 max_size = kernel_size // 2
1145 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1146
1147
1148def replace_pad_by_hw_pad(op: Operation, arch, nng):
1149 """
1150 Tries to completely remove a PAD operator by using hardware padding.
1151 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1152 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1153 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1154 if both operations can be run on the NPU.
1155 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1156 """
1157 if (
1158 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
1159 and op.run_on_npu
1160 and op.attrs["padding"] == Padding.VALID
1161 ):
1162 pad_op = op.ifm.ops[0]
1163 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1164 return op
1165 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1166 return op
1167 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1168 k = op.kernel
1169 k_w, k_h = k.dilated_wh()
1170
1171 # Check if the PAD operator can be replaced by hardware padding
1172 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1173 # Too much padding, it would require hardware padding to actually insert zeros
1174 return op
1175 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1176 return op
1177
1178 if op.type.is_avgpool_op():
1179 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1180 for pad, k_size in (
1181 (left, k_w),
1182 (right, k_w),
1183 (top, k_h),
1184 (bottom, k_h),
1185 ):
1186 if pad not in (0, k_size // 2):
1187 return op
1188 # Average pool is converted to depthwise, because NPU average pool + same padding
1189 # has a special implementation that is different from PAD followed by average pool with
1190 # valid padding.
1191 k_w, k_h = op.kernel.width, op.kernel.height
1192 ifm = op.ifm
1193 # Remember other inputs
1194 other_inputs = op.inputs[1:]
1195 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1196 quantization = QuantizationParameters(0.0, 255.0)
1197 quantization.scale_f32 = 1.0 / (k_w * k_h)
1198 quantization.zero_point = 0
1199 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1200 weights = np.full(shape, 1)
1201
1202 weight_tens = create_const_tensor(
1203 op.name + "_weights",
1204 shape,
1205 op.ifm.dtype,
1206 weights,
1207 np.uint8,
1208 purpose=TensorPurpose.Weights,
1209 quantization=quantization,
1210 )
James Peet7519d502021-07-19 16:47:58 +01001211 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001212 op.type = Op.DepthwiseConv2DBias
1213 op.inputs = []
1214 op.add_input_tensor(ifm)
1215 op.add_input_tensor(weight_tens)
1216 # Add bias tensor, all biases set to 0
1217 op.inputs.append(None)
1218 fixup_bias_tensors(op, arch, nng)
1219 # Add other inputs
1220 op.inputs.extend(other_inputs)
1221 op.rounding_mode = NpuRoundingMode.NATURAL
1222
1223 # Bypass the PAD operator
1224 op.set_input_tensor(pad_op.ifm, 0)
1225 # Adjust the padding attributes of the convolution operator
1226 op.attrs["padding"] = Padding.EXPLICIT
1227 op.attrs["explicit_padding"] = (top, left, bottom, right)
1228 op.set_ifm_ofm_shapes()
1229 return op
1230
1231
1232def convert_pad(op: Operation, arch, nng):
1233 """
1234 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1235 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1236 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1237 """
1238 if op.type != Op.Pad or not op.run_on_npu:
1239 return op
1240 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1241
1242 ifm = op.ifm
1243 assert ifm is not None
1244 ifm_shape = Shape4D(ifm.shape)
1245 ofm = op.ofm
1246 assert ofm is not None
1247 ofm.ops = []
1248 ofm_shape = op.ofm_shapes[0]
1249
1250 # Average pool op that copies IFM to the right place inside the OFM
1251 shp0 = Shape4D(0, 0, 0, 0)
1252 shp_top = shp0.with_height(top)
1253 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1254 avgpool_op.activation = op.activation
1255 quant = ofm.quantization
1256 pad_value = quant.zero_point
1257 # Add operations that fill the borders of the OFM
1258 if top > 0:
1259 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1260 zero_tens = create_const_tensor(
1261 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1262 )
1263 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1264 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1265 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1266 if bottom > 0:
1267 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1268 zero_tens = create_const_tensor(
1269 op.name + "_bottom",
1270 shape.as_list(),
1271 ofm.dtype,
1272 shape.elements() * [pad_value],
1273 np.uint8,
1274 quantization=quant,
1275 )
1276 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1277 create_avg_pool_for_concat(
1278 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1279 )
1280 if left > 0:
1281 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1282 zero_tens = create_const_tensor(
1283 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1284 )
1285 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1286 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1287 if right > 0:
1288 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1289 zero_tens = create_const_tensor(
1290 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1291 )
1292 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1293 create_avg_pool_for_concat(
1294 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1295 )
1296
1297 op.type = Op.ConcatTFLite
1298 return avgpool_op
1299
1300
1301def add_attrs_to_resizebilinear(op, arch, nng):
1302 if op.type == Op.ResizeBilinear and op.run_on_npu:
1303 input_tensor = op.inputs[0]
1304 input_shape = op.ifm_shapes[0]
1305 upscaled_height = input_shape.height * 2
1306 upscaled_width = input_shape.width * 2
1307 out_shape = op.ofm_shapes[0]
1308 if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
1309 # this means the output is supposed to be a x2 upscale,
1310 # so we need to do SAME padding
1311 op.attrs["padding"] = Padding.SAME
1312 elif (
1313 op.attrs["align_corners"]
1314 and out_shape.height == (upscaled_height - 1)
1315 and out_shape.width == (upscaled_width - 1)
1316 ):
1317 # here we can just run the avg pool without padding and
1318 # produce a (M * 2 - 1, N * 2 - 1) sized output
1319 op.attrs["padding"] = Padding.VALID
1320 else:
1321 return op
1322 input_tensor.resampling_mode = resampling_mode.NEAREST
1323 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
1324 return op
1325
1326
1327def fixup_bias_tensors(op, arch, nng):
1328 if op.type.needs_bias() and op.bias is None:
1329 # Op has no bias, add bias tensor filled with zeros
1330 nr_biases = op.inputs[1].shape[-1]
1331 bias_values = [0] * nr_biases
1332 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001333 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1334
1335 return op
1336
1337
1338def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1339 if op.type == Op.Mean and op.run_on_npu:
1340 keep_dims = op.attrs.get("keep_dims", False)
1341 inp, axis = op.inputs
1342 shape = inp.shape
1343 dims = len(shape)
1344
1345 # Height and width axes have different index depending on dimensions
1346 if axis.shape == [] or axis.shape[0] == 1: # single axis
1347 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1348 if dims in (2, 3):
1349 if axis == 0:
1350 h, w = shape[axis], 1
1351 else:
1352 h, w = 1, shape[axis]
1353 else:
1354 if axis == 1:
1355 h, w = shape[axis], 1
1356 else:
1357 h, w = 1, shape[axis]
1358 else: # multiple axes
1359 axis = sorted(axis.values)
1360 h, w = [shape[i] for i in axis]
1361
1362 # Set necessary depthwise attributes
1363 op.attrs.update(
1364 {
1365 "padding": Padding.VALID,
1366 "stride_h": 1,
1367 "stride_w": 1,
1368 "strides": (1, 1, 1, 1),
1369 "depth_multiplier": 1,
1370 "channel_multiplier": 1,
1371 "dilation_h_factor": 1,
1372 "dilation_w_factor": 1,
1373 "dilation": (1, 1, 1, 1),
1374 }
1375 )
1376 # Change op type
1377 op.type = Op.DepthwiseConv2DBias
1378 # Set IFM/OFM shapes after changing op type
1379 op.set_ifm_ofm_shapes()
1380
1381 weight_scale, bias = 1, None
1382 ofmq, ifmq = op.ofm.quantization, inp.quantization
1383 # Set rounding mode, scaling and zero point based on which reference implementation to match
1384 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1385 if inp.dtype == DataType.uint8:
1386 # This attribute means a different scaling calculation is used in order to match reference
1387 op.low_precision_scaling = True
1388 weight_scale = h * w
1389 # Set zero points to 0 as they will be adjusted for with bias term
1390 foq = ofmq.clone()
1391 foq.zero_point = 0
1392 fiq = ifmq.clone()
1393 fiq.zero_point = 0
1394 op.forced_input_quantization = fiq
1395 bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
1396 # If the bias term is outside uint8 range, we need an Add op to apply it.
1397 if bias_term < 0 or bias_term > 255:
1398 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1399 # Bias term has higher bitness (i32) than input/output (u8).
1400 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1401 # the bias can only effectively assume values in the range [-255, 255].
1402 intermediate.dtype = DataType.int16
1403 intermediate.quantization.zero_point = 0
1404 add_op = Operation(Op.Add, op.name + "_bias")
1405 add_op.forced_output_quantization = foq
1406 add_op.add_input_tensor(intermediate)
1407 quant = QuantizationParameters()
1408 quant.zero_point = 0
1409 bias_term_tens = create_const_tensor(
James Peet7519d502021-07-19 16:47:58 +01001410 op.name + "_bias", [1, 1, 1, 1], DataType.int16, [bias_term], np.int16, quantization=quant,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001411 )
1412 add_op.add_input_tensor(bias_term_tens)
1413 add_op.set_output_tensor(op.ofm)
1414 add_op.set_ifm_ofm_shapes()
1415 add_op.activation = op.activation
1416 op.activation = None
1417 op.set_output_tensor(intermediate)
1418 op.set_ifm_ofm_shapes()
1419 # If not, we can just do it with the OFM zero point.
1420 else:
1421 foq.zero_point = bias_term
1422 op.forced_output_quantization = foq
1423 else:
1424 assert inp.dtype == DataType.int8
1425 # Use a depthwise to calculate the sum,
1426 # followed by a multiplication with 1/N to get the MEAN
1427 weight_scale = 1
1428 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1429 intermediate.dtype = DataType.int16
1430 mul_op = Operation(Op.Mul, op.name + "_mul")
1431 mul_op.add_input_tensor(intermediate)
1432 # Create scalar containing 1/N
1433 quant = QuantizationParameters()
1434 quant.zero_point = 0
1435 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1436 # while rounding mode NATURAL would round this to -1.
1437 # This can only occur if N is even, and can be emulated by
1438 # multiplying with a number that is slightly smaller than 1/N.
1439 # It must be so small that other roundings are not affected;
1440 # the calculated value is based on worst case,
1441 # which is sum 256 * N (the maximum sum that can occur with int8)
1442 n = int(h * w)
1443 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1444 quant.scale_f32 = 1 / (n - eps)
1445 scalar = create_const_tensor(
1446 op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
1447 )
1448 mul_op.add_input_tensor(scalar)
1449 mul_op.set_output_tensor(op.ofm)
1450 mul_op.set_ifm_ofm_shapes()
1451 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1452 mul_op.activation = op.activation
1453 op.activation = None
1454 op.set_output_tensor(intermediate)
1455 op.set_ifm_ofm_shapes()
1456 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1457 # Here we can just use a simple AvgPool with truncating rounding,
1458 # as we're emulating simple integer division.
1459 op.rounding_mode = NpuRoundingMode.TRUNCATE
1460 op.type = Op.AvgPool
1461 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1462 else:
1463 op.rounding_mode = NpuRoundingMode.NATURAL
1464 weight_scale = 1 / (h * w)
1465 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1466 bias = -ifmq.zero_point * h * w
1467 fiq = ifmq.clone()
1468 fiq.zero_point = 0
1469 op.forced_input_quantization = fiq
1470
1471 # Change dimensions to 4
1472 if dims < 4:
1473 shape = [1] + shape
1474 if dims == 2:
1475 shape += [1]
1476
1477 # If height is greater than max kernel height, reshape to from HxW to 1x(HxW)
1478 if h > 64:
1479 shape = [shape[0], 1, h * w, shape[3]]
1480 op.ifm_shapes[0] = Shape4D(shape)
1481 if h > 256 and op.type == Op.AvgPool:
1482 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1483
1484 # If the AvgPool version is used, we don't need to do anything else
1485 if op.type == Op.AvgPool:
1486 return op
1487
1488 # Make unit weight tensor quantization
1489 weight_quant = ifmq.clone()
1490 weight_quant.min = 0
1491 weight_quant.max = 255
1492 weight_quant.scale_f32 = weight_scale
1493 weight_quant.zero_point = 0
1494
1495 # Set weight shape to [H,W,C,B]
1496 weight_shape = shape[1:4] + [shape[0]]
1497 # Add unit weight tensor
1498 op.set_input_tensor(
1499 create_const_tensor(
1500 "weights",
1501 weight_shape,
1502 inp.dtype,
1503 np.ones(weight_shape),
1504 value_dtype=np.uint8,
1505 quantization=weight_quant,
1506 ),
1507 1,
1508 )
James Peet7519d502021-07-19 16:47:58 +01001509 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001510
1511 # Add None bias tensor
1512 op.inputs.append(None)
1513 # Add bias tensor
1514 if bias:
1515 bias_shape = [shape[-1]]
1516 op.set_input_tensor(
1517 create_const_tensor(
Tim Hall8ae29292021-07-28 16:52:03 +01001518 "bias", bias_shape, inp.dtype, np.ones(bias_shape) * bias, value_dtype=np.int32, quantization=None,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001519 ),
1520 2,
1521 )
1522
1523 return op
1524
1525
1526def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001527 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001528 return op
1529
1530
1531def tflite_optimise_graph(nng, arch):
1532 # Pre-processing step
1533 pre_process_list = [
1534 supported_operator_check,
1535 set_ifm_ofm_op_shapes,
1536 ]
1537
1538 for idx, sg in enumerate(nng.subgraphs):
1539 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1540 nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
1541 )
1542
1543 # Handle Concat Ops
1544 for idx, sg in enumerate(nng.subgraphs):
1545 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1546 sg.refresh_after_modification()
1547
1548 # Handle Split Ops
1549 for idx, sg in enumerate(nng.subgraphs):
1550 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1551 nng,
1552 sg,
1553 arch,
1554 [],
1555 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1556 rewrite_unsupported=False,
1557 )
1558
1559 for idx, sg in enumerate(nng.subgraphs):
1560 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1561 nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False,
1562 )
1563
1564 # Handle sg input output
1565 for idx, sg in enumerate(nng.subgraphs):
1566 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1567 nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
1568 )
1569
Jonas Ohlsson81942e92021-08-20 09:33:28 +02001570 # Removal of reshapes and squeeze
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001571 for sg in nng.subgraphs:
Jonas Ohlsson81942e92021-08-20 09:33:28 +02001572 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshape_and_squeeze_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001573 sg.refresh_after_modification()
1574
1575 # Rewrite of operators
1576 op_rewrite_list = [
1577 set_tensor_equivalence,
1578 convert_mean_to_depthwise_conv_or_avgpool,
1579 convert_depthwise_to_conv,
1580 convert_conv_to_fc,
1581 convert_softmax,
1582 optimise_strided_conv,
1583 convert_hardswish_to_lut,
1584 rewrite_fully_connected_input,
1585 convert_batched_fc_shape,
1586 fixup_conv2d_backprop,
1587 fixup_relus_with_differing_ifm_ofm_scaling,
1588 fixup_elementwise_with_scalars,
1589 reorder_depthwise_weights,
1590 fixup_resizebilinear,
1591 fixup_bias_tensors,
1592 convert_mul_max_to_abs_or_lrelu,
1593 convert_lrelu,
1594 convert_tanh_sigmoid_to_lut,
1595 replace_pad_by_hw_pad,
1596 ]
1597
1598 for idx, sg in enumerate(nng.subgraphs):
1599 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1600 nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
1601 )
1602
1603 for idx, sg in enumerate(nng.subgraphs):
1604 # remove passthrough tensors and attempt further optimizations
1605 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1606 nng,
1607 sg,
1608 arch,
1609 [remove_passthrough_tensor],
1610 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1611 )
1612
1613 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1614 # since ifm/ofm_shapes are of importance to this function
1615 for sg in nng.subgraphs:
1616 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1617 sg.refresh_after_modification()
1618
1619 return nng