blob: 3f743e43224af3afe1c62e2d4b35c61b2a6d0c8d [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
25from . import lut
26from . import rewrite_graph
27from . import scaling
28from .api import NpuRoundingMode
29from .data_type import DataType
30from .debug_database import DebugDatabase
31from .errors import UnsupportedFeatureError
32from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020033from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020034from .graph_optimiser_util import needed_total_padding
35from .graph_optimiser_util import set_ifm_ofm_op_shapes
36from .graph_optimiser_util import set_tensor_equivalence
37from .numeric_util import clamp_sigmoid
38from .numeric_util import full_shape
39from .numeric_util import round_away_zero
40from .operation import create_activation_function
41from .operation import NpuBlockType
42from .operation import Op
43from .operation import Operation
44from .operation import Padding
45from .operation_util import create_avgpool_nop
46from .operation_util import get_pad_values_from_input
47from .shape4d import Shape4D
48from .softmax import SoftMax
49from .tensor import check_quantized_tens_scaling_equal
50from .tensor import create_const_tensor
51from .tensor import create_equivalence_id
52from .tensor import QuantizationParameters
53from .tensor import Tensor
54from .tensor import TensorPurpose
55from .tflite_mapping import optype_to_builtintype
56
57passthrough_nodes = (Op.Identity,)
58
59
60def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
61 """Creates an average pool for the given concat op/input feature map"""
62 ofm = concat_op.ofm
63 avgpool_op = create_avgpool_nop(name)
64 avgpool_op.inputs = [ifm]
65 avgpool_op.outputs = [ofm]
66
67 avgpool_op.write_offset = write_offset
68 avgpool_op.write_shape = ifm_shape
69 ofm.ops.append(avgpool_op)
70 DebugDatabase.add_optimised(concat_op, avgpool_op)
71 avgpool_op.ifm_shapes.append(ifm_shape)
72 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
73 avgpool_op.memory_function = Op.ConcatSliceWrite
74 return avgpool_op
75
76
77def remove_passthrough_tensor(tens, arch, nng):
78 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
79 assert len(tens.ops[0].inputs) == 1
80 tens = tens.ops[0].inputs[0]
81 return tens
82
83
84def rewrite_concat_ops(op, arch):
85 if not op.run_on_npu or not op.type.is_concat_op():
86 return
87
88 axis_4D = 0
89 ofm = op.ofm
90 ofm.ops = []
91 offset = 0
92
93 unfuse_activation_function(op)
94
95 if op.type == Op.Pack:
96 # Pack is also referred to as Stack
97 axis = int(op.attrs["axis"])
98 if axis < 0: # Convert to positive axis
99 axis = len(op.inputs[0].shape) + 1 + axis
100
101 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
102
103 axis_4D = axis + (4 - len(desired_shape))
104
105 for idx, inp in enumerate(op.inputs):
106 op.ifm_shapes[idx] = Shape4D(desired_shape)
107 op.type = Op.PackReshaped
108
109 inputs, axis = op.get_concat_inputs_axis()
110 for idx, inp in enumerate(inputs):
111 if op.type != Op.PackReshaped:
112 op.ifm_shapes[idx] = Shape4D(inp.shape)
113 if axis >= 0:
114 axis_4D = axis + (4 - len(inp.shape))
115 else:
116 axis_4D = axis
117 write_offset = [0, 0, 0, 0]
118 write_offset[axis_4D] = offset
119 concat_end = offset + op.ifm_shapes[idx][axis_4D]
120 create_avg_pool_for_concat(
121 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
122 )
123 offset = concat_end
124 assert ofm.shape[axis] == offset
125
126 return op
127
128
129def rewrite_split_ops(tens, arch, nng):
130
131 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
132 split_op = tens.ops[0]
133
134 # Not supported so leave it and run on CPU
135 if not split_op.run_on_npu:
136 return tens
137
138 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
139
140 tens.ops = []
141 new_op = Operation(Op.SplitSliceRead, split_op.name)
142 new_op.inputs = [inp]
143 ofm_shape_idx = 0
144 read_shape = offset_end
145
146 # For Split the offset cannot be extracted from the tensor so it has to
147 # be calculated from the index of the output tensor
148 if axis is not None:
149 # Get the start and end of the split
150 offset_start = [0] * 4
151 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
152 for idx, out in enumerate(outputs):
153 if axis_4D_list is not None:
154 axis_4D = axis_4D_list[idx]
155 else:
156 split_op.ofm_shapes[idx] = Shape4D(out.shape)
157 if axis >= 0:
158 axis_4D = axis + (4 - len(out.shape))
159 else:
160 axis_4D = axis
161
162 if out == tens:
163 ofm_shape_idx = idx
164 read_shape = split_op.ofm_shapes[idx]
165 break
166
167 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
168
169 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
170 new_op.read_shapes[0] = read_shape
171 new_op.run_on_npu = True
172 new_op.set_output_tensor(tens)
173 new_op.ifm_shapes.append(Shape4D(inp.shape))
174 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
175 DebugDatabase.add_optimised(split_op, new_op)
176
177 return tens
178
179
180def remove_SplitSliceRead(op, arch):
181
182 if op.type == Op.SplitSliceRead:
183 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
184 if (
185 len(op.ofm.consumer_list) == 1
186 and op.ofm.consumer_list[0] is not None
187 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200188 and op.ofm.consumer_list[0].type not in (Op.Reshape, Op.Squeeze)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200189 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
190 ):
191 # SplitSliceRead can be performed by tensor consumer
192 cons_op = op.ofm.consumer_list[0]
193 if cons_op.ifm == op.ofm:
194 cons_op.read_offsets[0] = op.read_offsets[0]
195 cons_op.read_shapes[0] = op.read_shapes[0]
196 cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[0])
197 cons_op.ifm_shapes[0] = op.ifm_shapes[0]
198 elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == op.ofm:
199 cons_op.read_offsets[1] = op.read_offsets[0]
200 cons_op.read_shapes[1] = op.read_shapes[0]
201 cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[1])
202 cons_op.ifm_shapes[1] = op.ifm_shapes[0]
203
204 if "skirt" in cons_op.attrs:
205 assert cons_op.attrs["explicit_padding"] == cons_op.attrs["skirt"]
206 cons_op.attrs["skirt"] = None
207 cons_op.attrs["force_padding"] = True
208 op.ofm.consumer_list.remove(cons_op)
209 op.ofm.ops = []
210 op.ifm.consumer_list.remove(op)
211 else:
212 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
213 avgpool_op.add_input_tensor(op.ifm)
214 avgpool_op.outputs = [op.ofm]
215 op.ofm.ops.remove(op)
216 op.ofm.ops.append(avgpool_op)
217 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
218 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
219 avgpool_op.read_offsets[0] = op.read_offsets[0]
220 avgpool_op.read_shapes[0] = op.read_shapes[0]
221
222 op.ifm.consumer_list.remove(op)
223 DebugDatabase.add_optimised(op, avgpool_op)
224
225
226def insert_copy_op_after_tens(tens):
227 tens_cons_list_copy = tens.consumer_list.copy()
228
229 # Create a avg_pool nop op with ifm as input
230 copy_tens = tens.clone()
231 copy_op = create_avgpool_nop(tens.name + "_avgpool")
232 copy_op.add_input_tensor(tens)
233 copy_op.set_output_tensor(copy_tens)
234 copy_op.set_ifm_ofm_shapes()
235 copy_op.run_on_npu = True
236
237 # Set copy_ifm consumers
238 for tens_cons in tens_cons_list_copy:
239 if tens_cons is not None:
240 for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
241 if cons_inp == tens:
242 tens_cons.set_input_tensor(copy_tens, ifm_idx)
243
244 DebugDatabase.add_optimised(tens.ops[0], copy_op)
245
246
247def fix_sg_input_output(op, arch, nng):
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200248 if not op.run_on_npu or op.type not in (Op.Reshape, Op.Squeeze):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200249 return op
250
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200251 # For the Reshape/Squeeze operators we want to remove, tensors are removed.
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200252 # But in order to to do this, they cannot be outputs of the sg,
253 # this need to be fixed prior to the removal.
254 # Solution is to add a avgpool NOP, to maintain the original tensor.
255 # This is also valid when reshape ifm/ofm is produced respectively
256 # consumed by CPU
257
258 # Check if operator ifm/ofm are sg ifm/ofm
259 ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
260 ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
261 ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200262 # Check if ifm/ofm is produced respectively consumed by CPU
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200263 ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
264 ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
265
266 if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200267 # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape/Squeeze
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200268 insert_copy_op_after_tens(op.ifm)
269
270 return op
271
272
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200273def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
274 k_w, k_h = kernel.dilated_wh()
275 s_x, s_y = kernel.stride
276 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
277 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
278 if padding_type == Padding.SAME:
279 left_pad = (xpad + 0) // 2
280 right_pad = (xpad + 1) // 2
281 top_pad = (ypad + 0) // 2
282 bottom_pad = (ypad + 1) // 2
283 elif padding_type == Padding.VALID:
284 left_pad = 0
285 right_pad = 0
286 top_pad = 0
287 bottom_pad = 0
288 elif padding_type == Padding.EXPLICIT:
289 # Padding is specified in a PAD operator which has been bypassed.
290 top, left, bottom, right = explicit_padding
291 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
292 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
293 else:
294 raise UnsupportedFeatureError(f"Unknown padding")
295 padding = (top_pad, left_pad, bottom_pad, right_pad)
296 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
297 return padding, skirt
298
299
300def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
301 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
302 if padding_type == Padding.SAME:
303 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
304 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
305 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
306 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
307 left_pad = max(kernel_width - 1 - right_pad, 0)
308 top_pad = max(kernel_height - 1 - bottom_pad, 0)
309 elif padding_type == Padding.VALID:
310 right_pad = max(kernel_width - 2, 0)
311 bottom_pad = max(kernel_height - 2, 0)
312 left_pad = kernel_width - 1
313 top_pad = kernel_height - 1
314 else:
315 raise UnsupportedFeatureError(f"Unknown padding")
316 padding = (top_pad, left_pad, bottom_pad, right_pad)
317 skirt = padding
318 return padding, skirt
319
320
321def fixup_conv2d_backprop(op, arch, nng):
322 if op.type == Op.Conv2DBackpropInput:
323 # flip the inputs
324 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
325 op.type = Op.Conv2DBackpropInputSwitchedBias
326 op.ifm.resampling_mode = resampling_mode.TRANSPOSE
327
328 # Update strides
329 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
330
331 return op
332
333
334# Convert the op to an elementwise add
335def convert_resizebilinear_1x1_to_add(op):
336 op.type = Op.Add
337 op.name = op.name + "_add"
338 op.attrs["resizebilinear"] = True
339 # Create an input tensor filled with zeros
340 shape = op.ofm_shapes[0].as_list()
341 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100342 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200343 tens.quantization = QuantizationParameters(0.0, 255.0)
344 tens.quantization.scale_f32 = 1.0
345 tens.quantization.zero_point = 0
346 tens.consumer_list = [op]
347 tens_op = op.inputs[1].ops[0]
348 tens_op.set_output_tensor(tens)
349 # Set the add inputs
350 op.inputs[1] = op.inputs[0]
351 op.inputs[0] = tens
352 op.set_ifm_ofm_shapes()
353
354 return op
355
356
357# Convert ResizeBilinear to a number of 2x2 pool ops
358def convert_resizebilinear_to_2x2_pool(op):
359 count = 0
360 pre_op = op
361 outputs = op.outputs
362
363 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
364 if op.attrs["align_corners"]:
365 shape_modifier = 1
366 op.attrs["padding"] = Padding.VALID
367 else:
368 shape_modifier = 0
369 op.attrs["padding"] = Padding.SAME
370 op.inputs[0].resampling_mode = resampling_mode.NEAREST
371
372 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
373 out_shape = np.array(op.ofm_shapes[0].get_hw_as_list())
374 if (upscaled_shape == upscaled_shape * 2 - shape_modifier).all():
375 return op
376
377 while (upscaled_shape < out_shape).all():
378 if count == 0:
379 scaled_op = pre_op
380 else:
381 scaled_op = op.clone("_{}".format(count))
382 scaled_op.inputs[0] = pre_op.outputs[0]
383
384 upscaled_shape = upscaled_shape * 2 - shape_modifier
385
386 if (upscaled_shape == out_shape).all():
387 scaled_op.outputs = outputs
388 scaled_op.outputs[0].ops = [scaled_op]
389 else:
390 shape = op.ofm_shapes[0].as_list()
391 shape[1:3] = upscaled_shape
392 out_tens = Tensor(shape, DataType.int16, "{}_{}".format(op.outputs[0].name, count))
393 out_tens.quantization = op.outputs[0].quantization.clone()
394 out_tens.quantization.quant_min = np.iinfo(np.int16).min
395 out_tens.quantization.quant_max = np.iinfo(np.int16).max
396 scaled_op.set_output_tensor(out_tens)
397 pre_op = scaled_op
398 count += 1
399
400 # Setup the scale value
401 if scaled_op.inputs[0].dtype.bits == 8 and scaled_op.outputs[0].dtype.bits == 16:
402 scaled_op.rescale = 128
403 elif scaled_op.inputs[0].dtype.bits == 16 and scaled_op.outputs[0].dtype.bits == 8:
404 scaled_op.rescale = 1 / 128
405 else:
406 scaled_op.rescale = None
407 scaled_op.set_ifm_ofm_shapes()
408
409 return op
410
411
412def fixup_resizebilinear(op, arch, nng):
413 if op.type == Op.ResizeBilinear and op.run_on_npu:
414 if op.ifm_shapes[0] == op.ofm_shapes[0]:
415 # Bypass nop resizebilinear
416 op.inputs = op.inputs[:1]
417 op.type = Op.Identity
418 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
419 convert_resizebilinear_1x1_to_add(op)
420 else:
421 convert_resizebilinear_to_2x2_pool(op)
422
423 return op
424
425
426def convert_nop_split_to_identity(op, arch, nng):
427 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
428 # the list comprehension should return a list with a single tensor
429 # if it shouldn't, remove_passthrough_tensor will fail appropriately
430 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
431 op.type = Op.Identity
432 return op
433
434
435def rewrite_fully_connected_input(op, arch, nng):
436 if op.type == Op.FullyConnected:
437 n_in_elems = op.weights.shape[-2]
438 elms = op.ifm.elements()
439 batch_size = elms // n_in_elems
440 assert batch_size * n_in_elems == elms
441
442 op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
443 return op
444
445
446def convert_batched_fc_shape(op, arch, nng):
447 if op.type == Op.FullyConnected:
448 # Check if the first dimension indicates batching
449 if op.ifm_shapes[0].batch > 1:
450 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
451 n = op.ifm_shapes[0].batch
452 h, w = batching_split.get(n, (1, n))
453 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
454
455 # Reshape Weights to be 4D. IO becomes HWIO
456 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100457 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
458 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200459
460 n = op.ofm_shapes[0].batch
461 h, w = batching_split.get(n, (1, n))
462 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
463 return op
464
465
466def unfuse_activation_function(op):
467 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
468 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
469 op.activation = None
470 out_tens = op.outputs[0]
471 intermediate_tens = out_tens.clone("_act_intermediate")
472 act_op.set_output_tensor(out_tens)
473 act_op.add_input_tensor(intermediate_tens)
474 op.set_output_tensor(intermediate_tens)
475 act_op.set_ifm_ofm_shapes()
476
477
478def rewrite_stridedslice_output(op, arch, nng):
479 if not op.run_on_npu or op.type != Op.StridedSlice:
480 return op
481
482 new_axis_mask = op.attrs["new_axis_mask"]
483 shrink_axis_mask = op.attrs["shrink_axis_mask"]
484
485 if shrink_axis_mask == 0 and new_axis_mask == 0:
486 return op
487
488 axis_4D = [0] * len(op.outputs)
489 for idx, out_tens in enumerate(op.outputs):
490 output_shape = list(out_tens.shape)
491
492 if shrink_axis_mask != 0:
493 n = 0
494 axis = 0
495 while shrink_axis_mask:
496 prev_mask = shrink_axis_mask
497 n += 1
498 shrink_axis_mask &= shrink_axis_mask - 1
499 axis = int(math.log2(prev_mask - shrink_axis_mask))
500 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
501
502 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
503 op.attrs["shrink_axis_mask"] = 0
504 if axis >= 0:
505 axis_4D[idx] = axis + (4 - len(output_shape))
506 else:
507 axis_4D[idx] = axis
508 op.ofm_shapes[idx] = Shape4D(output_shape)
509
510 elif new_axis_mask != 0:
511 n = 0
512 axis = 0
513 while new_axis_mask:
514 prev_mask = new_axis_mask
515 n += 1
516 new_axis_mask &= new_axis_mask - 1
517 axis = int(math.log2(prev_mask - new_axis_mask))
518 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
519 new_axis_mask >>= 1
520
521 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
522 op.attrs["new_axis_mask"] = 0
523 if axis >= 0:
524 axis_4D[idx] = axis + (4 - len(output_shape))
525 else:
526 axis_4D[idx] = axis
527 op.ofm_shapes[idx] = Shape4D(output_shape)
528
529 op.attrs["split_axis_4D"] = axis_4D
530 return op
531
532
533def rewrite_unpack_output(op, arch, nng):
534 tens = op.outputs[0]
535 if op.run_on_npu and op.type == Op.Unpack:
536 # Unpack is also referred to as Unstack
537 axis = int(op.attrs["axis"])
538 if axis < 0: # Convert to positive axis
539 axis = len(op.inputs[0].shape) + 1 + axis
540 op.type = Op.UnpackReshaped
541 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
542
543 axis_4D = axis + (4 - len(desired_output_shape))
544 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
545
546 for idx, out_tens in enumerate(op.outputs):
547 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
548 return op
549
550
551def add_padding_fields(op, arch, nng):
552 if op.run_on_npu:
553 if "padding" in op.attrs:
554 input_shape = op.ifm_shapes[0]
555 output_shape = op.ofm_shapes[0]
556 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
557 kernel_size = op.inputs[1].shape[:2]
558 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
559 kernel_size = op.attrs["ksize"][1:3]
560 else:
561 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
562
563 if op.type == Op.Conv2DBackpropInputSwitchedBias:
564 upscaling_factor = output_shape.height // input_shape.height
565 padding, skirt = calc_upscaled_padding_and_skirt(
566 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
567 )
568 else:
569 padding, skirt = calc_padding_and_skirt(
570 op.attrs["padding"], op.kernel, input_shape, op.attrs.get("explicit_padding"),
571 )
572
573 op.attrs["explicit_padding"] = padding
574 op.attrs["skirt"] = skirt
575
576 return op
577
578
579def convert_depthwise_to_conv(op, arch, nng):
580 # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and
581 # the ofm depth equals the depth multipler.
582 # If those conditions are true, then we can perform a simple
583 # switch of the operator type (and weight order)
584
585 if op.type == Op.DepthwiseConv2DBias and (op.attrs["depth_multiplier"] != 1):
586 ifm_shape = op.ifm_shapes[0]
587 weight_tensor = op.inputs[1]
588 ofm_shape = op.ofm_shapes[0]
589 if (ifm_shape.depth == 1) and (ofm_shape.depth == op.attrs["depth_multiplier"]):
590 # Change op type to Conv2d
591 op.type = Op.Conv2DBias
592 del op.attrs["channel_multiplier"]
593 del op.attrs["depth_multiplier"]
594
James Peet7519d502021-07-19 16:47:58 +0100595 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
596 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200597 else:
598 raise UnsupportedFeatureError(
599 f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},",
600 f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}",
601 )
602 DebugDatabase.add_optimised(op, op)
603 return op
604
605
606def reorder_depthwise_weights(op, arch, nng):
607 if op.type.is_depthwise_conv2d_op():
608 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100609 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
610 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200611 weight_tensor.weight_transpose_depthwise = True
612
613 return op
614
615
616def optimise_strided_conv(op, arch, nng):
617 stride_x, stride_y = op.get_kernel_stride()
618 ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()
619
620 if (
621 op.type == Op.Conv2DBias
622 and op.op_index == 0
623 and stride_x == 2
624 and op.ifm_shapes[0].depth <= 4
625 and op.ifm_shapes[0].width % 2 == 0
626 and weight_tensor is not None
627 and weight_tensor.shape[1] >= 2
628 ):
629 ifm_shape = op.ifm_shapes[0]
630 # IFM
631 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
632
633 # Weights
634 weight_shape = weight_tensor.shape
635 if weight_shape[1] % 2 != 0:
636 weight_shape[1] = weight_shape[1] + 1
637 padded_array = np.zeros(weight_shape)
638 for i in range(weight_shape[0]):
639 padded_array[i] = np.vstack(
640 [
James Peet7519d502021-07-19 16:47:58 +0100641 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200642 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
643 ]
644 )
James Peet7519d502021-07-19 16:47:58 +0100645 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200646 weight_shape[1] //= 2
647 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100648 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200649 weight_tensor.set_all_shapes(weight_shape)
650 # If multiple copies of the weights are used, we could avoid
651 # them having the same address by changing the value_id
652 weight_tensor.value_id = uuid.uuid4()
653
654 # Strides
655 stride_x = 1
656 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
657
658 return op
659
660
661def convert_conv_to_fc(op, arch, nng):
662 # Conv 1x1 can be equivalent to Fully Connected.
663 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
664 # caching/double buffering for the weights.
665 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
666 if op.type == Op.Conv2DBias:
667 h = op.ifm_shapes[0].height
668 w = op.ifm_shapes[0].width
669 kh, kw, _, _ = op.inputs[1].shape
670 if h == 1 and w == 1 and kh == 1 and kw == 1:
671 # Overwrite this op as a Fully Connected Op
672 op.name += "_fc"
673 op.type = Op.FullyConnected
674 op.attrs = {
675 "weights_format": 0,
676 }
677 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
678 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100679 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
680 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200681
682 DebugDatabase.add_optimised(op, op)
683 return op
684
685
686def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
687 if op.run_on_npu and op.type.is_relu_op():
688 ifm = op.inputs[0]
689 ofm = op.outputs[0]
690 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
691 # and requires its own to be inserted
692 if not check_quantized_tens_scaling_equal(ifm, ofm):
693 # Override this op with its own primary op (avgpool)
694 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
695 # And fuse the original activation function to it
696 relu_fused_op.activation = create_activation_function(op.type)
697 # Tidy up and assign the ifm and ofm to the new op
698 ifm.consumer_list.remove(op)
699
700 relu_fused_op.add_input_tensor(ifm)
701 relu_fused_op.set_output_tensor(ofm)
702 relu_fused_op.set_ifm_ofm_shapes()
703 op = relu_fused_op
704 return op
705
706
707def fixup_elementwise_with_scalars(op, arch, nng):
708 if op.type.is_binary_elementwise_op():
709 ifm_tensor, ifm2_tensor, _, _ = op.get_ifm_ifm2_weights_ofm()
710 if ifm2_tensor.shape != [] and ifm_tensor.shape != []:
711 diff = len(ifm_tensor.shape) - len(ifm2_tensor.shape)
712 if diff > 0:
713 ifm2_tensor.shape = full_shape(len(ifm_tensor.shape), ifm2_tensor.shape, 1)
714 elif diff < 0:
715 ifm_tensor.shape = full_shape(len(ifm2_tensor.shape), ifm_tensor.shape, 1)
James Peet7519d502021-07-19 16:47:58 +0100716 elif ifm_tensor.shape == [] and ifm_tensor.values is None:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200717 # IFM is marked as a scalar, but is a result of an operation; change it to a shape of size 1
718 ifm_tensor.shape = len(ifm2_tensor.shape) * [1]
719 ifm_tensor.storage_shape = ifm_tensor.shape
James Peet7519d502021-07-19 16:47:58 +0100720 elif ifm2_tensor.shape == [] and ifm2_tensor.values is None:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200721 # IFM2 is marked as a scalar, but is a result of an operation; change it to a shape of size 1
722 ifm2_tensor.shape = len(ifm_tensor.shape) * [1]
723 ifm2_tensor.storage_shape = ifm2_tensor.shape
724 return op
725
726
727def convert_softmax(op, arch, nng):
728 if op.type == Op.Softmax and op.run_on_npu:
729 softmax = SoftMax(op)
730 op = softmax.get_graph()
731 return op
732
733
734def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
735 r"""Whenever there is a subgraph with this topology:
736
737 Input X For X = -1 or X > 0
738 | \ / This subgraph can be replaced with either
739 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
740 | /
741 Max
742 """
743
744 if op.type == Op.Maximum:
745 # finds the Mul input(s) to the Max
746 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
747 if len(muls) == 1:
748 mul = muls[0].ops[0]
749 elif len(muls) == 2:
750 # In the case both inputs are Muls, find the one with the same input as the Max
751 mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
752 else:
753 # No Mul inputs
754 return op
755
756 # make sure the Mul doesn't have any other consumers
757 mul_ofm = mul.outputs[0]
758 if len(mul_ofm.consumers()) != 1:
759 return op
760 # make sure the Mul doesn't have a fused activation function
761 if mul.activation:
762 return op
763 ifm, ofm = op.get_ifm_ofm()
764 if ifm is None or ofm is None:
765 return op
766
767 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
768 return op
769 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
770 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
771 return op
772
773 # finds the branched input that goes to both the Max and the Mul
774 shared = set(op.inputs) & set(mul.inputs)
775 if len(shared) == 1:
776 shared_in = shared.pop()
777 # find the constant scalar input to the Mul
778 const_tens = (set(mul.inputs) - {shared_in}).pop()
779 # check that it is a scalar
780 if const_tens.shape != []:
781 return op
782 const = const_tens.ops[0]
783 # check that it is a constant
784 if const.type != Op.Const:
785 return op
786 # Remove the Mul from the shared input's consumers
787 shared_in.consumer_list.remove(mul)
788 else:
789 return op
790
791 val = const.outputs[0].values
792 if val >= 0:
793 new_op = Op.LeakyRelu
794 op.attrs["alpha"] = val
795 # to produce bit exact results, the alpha is not enough;
796 # save additional scaling info in attr "alpha_scale", to be used as input
797 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +0100798 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200799 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
800 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
801 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
802 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
803 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
804 elif val == -1:
805 new_op = Op.Abs
806 else:
807 return op
808
809 op.type = new_op
810 op.name = op.name.replace("Maximum", new_op.name)
811 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
812 op.inputs = [shared_in]
813 op.set_ifm_ofm_shapes()
814
815 # Record optimisation in debug database
816 DebugDatabase.add_optimised(op, op)
817
818 return op
819
820
821def convert_hardswish_to_lut(op, arch, nng):
822 if op.type == Op.HardSwish:
823 ifm, ofm = op.get_ifm_ofm()
824 # Generate the LUT
825 ifm_scale = np.double(ifm.quantization.scale_f32)
826 ofm_scale = np.double(ofm.quantization.scale_f32)
827 zp_in = ifm.quantization.zero_point
828 zp_out = ofm.quantization.zero_point
829 ifm_scale_hires = (1 / 128) * ifm_scale
830 relu_multiplier = np.double(3 / 32768)
831 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
832 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
833 # Use 16bit scale
834 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
835 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
836
837 values = []
838 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
839 quantized_min = min(ix)
840 quantized_max = max(ix)
841 for x in ix:
842 input_value = x - zp_in
843 input_value_hires = input_value * 128
844 # Compute the input value on essentially the output scale, not shifted yet
845 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
846 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
847 relu_value = np.int16(input_value_hires)
848 if relu_shift < 31:
849 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
850
851 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
852
853 if relu_shift < 31:
854 relu_value = fp_math.shift_left16(relu_value, 1)
855
856 if relu_shift > 31:
857 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
858
859 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
860 # Now convert that to a 16bit fixedpoint value in [0, 1]
861 relu_value = (relu_value + (1 << 15)) >> 1
862 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
863 shift = 31 - out_shift
864 shift = -shift if shift < 0 else 0
865 # Finally apply the output shift
866 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
867 lut_result = min(quantized_max, max(quantized_min, lut_result))
868 values.append(lut_result)
869 return convert_to_lut(op, values, "hardswish")
870 return op
871
872
873def convert_lrelu_to_mul_max(op, arch):
874 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
875 # (the opposite of convert_mul_max_to_abs_or_lrelu)
876 ifm, ofm = op.get_ifm_ofm()
877 if ifm is None or ofm is None:
878 return op
879
880 # Add multiplication with alpha
881 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
882 mul_alpha.add_input_tensor(ifm)
883 # Create const tensor containing alpha as scalar
884 alpha = op.attrs["alpha"]
885 quantization = ifm.quantization.clone()
886 quantization.min = 0
887 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
888 quantization.zero_point = 0
889 if np.isinf(1 / np.float32(alpha)):
890 # Handling of alpha near zero
891 quantization.scale_f32 = 1
892 scalar = 0
893 else:
894 quantization.scale_f32 = alpha
895 scalar = alpha
896 alpha_tens = create_const_tensor(
897 op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
898 )
James Peet7519d502021-07-19 16:47:58 +0100899 alpha_tens.values = np.array([1])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200900 mul_alpha.add_input_tensor(alpha_tens)
901 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
902 mul_alpha.set_output_tensor(fm_alpha)
903 mul_alpha.set_ifm_ofm_shapes()
904 DebugDatabase.add_optimised(op, mul_alpha)
905
906 if check_quantized_tens_scaling_equal(ifm, ofm):
907 # No identity multiplication is needed
908 fm_id = ifm
909 else:
910 # Add multiplication with identity
911 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
912 mul_identity.add_input_tensor(ifm)
913 # Create const tensor containing identity as scalar
914 quantization = ifm.quantization.clone()
915 quantization.min = 0
916 quantization.max = quantization.quant_max - quantization.quant_min
917 quantization.scale_f32 = 1
918 quantization.zero_point = 0
919 identity_tens = create_const_tensor(
920 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
921 )
922 mul_identity.add_input_tensor(identity_tens)
923 # Make sure that fm_id is allocated to a different address than fm_alpha
924 fm_id = ofm.clone(op.name + "_id", set_unique=True)
925 mul_identity.set_output_tensor(fm_id)
926 mul_identity.set_ifm_ofm_shapes()
927 DebugDatabase.add_optimised(op, mul_identity)
928
929 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
930 op.type = Op.Maximum
931 op.name = op.name.replace("LeakyRelu", "Maximum")
932 op.inputs = []
933 ifm.consumer_list.remove(op)
934 op.add_input_tensor(fm_alpha)
935 op.add_input_tensor(fm_id)
936 op.set_ifm_ofm_shapes()
937
938 DebugDatabase.add_optimised(op, op)
939 return op
940
941
942def convert_to_lut(op, lut_values, lut_name):
943 # Rewrite the operation by Add with scalar 0 + LUT activation
944 ifm = op.inputs[0]
945 if ifm is None:
946 return op
947 assert ifm.dtype.size_in_bytes() == 1
948 op.type = Op.Add
949 op.name = op.name + "_lut_" + lut_name
950 # Mark as no-op to enable potential fusing optimizations
951 op.attrs["is_nop"] = True
952 # Create an input tensor containing scalar zero
953 quantization = QuantizationParameters(0.0, 255.0)
954 quantization.scale_f32 = ifm.quantization.scale_f32
955 quantization.zero_point = 0
956 tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization)
957 op.add_input_tensor(tens)
958 op.ifm_shapes.append(Shape4D(tens.shape))
959
960 # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
961 # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
962 # should be the same as the IFM
963 op.forced_output_quantization = ifm.quantization
964 lut_tensor = lut.create_lut_tensor(op.name + "_values", lut_values, DataType.int8)
965 op.set_activation_lut(lut_tensor)
966 op.set_ifm_ofm_shapes()
967 return op
968
969
970def convert_to_lut8(op, fn, fn_name):
971 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
972 # fn is a function(real) -> real
973 ifm, ofm = op.get_ifm_ofm()
974 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
975 return op
976 # Generate the LUT
977 ifm_scale = np.double(ifm.quantization.scale_f32)
978 ofm_scale = np.double(ofm.quantization.scale_f32)
979 zp_in = ifm.quantization.zero_point
980 zp_out = ofm.quantization.zero_point
981 values = []
982 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
983 quantized_min = min(ix)
984 quantized_max = max(ix)
985 for x in ix:
986 x_real = ifm_scale * (x - zp_in)
987 y_real = fn(x_real)
988 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
989 lut_result = min(quantized_max, max(quantized_min, lut_result))
990 values.append(lut_result)
991 return convert_to_lut(op, values, fn_name)
992
993
994def convert_lrelu_to_lut(op, arch):
995 ifm, ofm = op.get_ifm_ofm()
996 # Generate the LUT
997 alpha = op.attrs["alpha"]
998 ifm_scale = np.double(ifm.quantization.scale_f32)
999 ofm_scale = np.double(ofm.quantization.scale_f32)
1000 zp_in = ifm.quantization.zero_point
1001 zp_out = ofm.quantization.zero_point
1002 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1003 alpha_scalar = 1
1004 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1005 if "alpha_scaling" in op.attrs:
1006 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1007 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1008 values = []
1009 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1010 quantized_min = min(ix)
1011 quantized_max = max(ix)
1012 for x in ix:
1013 if x < zp_in:
1014 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1015 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1016 )
1017 else:
1018 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1019 lut_result = min(quantized_max, max(quantized_min, lut_result))
1020 values.append(lut_result)
1021 return convert_to_lut(op, values, "lrelu")
1022
1023
1024def convert_lrelu(op, arch, nng):
1025 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
1026 if op.type != Op.LeakyRelu:
1027 return op
1028 ifm, ofm = op.get_ifm_ofm()
1029 if ifm is None or ofm is None:
1030 return op
1031 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1032 # use LUT for int8/uint8
1033 return convert_lrelu_to_lut(op, arch)
1034 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
1035 # use LeakyRelu unmodified for int16 with equal input/output scaling
1036 return op
1037 return convert_lrelu_to_mul_max(op, arch)
1038
1039
1040def convert_tanh_sigmoid_to_lut(op, arch, nng):
1041 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
1042 if op.type == Op.Sigmoid:
1043 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1044 elif op.type == Op.Tanh:
1045 return convert_to_lut8(op, math.tanh, "tanh")
1046 return op
1047
1048
Jonas Ohlsson81942e92021-08-20 09:33:28 +02001049def remove_reshape_and_squeeze_ops(op, arch):
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +02001050 if op.run_on_npu and op.type in (Op.Reshape, Op.Squeeze):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001051 ofm = op.ofm
1052 ifm = op.ifm
1053
1054 # Check if quantization is the same in the input and output for the reshape ops
1055 if not check_quantized_tens_scaling_equal(ifm, ofm):
1056 # TODO Both tensors are needed, since quantisation properties currently are linked to Tensors.
1057 # In order to remove this reshape either quantization properties need to be moved to Operator,
1058 # or the reshape need to be replace with a NOP.
1059 return
1060
Jonas Ohlsson81942e92021-08-20 09:33:28 +02001061 # Check if ifm/ofm are network ifm/ofm
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001062 ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
1063 ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list)
1064 ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list)
Jonas Ohlsson81942e92021-08-20 09:33:28 +02001065 # Check if ifm/ofm is produced respectively consumed by CPU
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001066 ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
1067 ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
1068
1069 # This case should be handled prior to this function
1070 assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed))
1071
1072 if ofm_is_sg_ofm or ofm_is_cpu_consumed:
1073 # Bypassed by replacing ifm with ofm
1074 ofm.ops = []
1075 for prev_op in ifm.ops:
1076 prev_op.outputs = [ofm]
1077 ofm.ops.append(prev_op)
1078
1079 # All ifm consumers need to use ofm as input
1080 for ifm_cons in ifm.consumer_list:
1081 for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
1082 if cons_ifm == ifm:
1083 ifm_cons.set_input_tensor(ofm, ifm_idx)
1084 else:
Jonas Ohlsson81942e92021-08-20 09:33:28 +02001085 # Bypassed by replacing ofm with ifm
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001086 for cons in ofm.consumer_list:
1087 for ifm_idx, cons_ifm in enumerate(cons.inputs):
1088 if cons_ifm == ofm:
1089 cons.set_input_tensor(ifm, ifm_idx)
1090
1091
1092def fuse_activation_function_with_prev(op, arch, nng):
1093 # if op is a no-op: attempts to move the activation function to the preceding op
1094 if not op.attrs.get("is_nop", False) or op.activation is None:
1095 return op
1096 ifm, ofm = op.get_ifm_ofm()
1097 if ifm is None or ofm is None:
1098 return op
1099 # finds the input(s) to the operation
1100 prev_op = ifm.ops[0]
1101 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1102 fuse = (
1103 prev_op.run_on_npu
1104 and prev_op.type.npu_block_type != NpuBlockType.Default
1105 and len(ifm.ops) == 1
1106 and len(prev_op.outputs[0].consumers()) == 1
1107 and prev_op.activation is None
1108 )
1109 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1110 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1111 # LUT currently only works correctly for elementwise ops
1112 fuse = False
1113 if not fuse:
1114 return op
1115 # Move the fused activation function + corresponding info to prev_op
1116 prev_op.activation = op.activation
1117 prev_op.forced_output_quantization = op.forced_output_quantization
1118 if op.activation_lut is not None:
1119 prev_op.set_activation_lut(op.activation_lut)
1120 # Bypass op
1121 prev_op.set_output_tensor(ofm)
1122 DebugDatabase.add_optimised(op, prev_op)
1123 return op
1124
1125
1126def _leading_pad_ok(leading_pad, stride, kernel_size):
1127 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1128 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1129 max_size = kernel_size // 2
1130 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1131
1132
1133def replace_pad_by_hw_pad(op: Operation, arch, nng):
1134 """
1135 Tries to completely remove a PAD operator by using hardware padding.
1136 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1137 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1138 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1139 if both operations can be run on the NPU.
1140 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1141 """
1142 if (
1143 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
1144 and op.run_on_npu
1145 and op.attrs["padding"] == Padding.VALID
1146 ):
1147 pad_op = op.ifm.ops[0]
1148 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1149 return op
1150 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1151 return op
1152 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1153 k = op.kernel
1154 k_w, k_h = k.dilated_wh()
1155
1156 # Check if the PAD operator can be replaced by hardware padding
1157 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1158 # Too much padding, it would require hardware padding to actually insert zeros
1159 return op
1160 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1161 return op
1162
1163 if op.type.is_avgpool_op():
1164 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1165 for pad, k_size in (
1166 (left, k_w),
1167 (right, k_w),
1168 (top, k_h),
1169 (bottom, k_h),
1170 ):
1171 if pad not in (0, k_size // 2):
1172 return op
1173 # Average pool is converted to depthwise, because NPU average pool + same padding
1174 # has a special implementation that is different from PAD followed by average pool with
1175 # valid padding.
1176 k_w, k_h = op.kernel.width, op.kernel.height
1177 ifm = op.ifm
1178 # Remember other inputs
1179 other_inputs = op.inputs[1:]
1180 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1181 quantization = QuantizationParameters(0.0, 255.0)
1182 quantization.scale_f32 = 1.0 / (k_w * k_h)
1183 quantization.zero_point = 0
1184 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1185 weights = np.full(shape, 1)
1186
1187 weight_tens = create_const_tensor(
1188 op.name + "_weights",
1189 shape,
1190 op.ifm.dtype,
1191 weights,
1192 np.uint8,
1193 purpose=TensorPurpose.Weights,
1194 quantization=quantization,
1195 )
James Peet7519d502021-07-19 16:47:58 +01001196 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001197 op.type = Op.DepthwiseConv2DBias
1198 op.inputs = []
1199 op.add_input_tensor(ifm)
1200 op.add_input_tensor(weight_tens)
1201 # Add bias tensor, all biases set to 0
1202 op.inputs.append(None)
1203 fixup_bias_tensors(op, arch, nng)
1204 # Add other inputs
1205 op.inputs.extend(other_inputs)
1206 op.rounding_mode = NpuRoundingMode.NATURAL
1207
1208 # Bypass the PAD operator
1209 op.set_input_tensor(pad_op.ifm, 0)
1210 # Adjust the padding attributes of the convolution operator
1211 op.attrs["padding"] = Padding.EXPLICIT
1212 op.attrs["explicit_padding"] = (top, left, bottom, right)
1213 op.set_ifm_ofm_shapes()
1214 return op
1215
1216
1217def convert_pad(op: Operation, arch, nng):
1218 """
1219 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1220 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1221 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1222 """
1223 if op.type != Op.Pad or not op.run_on_npu:
1224 return op
1225 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1226
1227 ifm = op.ifm
1228 assert ifm is not None
1229 ifm_shape = Shape4D(ifm.shape)
1230 ofm = op.ofm
1231 assert ofm is not None
1232 ofm.ops = []
1233 ofm_shape = op.ofm_shapes[0]
1234
1235 # Average pool op that copies IFM to the right place inside the OFM
1236 shp0 = Shape4D(0, 0, 0, 0)
1237 shp_top = shp0.with_height(top)
1238 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1239 avgpool_op.activation = op.activation
1240 quant = ofm.quantization
1241 pad_value = quant.zero_point
1242 # Add operations that fill the borders of the OFM
1243 if top > 0:
1244 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1245 zero_tens = create_const_tensor(
1246 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1247 )
1248 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1249 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1250 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1251 if bottom > 0:
1252 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1253 zero_tens = create_const_tensor(
1254 op.name + "_bottom",
1255 shape.as_list(),
1256 ofm.dtype,
1257 shape.elements() * [pad_value],
1258 np.uint8,
1259 quantization=quant,
1260 )
1261 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1262 create_avg_pool_for_concat(
1263 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1264 )
1265 if left > 0:
1266 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1267 zero_tens = create_const_tensor(
1268 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1269 )
1270 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1271 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1272 if right > 0:
1273 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1274 zero_tens = create_const_tensor(
1275 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1276 )
1277 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1278 create_avg_pool_for_concat(
1279 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1280 )
1281
1282 op.type = Op.ConcatTFLite
1283 return avgpool_op
1284
1285
1286def add_attrs_to_resizebilinear(op, arch, nng):
1287 if op.type == Op.ResizeBilinear and op.run_on_npu:
1288 input_tensor = op.inputs[0]
1289 input_shape = op.ifm_shapes[0]
1290 upscaled_height = input_shape.height * 2
1291 upscaled_width = input_shape.width * 2
1292 out_shape = op.ofm_shapes[0]
1293 if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
1294 # this means the output is supposed to be a x2 upscale,
1295 # so we need to do SAME padding
1296 op.attrs["padding"] = Padding.SAME
1297 elif (
1298 op.attrs["align_corners"]
1299 and out_shape.height == (upscaled_height - 1)
1300 and out_shape.width == (upscaled_width - 1)
1301 ):
1302 # here we can just run the avg pool without padding and
1303 # produce a (M * 2 - 1, N * 2 - 1) sized output
1304 op.attrs["padding"] = Padding.VALID
1305 else:
1306 return op
1307 input_tensor.resampling_mode = resampling_mode.NEAREST
1308 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
1309 return op
1310
1311
1312def fixup_bias_tensors(op, arch, nng):
1313 if op.type.needs_bias() and op.bias is None:
1314 # Op has no bias, add bias tensor filled with zeros
1315 nr_biases = op.inputs[1].shape[-1]
1316 bias_values = [0] * nr_biases
1317 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001318 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1319
1320 return op
1321
1322
1323def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1324 if op.type == Op.Mean and op.run_on_npu:
1325 keep_dims = op.attrs.get("keep_dims", False)
1326 inp, axis = op.inputs
1327 shape = inp.shape
1328 dims = len(shape)
1329
1330 # Height and width axes have different index depending on dimensions
1331 if axis.shape == [] or axis.shape[0] == 1: # single axis
1332 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1333 if dims in (2, 3):
1334 if axis == 0:
1335 h, w = shape[axis], 1
1336 else:
1337 h, w = 1, shape[axis]
1338 else:
1339 if axis == 1:
1340 h, w = shape[axis], 1
1341 else:
1342 h, w = 1, shape[axis]
1343 else: # multiple axes
1344 axis = sorted(axis.values)
1345 h, w = [shape[i] for i in axis]
1346
1347 # Set necessary depthwise attributes
1348 op.attrs.update(
1349 {
1350 "padding": Padding.VALID,
1351 "stride_h": 1,
1352 "stride_w": 1,
1353 "strides": (1, 1, 1, 1),
1354 "depth_multiplier": 1,
1355 "channel_multiplier": 1,
1356 "dilation_h_factor": 1,
1357 "dilation_w_factor": 1,
1358 "dilation": (1, 1, 1, 1),
1359 }
1360 )
1361 # Change op type
1362 op.type = Op.DepthwiseConv2DBias
1363 # Set IFM/OFM shapes after changing op type
1364 op.set_ifm_ofm_shapes()
1365
1366 weight_scale, bias = 1, None
1367 ofmq, ifmq = op.ofm.quantization, inp.quantization
1368 # Set rounding mode, scaling and zero point based on which reference implementation to match
1369 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1370 if inp.dtype == DataType.uint8:
1371 # This attribute means a different scaling calculation is used in order to match reference
1372 op.low_precision_scaling = True
1373 weight_scale = h * w
1374 # Set zero points to 0 as they will be adjusted for with bias term
1375 foq = ofmq.clone()
1376 foq.zero_point = 0
1377 fiq = ifmq.clone()
1378 fiq.zero_point = 0
1379 op.forced_input_quantization = fiq
1380 bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
1381 # If the bias term is outside uint8 range, we need an Add op to apply it.
1382 if bias_term < 0 or bias_term > 255:
1383 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1384 # Bias term has higher bitness (i32) than input/output (u8).
1385 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1386 # the bias can only effectively assume values in the range [-255, 255].
1387 intermediate.dtype = DataType.int16
1388 intermediate.quantization.zero_point = 0
1389 add_op = Operation(Op.Add, op.name + "_bias")
1390 add_op.forced_output_quantization = foq
1391 add_op.add_input_tensor(intermediate)
1392 quant = QuantizationParameters()
1393 quant.zero_point = 0
1394 bias_term_tens = create_const_tensor(
James Peet7519d502021-07-19 16:47:58 +01001395 op.name + "_bias", [1, 1, 1, 1], DataType.int16, [bias_term], np.int16, quantization=quant,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001396 )
1397 add_op.add_input_tensor(bias_term_tens)
1398 add_op.set_output_tensor(op.ofm)
1399 add_op.set_ifm_ofm_shapes()
1400 add_op.activation = op.activation
1401 op.activation = None
1402 op.set_output_tensor(intermediate)
1403 op.set_ifm_ofm_shapes()
1404 # If not, we can just do it with the OFM zero point.
1405 else:
1406 foq.zero_point = bias_term
1407 op.forced_output_quantization = foq
1408 else:
1409 assert inp.dtype == DataType.int8
1410 # Use a depthwise to calculate the sum,
1411 # followed by a multiplication with 1/N to get the MEAN
1412 weight_scale = 1
1413 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1414 intermediate.dtype = DataType.int16
1415 mul_op = Operation(Op.Mul, op.name + "_mul")
1416 mul_op.add_input_tensor(intermediate)
1417 # Create scalar containing 1/N
1418 quant = QuantizationParameters()
1419 quant.zero_point = 0
1420 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1421 # while rounding mode NATURAL would round this to -1.
1422 # This can only occur if N is even, and can be emulated by
1423 # multiplying with a number that is slightly smaller than 1/N.
1424 # It must be so small that other roundings are not affected;
1425 # the calculated value is based on worst case,
1426 # which is sum 256 * N (the maximum sum that can occur with int8)
1427 n = int(h * w)
1428 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1429 quant.scale_f32 = 1 / (n - eps)
1430 scalar = create_const_tensor(
1431 op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
1432 )
1433 mul_op.add_input_tensor(scalar)
1434 mul_op.set_output_tensor(op.ofm)
1435 mul_op.set_ifm_ofm_shapes()
1436 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1437 mul_op.activation = op.activation
1438 op.activation = None
1439 op.set_output_tensor(intermediate)
1440 op.set_ifm_ofm_shapes()
1441 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1442 # Here we can just use a simple AvgPool with truncating rounding,
1443 # as we're emulating simple integer division.
1444 op.rounding_mode = NpuRoundingMode.TRUNCATE
1445 op.type = Op.AvgPool
1446 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1447 else:
1448 op.rounding_mode = NpuRoundingMode.NATURAL
1449 weight_scale = 1 / (h * w)
1450 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1451 bias = -ifmq.zero_point * h * w
1452 fiq = ifmq.clone()
1453 fiq.zero_point = 0
1454 op.forced_input_quantization = fiq
1455
1456 # Change dimensions to 4
1457 if dims < 4:
1458 shape = [1] + shape
1459 if dims == 2:
1460 shape += [1]
1461
1462 # If height is greater than max kernel height, reshape to from HxW to 1x(HxW)
1463 if h > 64:
1464 shape = [shape[0], 1, h * w, shape[3]]
1465 op.ifm_shapes[0] = Shape4D(shape)
1466 if h > 256 and op.type == Op.AvgPool:
1467 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1468
1469 # If the AvgPool version is used, we don't need to do anything else
1470 if op.type == Op.AvgPool:
1471 return op
1472
1473 # Make unit weight tensor quantization
1474 weight_quant = ifmq.clone()
1475 weight_quant.min = 0
1476 weight_quant.max = 255
1477 weight_quant.scale_f32 = weight_scale
1478 weight_quant.zero_point = 0
1479
1480 # Set weight shape to [H,W,C,B]
1481 weight_shape = shape[1:4] + [shape[0]]
1482 # Add unit weight tensor
1483 op.set_input_tensor(
1484 create_const_tensor(
1485 "weights",
1486 weight_shape,
1487 inp.dtype,
1488 np.ones(weight_shape),
1489 value_dtype=np.uint8,
1490 quantization=weight_quant,
1491 ),
1492 1,
1493 )
James Peet7519d502021-07-19 16:47:58 +01001494 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001495
1496 # Add None bias tensor
1497 op.inputs.append(None)
1498 # Add bias tensor
1499 if bias:
1500 bias_shape = [shape[-1]]
1501 op.set_input_tensor(
1502 create_const_tensor(
Tim Hall8ae29292021-07-28 16:52:03 +01001503 "bias", bias_shape, inp.dtype, np.ones(bias_shape) * bias, value_dtype=np.int32, quantization=None,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001504 ),
1505 2,
1506 )
1507
1508 return op
1509
1510
1511def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001512 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001513 return op
1514
1515
1516def tflite_optimise_graph(nng, arch):
1517 # Pre-processing step
1518 pre_process_list = [
1519 supported_operator_check,
1520 set_ifm_ofm_op_shapes,
1521 ]
1522
1523 for idx, sg in enumerate(nng.subgraphs):
1524 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1525 nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
1526 )
1527
1528 # Handle Concat Ops
1529 for idx, sg in enumerate(nng.subgraphs):
1530 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1531 sg.refresh_after_modification()
1532
1533 # Handle Split Ops
1534 for idx, sg in enumerate(nng.subgraphs):
1535 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1536 nng,
1537 sg,
1538 arch,
1539 [],
1540 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1541 rewrite_unsupported=False,
1542 )
1543
1544 for idx, sg in enumerate(nng.subgraphs):
1545 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1546 nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False,
1547 )
1548
1549 # Handle sg input output
1550 for idx, sg in enumerate(nng.subgraphs):
1551 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1552 nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
1553 )
1554
Jonas Ohlsson81942e92021-08-20 09:33:28 +02001555 # Removal of reshapes and squeeze
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001556 for sg in nng.subgraphs:
Jonas Ohlsson81942e92021-08-20 09:33:28 +02001557 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshape_and_squeeze_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001558 sg.refresh_after_modification()
1559
1560 # Rewrite of operators
1561 op_rewrite_list = [
1562 set_tensor_equivalence,
1563 convert_mean_to_depthwise_conv_or_avgpool,
1564 convert_depthwise_to_conv,
1565 convert_conv_to_fc,
1566 convert_softmax,
1567 optimise_strided_conv,
1568 convert_hardswish_to_lut,
1569 rewrite_fully_connected_input,
1570 convert_batched_fc_shape,
1571 fixup_conv2d_backprop,
1572 fixup_relus_with_differing_ifm_ofm_scaling,
1573 fixup_elementwise_with_scalars,
1574 reorder_depthwise_weights,
1575 fixup_resizebilinear,
1576 fixup_bias_tensors,
1577 convert_mul_max_to_abs_or_lrelu,
1578 convert_lrelu,
1579 convert_tanh_sigmoid_to_lut,
1580 replace_pad_by_hw_pad,
1581 ]
1582
1583 for idx, sg in enumerate(nng.subgraphs):
1584 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1585 nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
1586 )
1587
1588 for idx, sg in enumerate(nng.subgraphs):
1589 # remove passthrough tensors and attempt further optimizations
1590 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1591 nng,
1592 sg,
1593 arch,
1594 [remove_passthrough_tensor],
1595 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1596 )
1597
1598 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1599 # since ifm/ofm_shapes are of importance to this function
1600 for sg in nng.subgraphs:
1601 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1602 sg.refresh_after_modification()
1603
1604 return nng