blob: 3d9eeb8a2b108e9343263b3a8cc92b6dc2da1391 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
21from typing import Tuple
22
23import numpy as np
24
25from . import fp_math
26from . import lut
27from . import rewrite_graph
28from . import scaling
29from .api import NpuRoundingMode
30from .data_type import DataType
31from .debug_database import DebugDatabase
32from .errors import UnsupportedFeatureError
33from .ethos_u55_regs.ethos_u55_regs import resampling_mode
34from .graph_optimiser_util import needed_total_padding
35from .graph_optimiser_util import set_ifm_ofm_op_shapes
36from .graph_optimiser_util import set_tensor_equivalence
37from .numeric_util import clamp_sigmoid
38from .numeric_util import full_shape
39from .numeric_util import round_away_zero
40from .operation import create_activation_function
41from .operation import NpuBlockType
42from .operation import Op
43from .operation import Operation
44from .operation import Padding
45from .operation_util import create_avgpool_nop
46from .operation_util import get_pad_values_from_input
47from .shape4d import Shape4D
48from .softmax import SoftMax
49from .tensor import check_quantized_tens_scaling_equal
50from .tensor import create_const_tensor
51from .tensor import create_equivalence_id
52from .tensor import QuantizationParameters
53from .tensor import Tensor
54from .tensor import TensorPurpose
55from .tflite_mapping import optype_to_builtintype
56
57passthrough_nodes = (Op.Identity,)
58
59
60def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
61 """Creates an average pool for the given concat op/input feature map"""
62 ofm = concat_op.ofm
63 avgpool_op = create_avgpool_nop(name)
64 avgpool_op.inputs = [ifm]
65 avgpool_op.outputs = [ofm]
66
67 avgpool_op.write_offset = write_offset
68 avgpool_op.write_shape = ifm_shape
69 ofm.ops.append(avgpool_op)
70 DebugDatabase.add_optimised(concat_op, avgpool_op)
71 avgpool_op.ifm_shapes.append(ifm_shape)
72 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
73 avgpool_op.memory_function = Op.ConcatSliceWrite
74 return avgpool_op
75
76
77def remove_passthrough_tensor(tens, arch, nng):
78 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
79 assert len(tens.ops[0].inputs) == 1
80 tens = tens.ops[0].inputs[0]
81 return tens
82
83
84def rewrite_concat_ops(op, arch):
85 if not op.run_on_npu or not op.type.is_concat_op():
86 return
87
88 axis_4D = 0
89 ofm = op.ofm
90 ofm.ops = []
91 offset = 0
92
93 unfuse_activation_function(op)
94
95 if op.type == Op.Pack:
96 # Pack is also referred to as Stack
97 axis = int(op.attrs["axis"])
98 if axis < 0: # Convert to positive axis
99 axis = len(op.inputs[0].shape) + 1 + axis
100
101 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
102
103 axis_4D = axis + (4 - len(desired_shape))
104
105 for idx, inp in enumerate(op.inputs):
106 op.ifm_shapes[idx] = Shape4D(desired_shape)
107 op.type = Op.PackReshaped
108
109 inputs, axis = op.get_concat_inputs_axis()
110 for idx, inp in enumerate(inputs):
111 if op.type != Op.PackReshaped:
112 op.ifm_shapes[idx] = Shape4D(inp.shape)
113 if axis >= 0:
114 axis_4D = axis + (4 - len(inp.shape))
115 else:
116 axis_4D = axis
117 write_offset = [0, 0, 0, 0]
118 write_offset[axis_4D] = offset
119 concat_end = offset + op.ifm_shapes[idx][axis_4D]
120 create_avg_pool_for_concat(
121 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
122 )
123 offset = concat_end
124 assert ofm.shape[axis] == offset
125
126 return op
127
128
129def rewrite_split_ops(tens, arch, nng):
130
131 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
132 split_op = tens.ops[0]
133
134 # Not supported so leave it and run on CPU
135 if not split_op.run_on_npu:
136 return tens
137
138 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
139
140 tens.ops = []
141 new_op = Operation(Op.SplitSliceRead, split_op.name)
142 new_op.inputs = [inp]
143 ofm_shape_idx = 0
144 read_shape = offset_end
145
146 # For Split the offset cannot be extracted from the tensor so it has to
147 # be calculated from the index of the output tensor
148 if axis is not None:
149 # Get the start and end of the split
150 offset_start = [0] * 4
151 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
152 for idx, out in enumerate(outputs):
153 if axis_4D_list is not None:
154 axis_4D = axis_4D_list[idx]
155 else:
156 split_op.ofm_shapes[idx] = Shape4D(out.shape)
157 if axis >= 0:
158 axis_4D = axis + (4 - len(out.shape))
159 else:
160 axis_4D = axis
161
162 if out == tens:
163 ofm_shape_idx = idx
164 read_shape = split_op.ofm_shapes[idx]
165 break
166
167 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
168
169 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
170 new_op.read_shapes[0] = read_shape
171 new_op.run_on_npu = True
172 new_op.set_output_tensor(tens)
173 new_op.ifm_shapes.append(Shape4D(inp.shape))
174 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
175 DebugDatabase.add_optimised(split_op, new_op)
176
177 return tens
178
179
180def remove_SplitSliceRead(op, arch):
181
182 if op.type == Op.SplitSliceRead:
183 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
184 if (
185 len(op.ofm.consumer_list) == 1
186 and op.ofm.consumer_list[0] is not None
187 and op.ofm.consumer_list[0].run_on_npu
188 and op.ofm.consumer_list[0].type != Op.Reshape
189 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
190 ):
191 # SplitSliceRead can be performed by tensor consumer
192 cons_op = op.ofm.consumer_list[0]
193 if cons_op.ifm == op.ofm:
194 cons_op.read_offsets[0] = op.read_offsets[0]
195 cons_op.read_shapes[0] = op.read_shapes[0]
196 cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[0])
197 cons_op.ifm_shapes[0] = op.ifm_shapes[0]
198 elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == op.ofm:
199 cons_op.read_offsets[1] = op.read_offsets[0]
200 cons_op.read_shapes[1] = op.read_shapes[0]
201 cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[1])
202 cons_op.ifm_shapes[1] = op.ifm_shapes[0]
203
204 if "skirt" in cons_op.attrs:
205 assert cons_op.attrs["explicit_padding"] == cons_op.attrs["skirt"]
206 cons_op.attrs["skirt"] = None
207 cons_op.attrs["force_padding"] = True
208 op.ofm.consumer_list.remove(cons_op)
209 op.ofm.ops = []
210 op.ifm.consumer_list.remove(op)
211 else:
212 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
213 avgpool_op.add_input_tensor(op.ifm)
214 avgpool_op.outputs = [op.ofm]
215 op.ofm.ops.remove(op)
216 op.ofm.ops.append(avgpool_op)
217 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
218 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
219 avgpool_op.read_offsets[0] = op.read_offsets[0]
220 avgpool_op.read_shapes[0] = op.read_shapes[0]
221
222 op.ifm.consumer_list.remove(op)
223 DebugDatabase.add_optimised(op, avgpool_op)
224
225
226def insert_copy_op_after_tens(tens):
227 tens_cons_list_copy = tens.consumer_list.copy()
228
229 # Create a avg_pool nop op with ifm as input
230 copy_tens = tens.clone()
231 copy_op = create_avgpool_nop(tens.name + "_avgpool")
232 copy_op.add_input_tensor(tens)
233 copy_op.set_output_tensor(copy_tens)
234 copy_op.set_ifm_ofm_shapes()
235 copy_op.run_on_npu = True
236
237 # Set copy_ifm consumers
238 for tens_cons in tens_cons_list_copy:
239 if tens_cons is not None:
240 for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
241 if cons_inp == tens:
242 tens_cons.set_input_tensor(copy_tens, ifm_idx)
243
244 DebugDatabase.add_optimised(tens.ops[0], copy_op)
245
246
247def fix_sg_input_output(op, arch, nng):
248 if not op.run_on_npu or op.type != Op.Reshape:
249 return op
250
251 # For the Reshape operators we want to remove, tensors are removed.
252 # But in order to to do this, they cannot be outputs of the sg,
253 # this need to be fixed prior to the removal.
254 # Solution is to add a avgpool NOP, to maintain the original tensor.
255 # This is also valid when reshape ifm/ofm is produced respectively
256 # consumed by CPU
257
258 # Check if operator ifm/ofm are sg ifm/ofm
259 ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
260 ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
261 ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
262 # Check if ifm/ofm is produced repectivly consumed by CPU
263 ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
264 ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
265
266 if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
267 # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape
268 insert_copy_op_after_tens(op.ifm)
269
270 return op
271
272
273def calc_explicit_padding(input_size, stride, filter_size, pad_before, pad_after) -> Tuple[int, int]:
274 """
275 Based on explicit padding provided in a PAD operation, returns the corresponding hardware padding
276 that provides equivalent results.
277 """
278 total_padding = needed_total_padding(input_size, stride, filter_size)
279 # The top/left padding can be taken as is from the PAD
280 output_pad_before = pad_before
281 # The bottom/right padding might need downward adjustment depending on stride/input size
282 output_pad_after = pad_after
283 while output_pad_after > 0 and output_pad_after % stride != (total_padding - pad_before) % stride:
284 output_pad_after -= 1
285 return output_pad_before, output_pad_after
286
287
288def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
289 k_w, k_h = kernel.dilated_wh()
290 s_x, s_y = kernel.stride
291 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
292 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
293 if padding_type == Padding.SAME:
294 left_pad = (xpad + 0) // 2
295 right_pad = (xpad + 1) // 2
296 top_pad = (ypad + 0) // 2
297 bottom_pad = (ypad + 1) // 2
298 elif padding_type == Padding.VALID:
299 left_pad = 0
300 right_pad = 0
301 top_pad = 0
302 bottom_pad = 0
303 elif padding_type == Padding.EXPLICIT:
304 # Padding is specified in a PAD operator which has been bypassed.
305 top, left, bottom, right = explicit_padding
306 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
307 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
308 else:
309 raise UnsupportedFeatureError(f"Unknown padding")
310 padding = (top_pad, left_pad, bottom_pad, right_pad)
311 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
312 return padding, skirt
313
314
315def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
316 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
317 if padding_type == Padding.SAME:
318 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
319 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
320 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
321 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
322 left_pad = max(kernel_width - 1 - right_pad, 0)
323 top_pad = max(kernel_height - 1 - bottom_pad, 0)
324 elif padding_type == Padding.VALID:
325 right_pad = max(kernel_width - 2, 0)
326 bottom_pad = max(kernel_height - 2, 0)
327 left_pad = kernel_width - 1
328 top_pad = kernel_height - 1
329 else:
330 raise UnsupportedFeatureError(f"Unknown padding")
331 padding = (top_pad, left_pad, bottom_pad, right_pad)
332 skirt = padding
333 return padding, skirt
334
335
336def fixup_conv2d_backprop(op, arch, nng):
337 if op.type == Op.Conv2DBackpropInput:
338 # flip the inputs
339 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
340 op.type = Op.Conv2DBackpropInputSwitchedBias
341 op.ifm.resampling_mode = resampling_mode.TRANSPOSE
342
343 # Update strides
344 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
345
346 return op
347
348
349# Convert the op to an elementwise add
350def convert_resizebilinear_1x1_to_add(op):
351 op.type = Op.Add
352 op.name = op.name + "_add"
353 op.attrs["resizebilinear"] = True
354 # Create an input tensor filled with zeros
355 shape = op.ofm_shapes[0].as_list()
356 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
357 tens.values = np.zeros(shape)
358 tens.quant_values = np.zeros(shape, np.uint8)
359 tens.quantization = QuantizationParameters(0.0, 255.0)
360 tens.quantization.scale_f32 = 1.0
361 tens.quantization.zero_point = 0
362 tens.consumer_list = [op]
363 tens_op = op.inputs[1].ops[0]
364 tens_op.set_output_tensor(tens)
365 # Set the add inputs
366 op.inputs[1] = op.inputs[0]
367 op.inputs[0] = tens
368 op.set_ifm_ofm_shapes()
369
370 return op
371
372
373# Convert ResizeBilinear to a number of 2x2 pool ops
374def convert_resizebilinear_to_2x2_pool(op):
375 count = 0
376 pre_op = op
377 outputs = op.outputs
378
379 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
380 if op.attrs["align_corners"]:
381 shape_modifier = 1
382 op.attrs["padding"] = Padding.VALID
383 else:
384 shape_modifier = 0
385 op.attrs["padding"] = Padding.SAME
386 op.inputs[0].resampling_mode = resampling_mode.NEAREST
387
388 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
389 out_shape = np.array(op.ofm_shapes[0].get_hw_as_list())
390 if (upscaled_shape == upscaled_shape * 2 - shape_modifier).all():
391 return op
392
393 while (upscaled_shape < out_shape).all():
394 if count == 0:
395 scaled_op = pre_op
396 else:
397 scaled_op = op.clone("_{}".format(count))
398 scaled_op.inputs[0] = pre_op.outputs[0]
399
400 upscaled_shape = upscaled_shape * 2 - shape_modifier
401
402 if (upscaled_shape == out_shape).all():
403 scaled_op.outputs = outputs
404 scaled_op.outputs[0].ops = [scaled_op]
405 else:
406 shape = op.ofm_shapes[0].as_list()
407 shape[1:3] = upscaled_shape
408 out_tens = Tensor(shape, DataType.int16, "{}_{}".format(op.outputs[0].name, count))
409 out_tens.quantization = op.outputs[0].quantization.clone()
410 out_tens.quantization.quant_min = np.iinfo(np.int16).min
411 out_tens.quantization.quant_max = np.iinfo(np.int16).max
412 scaled_op.set_output_tensor(out_tens)
413 pre_op = scaled_op
414 count += 1
415
416 # Setup the scale value
417 if scaled_op.inputs[0].dtype.bits == 8 and scaled_op.outputs[0].dtype.bits == 16:
418 scaled_op.rescale = 128
419 elif scaled_op.inputs[0].dtype.bits == 16 and scaled_op.outputs[0].dtype.bits == 8:
420 scaled_op.rescale = 1 / 128
421 else:
422 scaled_op.rescale = None
423 scaled_op.set_ifm_ofm_shapes()
424
425 return op
426
427
428def fixup_resizebilinear(op, arch, nng):
429 if op.type == Op.ResizeBilinear and op.run_on_npu:
430 if op.ifm_shapes[0] == op.ofm_shapes[0]:
431 # Bypass nop resizebilinear
432 op.inputs = op.inputs[:1]
433 op.type = Op.Identity
434 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
435 convert_resizebilinear_1x1_to_add(op)
436 else:
437 convert_resizebilinear_to_2x2_pool(op)
438
439 return op
440
441
442def convert_nop_split_to_identity(op, arch, nng):
443 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
444 # the list comprehension should return a list with a single tensor
445 # if it shouldn't, remove_passthrough_tensor will fail appropriately
446 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
447 op.type = Op.Identity
448 return op
449
450
451def rewrite_fully_connected_input(op, arch, nng):
452 if op.type == Op.FullyConnected:
453 n_in_elems = op.weights.shape[-2]
454 elms = op.ifm.elements()
455 batch_size = elms // n_in_elems
456 assert batch_size * n_in_elems == elms
457
458 op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
459 return op
460
461
462def convert_batched_fc_shape(op, arch, nng):
463 if op.type == Op.FullyConnected:
464 # Check if the first dimension indicates batching
465 if op.ifm_shapes[0].batch > 1:
466 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
467 n = op.ifm_shapes[0].batch
468 h, w = batching_split.get(n, (1, n))
469 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
470
471 # Reshape Weights to be 4D. IO becomes HWIO
472 weight_tensor = op.inputs[1]
473 weight_tensor.quant_values = np.expand_dims(np.expand_dims(weight_tensor.quant_values, axis=0), axis=0)
474 weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
475
476 n = op.ofm_shapes[0].batch
477 h, w = batching_split.get(n, (1, n))
478 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
479 return op
480
481
482def unfuse_activation_function(op):
483 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
484 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
485 op.activation = None
486 out_tens = op.outputs[0]
487 intermediate_tens = out_tens.clone("_act_intermediate")
488 act_op.set_output_tensor(out_tens)
489 act_op.add_input_tensor(intermediate_tens)
490 op.set_output_tensor(intermediate_tens)
491 act_op.set_ifm_ofm_shapes()
492
493
494def rewrite_stridedslice_output(op, arch, nng):
495 if not op.run_on_npu or op.type != Op.StridedSlice:
496 return op
497
498 new_axis_mask = op.attrs["new_axis_mask"]
499 shrink_axis_mask = op.attrs["shrink_axis_mask"]
500
501 if shrink_axis_mask == 0 and new_axis_mask == 0:
502 return op
503
504 axis_4D = [0] * len(op.outputs)
505 for idx, out_tens in enumerate(op.outputs):
506 output_shape = list(out_tens.shape)
507
508 if shrink_axis_mask != 0:
509 n = 0
510 axis = 0
511 while shrink_axis_mask:
512 prev_mask = shrink_axis_mask
513 n += 1
514 shrink_axis_mask &= shrink_axis_mask - 1
515 axis = int(math.log2(prev_mask - shrink_axis_mask))
516 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
517
518 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
519 op.attrs["shrink_axis_mask"] = 0
520 if axis >= 0:
521 axis_4D[idx] = axis + (4 - len(output_shape))
522 else:
523 axis_4D[idx] = axis
524 op.ofm_shapes[idx] = Shape4D(output_shape)
525
526 elif new_axis_mask != 0:
527 n = 0
528 axis = 0
529 while new_axis_mask:
530 prev_mask = new_axis_mask
531 n += 1
532 new_axis_mask &= new_axis_mask - 1
533 axis = int(math.log2(prev_mask - new_axis_mask))
534 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
535 new_axis_mask >>= 1
536
537 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
538 op.attrs["new_axis_mask"] = 0
539 if axis >= 0:
540 axis_4D[idx] = axis + (4 - len(output_shape))
541 else:
542 axis_4D[idx] = axis
543 op.ofm_shapes[idx] = Shape4D(output_shape)
544
545 op.attrs["split_axis_4D"] = axis_4D
546 return op
547
548
549def rewrite_unpack_output(op, arch, nng):
550 tens = op.outputs[0]
551 if op.run_on_npu and op.type == Op.Unpack:
552 # Unpack is also referred to as Unstack
553 axis = int(op.attrs["axis"])
554 if axis < 0: # Convert to positive axis
555 axis = len(op.inputs[0].shape) + 1 + axis
556 op.type = Op.UnpackReshaped
557 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
558
559 axis_4D = axis + (4 - len(desired_output_shape))
560 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
561
562 for idx, out_tens in enumerate(op.outputs):
563 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
564 return op
565
566
567def add_padding_fields(op, arch, nng):
568 if op.run_on_npu:
569 if "padding" in op.attrs:
570 input_shape = op.ifm_shapes[0]
571 output_shape = op.ofm_shapes[0]
572 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
573 kernel_size = op.inputs[1].shape[:2]
574 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
575 kernel_size = op.attrs["ksize"][1:3]
576 else:
577 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
578
579 if op.type == Op.Conv2DBackpropInputSwitchedBias:
580 upscaling_factor = output_shape.height // input_shape.height
581 padding, skirt = calc_upscaled_padding_and_skirt(
582 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
583 )
584 else:
585 padding, skirt = calc_padding_and_skirt(
586 op.attrs["padding"], op.kernel, input_shape, op.attrs.get("explicit_padding"),
587 )
588
589 op.attrs["explicit_padding"] = padding
590 op.attrs["skirt"] = skirt
591
592 return op
593
594
595def convert_depthwise_to_conv(op, arch, nng):
596 # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and
597 # the ofm depth equals the depth multipler.
598 # If those conditions are true, then we can perform a simple
599 # switch of the operator type (and weight order)
600
601 if op.type == Op.DepthwiseConv2DBias and (op.attrs["depth_multiplier"] != 1):
602 ifm_shape = op.ifm_shapes[0]
603 weight_tensor = op.inputs[1]
604 ofm_shape = op.ofm_shapes[0]
605 if (ifm_shape.depth == 1) and (ofm_shape.depth == op.attrs["depth_multiplier"]):
606 # Change op type to Conv2d
607 op.type = Op.Conv2DBias
608 del op.attrs["channel_multiplier"]
609 del op.attrs["depth_multiplier"]
610
611 weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2))
612 weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
613 else:
614 raise UnsupportedFeatureError(
615 f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},",
616 f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}",
617 )
618 DebugDatabase.add_optimised(op, op)
619 return op
620
621
622def reorder_depthwise_weights(op, arch, nng):
623 if op.type.is_depthwise_conv2d_op():
624 weight_tensor = op.inputs[1]
625 weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2))
626 weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
627 weight_tensor.weight_transpose_depthwise = True
628
629 return op
630
631
632def optimise_strided_conv(op, arch, nng):
633 stride_x, stride_y = op.get_kernel_stride()
634 ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()
635
636 if (
637 op.type == Op.Conv2DBias
638 and op.op_index == 0
639 and stride_x == 2
640 and op.ifm_shapes[0].depth <= 4
641 and op.ifm_shapes[0].width % 2 == 0
642 and weight_tensor is not None
643 and weight_tensor.shape[1] >= 2
644 ):
645 ifm_shape = op.ifm_shapes[0]
646 # IFM
647 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
648
649 # Weights
650 weight_shape = weight_tensor.shape
651 if weight_shape[1] % 2 != 0:
652 weight_shape[1] = weight_shape[1] + 1
653 padded_array = np.zeros(weight_shape)
654 for i in range(weight_shape[0]):
655 padded_array[i] = np.vstack(
656 [
657 weight_tensor.quant_values[i],
658 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
659 ]
660 )
661 weight_tensor.quant_values = padded_array
662 weight_shape[1] //= 2
663 weight_shape[2] *= 2
664 weight_tensor.quant_values = np.reshape(weight_tensor.quant_values, weight_shape)
665 weight_tensor.set_all_shapes(weight_shape)
666 # If multiple copies of the weights are used, we could avoid
667 # them having the same address by changing the value_id
668 weight_tensor.value_id = uuid.uuid4()
669
670 # Strides
671 stride_x = 1
672 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
673
674 return op
675
676
677def convert_conv_to_fc(op, arch, nng):
678 # Conv 1x1 can be equivalent to Fully Connected.
679 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
680 # caching/double buffering for the weights.
681 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
682 if op.type == Op.Conv2DBias:
683 h = op.ifm_shapes[0].height
684 w = op.ifm_shapes[0].width
685 kh, kw, _, _ = op.inputs[1].shape
686 if h == 1 and w == 1 and kh == 1 and kw == 1:
687 # Overwrite this op as a Fully Connected Op
688 op.name += "_fc"
689 op.type = Op.FullyConnected
690 op.attrs = {
691 "weights_format": 0,
692 }
693 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
694 weight_tensor = op.inputs[1]
695 weight_tensor.quant_values = weight_tensor.quant_values.squeeze(axis=(0, 1))
696 weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
697
698 DebugDatabase.add_optimised(op, op)
699 return op
700
701
702def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
703 if op.run_on_npu and op.type.is_relu_op():
704 ifm = op.inputs[0]
705 ofm = op.outputs[0]
706 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
707 # and requires its own to be inserted
708 if not check_quantized_tens_scaling_equal(ifm, ofm):
709 # Override this op with its own primary op (avgpool)
710 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
711 # And fuse the original activation function to it
712 relu_fused_op.activation = create_activation_function(op.type)
713 # Tidy up and assign the ifm and ofm to the new op
714 ifm.consumer_list.remove(op)
715
716 relu_fused_op.add_input_tensor(ifm)
717 relu_fused_op.set_output_tensor(ofm)
718 relu_fused_op.set_ifm_ofm_shapes()
719 op = relu_fused_op
720 return op
721
722
723def fixup_elementwise_with_scalars(op, arch, nng):
724 if op.type.is_binary_elementwise_op():
725 ifm_tensor, ifm2_tensor, _, _ = op.get_ifm_ifm2_weights_ofm()
726 if ifm2_tensor.shape != [] and ifm_tensor.shape != []:
727 diff = len(ifm_tensor.shape) - len(ifm2_tensor.shape)
728 if diff > 0:
729 ifm2_tensor.shape = full_shape(len(ifm_tensor.shape), ifm2_tensor.shape, 1)
730 elif diff < 0:
731 ifm_tensor.shape = full_shape(len(ifm2_tensor.shape), ifm_tensor.shape, 1)
732 elif ifm_tensor.shape == [] and ifm_tensor.quant_values is None:
733 # IFM is marked as a scalar, but is a result of an operation; change it to a shape of size 1
734 ifm_tensor.shape = len(ifm2_tensor.shape) * [1]
735 ifm_tensor.storage_shape = ifm_tensor.shape
736 elif ifm2_tensor.shape == [] and ifm2_tensor.quant_values is None:
737 # IFM2 is marked as a scalar, but is a result of an operation; change it to a shape of size 1
738 ifm2_tensor.shape = len(ifm_tensor.shape) * [1]
739 ifm2_tensor.storage_shape = ifm2_tensor.shape
740 return op
741
742
743def convert_softmax(op, arch, nng):
744 if op.type == Op.Softmax and op.run_on_npu:
745 softmax = SoftMax(op)
746 op = softmax.get_graph()
747 return op
748
749
750def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
751 r"""Whenever there is a subgraph with this topology:
752
753 Input X For X = -1 or X > 0
754 | \ / This subgraph can be replaced with either
755 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
756 | /
757 Max
758 """
759
760 if op.type == Op.Maximum:
761 # finds the Mul input(s) to the Max
762 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
763 if len(muls) == 1:
764 mul = muls[0].ops[0]
765 elif len(muls) == 2:
766 # In the case both inputs are Muls, find the one with the same input as the Max
767 mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
768 else:
769 # No Mul inputs
770 return op
771
772 # make sure the Mul doesn't have any other consumers
773 mul_ofm = mul.outputs[0]
774 if len(mul_ofm.consumers()) != 1:
775 return op
776 # make sure the Mul doesn't have a fused activation function
777 if mul.activation:
778 return op
779 ifm, ofm = op.get_ifm_ofm()
780 if ifm is None or ofm is None:
781 return op
782
783 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
784 return op
785 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
786 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
787 return op
788
789 # finds the branched input that goes to both the Max and the Mul
790 shared = set(op.inputs) & set(mul.inputs)
791 if len(shared) == 1:
792 shared_in = shared.pop()
793 # find the constant scalar input to the Mul
794 const_tens = (set(mul.inputs) - {shared_in}).pop()
795 # check that it is a scalar
796 if const_tens.shape != []:
797 return op
798 const = const_tens.ops[0]
799 # check that it is a constant
800 if const.type != Op.Const:
801 return op
802 # Remove the Mul from the shared input's consumers
803 shared_in.consumer_list.remove(mul)
804 else:
805 return op
806
807 val = const.outputs[0].values
808 if val >= 0:
809 new_op = Op.LeakyRelu
810 op.attrs["alpha"] = val
811 # to produce bit exact results, the alpha is not enough;
812 # save additional scaling info in attr "alpha_scale", to be used as input
813 # to the LUT construction
814 alpha_scalar = const_tens.quant_values - const_tens.quantization.zero_point
815 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
816 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
817 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
818 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
819 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
820 elif val == -1:
821 new_op = Op.Abs
822 else:
823 return op
824
825 op.type = new_op
826 op.name = op.name.replace("Maximum", new_op.name)
827 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
828 op.inputs = [shared_in]
829 op.set_ifm_ofm_shapes()
830
831 # Record optimisation in debug database
832 DebugDatabase.add_optimised(op, op)
833
834 return op
835
836
837def convert_hardswish_to_lut(op, arch, nng):
838 if op.type == Op.HardSwish:
839 ifm, ofm = op.get_ifm_ofm()
840 # Generate the LUT
841 ifm_scale = np.double(ifm.quantization.scale_f32)
842 ofm_scale = np.double(ofm.quantization.scale_f32)
843 zp_in = ifm.quantization.zero_point
844 zp_out = ofm.quantization.zero_point
845 ifm_scale_hires = (1 / 128) * ifm_scale
846 relu_multiplier = np.double(3 / 32768)
847 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
848 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
849 # Use 16bit scale
850 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
851 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
852
853 values = []
854 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
855 quantized_min = min(ix)
856 quantized_max = max(ix)
857 for x in ix:
858 input_value = x - zp_in
859 input_value_hires = input_value * 128
860 # Compute the input value on essentially the output scale, not shifted yet
861 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
862 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
863 relu_value = np.int16(input_value_hires)
864 if relu_shift < 31:
865 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
866
867 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
868
869 if relu_shift < 31:
870 relu_value = fp_math.shift_left16(relu_value, 1)
871
872 if relu_shift > 31:
873 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
874
875 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
876 # Now convert that to a 16bit fixedpoint value in [0, 1]
877 relu_value = (relu_value + (1 << 15)) >> 1
878 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
879 shift = 31 - out_shift
880 shift = -shift if shift < 0 else 0
881 # Finally apply the output shift
882 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
883 lut_result = min(quantized_max, max(quantized_min, lut_result))
884 values.append(lut_result)
885 return convert_to_lut(op, values, "hardswish")
886 return op
887
888
889def convert_lrelu_to_mul_max(op, arch):
890 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
891 # (the opposite of convert_mul_max_to_abs_or_lrelu)
892 ifm, ofm = op.get_ifm_ofm()
893 if ifm is None or ofm is None:
894 return op
895
896 # Add multiplication with alpha
897 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
898 mul_alpha.add_input_tensor(ifm)
899 # Create const tensor containing alpha as scalar
900 alpha = op.attrs["alpha"]
901 quantization = ifm.quantization.clone()
902 quantization.min = 0
903 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
904 quantization.zero_point = 0
905 if np.isinf(1 / np.float32(alpha)):
906 # Handling of alpha near zero
907 quantization.scale_f32 = 1
908 scalar = 0
909 else:
910 quantization.scale_f32 = alpha
911 scalar = alpha
912 alpha_tens = create_const_tensor(
913 op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
914 )
915 alpha_tens.quant_values = np.array([1])
916 mul_alpha.add_input_tensor(alpha_tens)
917 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
918 mul_alpha.set_output_tensor(fm_alpha)
919 mul_alpha.set_ifm_ofm_shapes()
920 DebugDatabase.add_optimised(op, mul_alpha)
921
922 if check_quantized_tens_scaling_equal(ifm, ofm):
923 # No identity multiplication is needed
924 fm_id = ifm
925 else:
926 # Add multiplication with identity
927 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
928 mul_identity.add_input_tensor(ifm)
929 # Create const tensor containing identity as scalar
930 quantization = ifm.quantization.clone()
931 quantization.min = 0
932 quantization.max = quantization.quant_max - quantization.quant_min
933 quantization.scale_f32 = 1
934 quantization.zero_point = 0
935 identity_tens = create_const_tensor(
936 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
937 )
938 mul_identity.add_input_tensor(identity_tens)
939 # Make sure that fm_id is allocated to a different address than fm_alpha
940 fm_id = ofm.clone(op.name + "_id", set_unique=True)
941 mul_identity.set_output_tensor(fm_id)
942 mul_identity.set_ifm_ofm_shapes()
943 DebugDatabase.add_optimised(op, mul_identity)
944
945 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
946 op.type = Op.Maximum
947 op.name = op.name.replace("LeakyRelu", "Maximum")
948 op.inputs = []
949 ifm.consumer_list.remove(op)
950 op.add_input_tensor(fm_alpha)
951 op.add_input_tensor(fm_id)
952 op.set_ifm_ofm_shapes()
953
954 DebugDatabase.add_optimised(op, op)
955 return op
956
957
958def convert_to_lut(op, lut_values, lut_name):
959 # Rewrite the operation by Add with scalar 0 + LUT activation
960 ifm = op.inputs[0]
961 if ifm is None:
962 return op
963 assert ifm.dtype.size_in_bytes() == 1
964 op.type = Op.Add
965 op.name = op.name + "_lut_" + lut_name
966 # Mark as no-op to enable potential fusing optimizations
967 op.attrs["is_nop"] = True
968 # Create an input tensor containing scalar zero
969 quantization = QuantizationParameters(0.0, 255.0)
970 quantization.scale_f32 = ifm.quantization.scale_f32
971 quantization.zero_point = 0
972 tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization)
973 op.add_input_tensor(tens)
974 op.ifm_shapes.append(Shape4D(tens.shape))
975
976 # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
977 # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
978 # should be the same as the IFM
979 op.forced_output_quantization = ifm.quantization
980 lut_tensor = lut.create_lut_tensor(op.name + "_values", lut_values, DataType.int8)
981 op.set_activation_lut(lut_tensor)
982 op.set_ifm_ofm_shapes()
983 return op
984
985
986def convert_to_lut8(op, fn, fn_name):
987 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
988 # fn is a function(real) -> real
989 ifm, ofm = op.get_ifm_ofm()
990 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
991 return op
992 # Generate the LUT
993 ifm_scale = np.double(ifm.quantization.scale_f32)
994 ofm_scale = np.double(ofm.quantization.scale_f32)
995 zp_in = ifm.quantization.zero_point
996 zp_out = ofm.quantization.zero_point
997 values = []
998 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
999 quantized_min = min(ix)
1000 quantized_max = max(ix)
1001 for x in ix:
1002 x_real = ifm_scale * (x - zp_in)
1003 y_real = fn(x_real)
1004 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
1005 lut_result = min(quantized_max, max(quantized_min, lut_result))
1006 values.append(lut_result)
1007 return convert_to_lut(op, values, fn_name)
1008
1009
1010def convert_lrelu_to_lut(op, arch):
1011 ifm, ofm = op.get_ifm_ofm()
1012 # Generate the LUT
1013 alpha = op.attrs["alpha"]
1014 ifm_scale = np.double(ifm.quantization.scale_f32)
1015 ofm_scale = np.double(ofm.quantization.scale_f32)
1016 zp_in = ifm.quantization.zero_point
1017 zp_out = ofm.quantization.zero_point
1018 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
1019 alpha_scalar = 1
1020 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
1021 if "alpha_scaling" in op.attrs:
1022 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
1023 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
1024 values = []
1025 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
1026 quantized_min = min(ix)
1027 quantized_max = max(ix)
1028 for x in ix:
1029 if x < zp_in:
1030 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1031 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1032 )
1033 else:
1034 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1035 lut_result = min(quantized_max, max(quantized_min, lut_result))
1036 values.append(lut_result)
1037 return convert_to_lut(op, values, "lrelu")
1038
1039
1040def convert_lrelu(op, arch, nng):
1041 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
1042 if op.type != Op.LeakyRelu:
1043 return op
1044 ifm, ofm = op.get_ifm_ofm()
1045 if ifm is None or ofm is None:
1046 return op
1047 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1048 # use LUT for int8/uint8
1049 return convert_lrelu_to_lut(op, arch)
1050 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
1051 # use LeakyRelu unmodified for int16 with equal input/output scaling
1052 return op
1053 return convert_lrelu_to_mul_max(op, arch)
1054
1055
1056def convert_tanh_sigmoid_to_lut(op, arch, nng):
1057 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
1058 if op.type == Op.Sigmoid:
1059 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1060 elif op.type == Op.Tanh:
1061 return convert_to_lut8(op, math.tanh, "tanh")
1062 return op
1063
1064
1065def remove_reshapes(op, arch):
1066 if op.run_on_npu and op.type == Op.Reshape:
1067 ofm = op.ofm
1068 ifm = op.ifm
1069
1070 # Check if quantization is the same in the input and output for the reshape ops
1071 if not check_quantized_tens_scaling_equal(ifm, ofm):
1072 # TODO Both tensors are needed, since quantisation properties currently are linked to Tensors.
1073 # In order to remove this reshape either quantization properties need to be moved to Operator,
1074 # or the reshape need to be replace with a NOP.
1075 return
1076
1077 # Check if Reshape ifm/ofm are network ifm/ofm
1078 ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
1079 ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list)
1080 ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list)
1081 # Check if ifm/ofm is produced repectivly consumed by CPU
1082 ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
1083 ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
1084
1085 # This case should be handled prior to this function
1086 assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed))
1087
1088 if ofm_is_sg_ofm or ofm_is_cpu_consumed:
1089 # Bypassed by replacing ifm with ofm
1090 ofm.ops = []
1091 for prev_op in ifm.ops:
1092 prev_op.outputs = [ofm]
1093 ofm.ops.append(prev_op)
1094
1095 # All ifm consumers need to use ofm as input
1096 for ifm_cons in ifm.consumer_list:
1097 for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
1098 if cons_ifm == ifm:
1099 ifm_cons.set_input_tensor(ofm, ifm_idx)
1100 else:
1101 # Bypassed Reshape by replacing ofm with ifm
1102 for cons in ofm.consumer_list:
1103 for ifm_idx, cons_ifm in enumerate(cons.inputs):
1104 if cons_ifm == ofm:
1105 cons.set_input_tensor(ifm, ifm_idx)
1106
1107
1108def fuse_activation_function_with_prev(op, arch, nng):
1109 # if op is a no-op: attempts to move the activation function to the preceding op
1110 if not op.attrs.get("is_nop", False) or op.activation is None:
1111 return op
1112 ifm, ofm = op.get_ifm_ofm()
1113 if ifm is None or ofm is None:
1114 return op
1115 # finds the input(s) to the operation
1116 prev_op = ifm.ops[0]
1117 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1118 fuse = (
1119 prev_op.run_on_npu
1120 and prev_op.type.npu_block_type != NpuBlockType.Default
1121 and len(ifm.ops) == 1
1122 and len(prev_op.outputs[0].consumers()) == 1
1123 and prev_op.activation is None
1124 )
1125 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1126 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1127 # LUT currently only works correctly for elementwise ops
1128 fuse = False
1129 if not fuse:
1130 return op
1131 # Move the fused activation function + corresponding info to prev_op
1132 prev_op.activation = op.activation
1133 prev_op.forced_output_quantization = op.forced_output_quantization
1134 if op.activation_lut is not None:
1135 prev_op.set_activation_lut(op.activation_lut)
1136 # Bypass op
1137 prev_op.set_output_tensor(ofm)
1138 DebugDatabase.add_optimised(op, prev_op)
1139 return op
1140
1141
1142def _leading_pad_ok(leading_pad, stride, kernel_size):
1143 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1144 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1145 max_size = kernel_size // 2
1146 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1147
1148
1149def replace_pad_by_hw_pad(op: Operation, arch, nng):
1150 """
1151 Tries to completely remove a PAD operator by using hardware padding.
1152 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1153 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1154 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1155 if both operations can be run on the NPU.
1156 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1157 """
1158 if (
1159 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
1160 and op.run_on_npu
1161 and op.attrs["padding"] == Padding.VALID
1162 ):
1163 pad_op = op.ifm.ops[0]
1164 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1165 return op
1166 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1167 return op
1168 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1169 k = op.kernel
1170 k_w, k_h = k.dilated_wh()
1171
1172 # Check if the PAD operator can be replaced by hardware padding
1173 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1174 # Too much padding, it would require hardware padding to actually insert zeros
1175 return op
1176 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1177 return op
1178
1179 if op.type.is_avgpool_op():
1180 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1181 for pad, k_size in (
1182 (left, k_w),
1183 (right, k_w),
1184 (top, k_h),
1185 (bottom, k_h),
1186 ):
1187 if pad not in (0, k_size // 2):
1188 return op
1189 # Average pool is converted to depthwise, because NPU average pool + same padding
1190 # has a special implementation that is different from PAD followed by average pool with
1191 # valid padding.
1192 k_w, k_h = op.kernel.width, op.kernel.height
1193 ifm = op.ifm
1194 # Remember other inputs
1195 other_inputs = op.inputs[1:]
1196 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1197 quantization = QuantizationParameters(0.0, 255.0)
1198 quantization.scale_f32 = 1.0 / (k_w * k_h)
1199 quantization.zero_point = 0
1200 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1201 weights = np.full(shape, 1)
1202
1203 weight_tens = create_const_tensor(
1204 op.name + "_weights",
1205 shape,
1206 op.ifm.dtype,
1207 weights,
1208 np.uint8,
1209 purpose=TensorPurpose.Weights,
1210 quantization=quantization,
1211 )
1212 weight_tens.quant_values = weights
1213 op.type = Op.DepthwiseConv2DBias
1214 op.inputs = []
1215 op.add_input_tensor(ifm)
1216 op.add_input_tensor(weight_tens)
1217 # Add bias tensor, all biases set to 0
1218 op.inputs.append(None)
1219 fixup_bias_tensors(op, arch, nng)
1220 # Add other inputs
1221 op.inputs.extend(other_inputs)
1222 op.rounding_mode = NpuRoundingMode.NATURAL
1223
1224 # Bypass the PAD operator
1225 op.set_input_tensor(pad_op.ifm, 0)
1226 # Adjust the padding attributes of the convolution operator
1227 op.attrs["padding"] = Padding.EXPLICIT
1228 op.attrs["explicit_padding"] = (top, left, bottom, right)
1229 op.set_ifm_ofm_shapes()
1230 return op
1231
1232
1233def convert_pad(op: Operation, arch, nng):
1234 """
1235 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1236 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1237 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1238 """
1239 if op.type != Op.Pad or not op.run_on_npu:
1240 return op
1241 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1242
1243 ifm = op.ifm
1244 assert ifm is not None
1245 ifm_shape = Shape4D(ifm.shape)
1246 ofm = op.ofm
1247 assert ofm is not None
1248 ofm.ops = []
1249 ofm_shape = op.ofm_shapes[0]
1250
1251 # Average pool op that copies IFM to the right place inside the OFM
1252 shp0 = Shape4D(0, 0, 0, 0)
1253 shp_top = shp0.with_height(top)
1254 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1255 avgpool_op.activation = op.activation
1256 quant = ofm.quantization
1257 pad_value = quant.zero_point
1258 # Add operations that fill the borders of the OFM
1259 if top > 0:
1260 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1261 zero_tens = create_const_tensor(
1262 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1263 )
1264 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1265 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1266 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1267 if bottom > 0:
1268 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1269 zero_tens = create_const_tensor(
1270 op.name + "_bottom",
1271 shape.as_list(),
1272 ofm.dtype,
1273 shape.elements() * [pad_value],
1274 np.uint8,
1275 quantization=quant,
1276 )
1277 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1278 create_avg_pool_for_concat(
1279 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1280 )
1281 if left > 0:
1282 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1283 zero_tens = create_const_tensor(
1284 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1285 )
1286 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1287 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1288 if right > 0:
1289 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1290 zero_tens = create_const_tensor(
1291 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1292 )
1293 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1294 create_avg_pool_for_concat(
1295 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1296 )
1297
1298 op.type = Op.ConcatTFLite
1299 return avgpool_op
1300
1301
1302def add_attrs_to_resizebilinear(op, arch, nng):
1303 if op.type == Op.ResizeBilinear and op.run_on_npu:
1304 input_tensor = op.inputs[0]
1305 input_shape = op.ifm_shapes[0]
1306 upscaled_height = input_shape.height * 2
1307 upscaled_width = input_shape.width * 2
1308 out_shape = op.ofm_shapes[0]
1309 if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
1310 # this means the output is supposed to be a x2 upscale,
1311 # so we need to do SAME padding
1312 op.attrs["padding"] = Padding.SAME
1313 elif (
1314 op.attrs["align_corners"]
1315 and out_shape.height == (upscaled_height - 1)
1316 and out_shape.width == (upscaled_width - 1)
1317 ):
1318 # here we can just run the avg pool without padding and
1319 # produce a (M * 2 - 1, N * 2 - 1) sized output
1320 op.attrs["padding"] = Padding.VALID
1321 else:
1322 return op
1323 input_tensor.resampling_mode = resampling_mode.NEAREST
1324 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
1325 return op
1326
1327
1328def fixup_bias_tensors(op, arch, nng):
1329 if op.type.needs_bias() and op.bias is None:
1330 # Op has no bias, add bias tensor filled with zeros
1331 nr_biases = op.inputs[1].shape[-1]
1332 bias_values = [0] * nr_biases
1333 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
1334 bias_tensor.quant_values = bias_tensor.values
1335 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1336
1337 return op
1338
1339
1340def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1341 if op.type == Op.Mean and op.run_on_npu:
1342 keep_dims = op.attrs.get("keep_dims", False)
1343 inp, axis = op.inputs
1344 shape = inp.shape
1345 dims = len(shape)
1346
1347 # Height and width axes have different index depending on dimensions
1348 if axis.shape == [] or axis.shape[0] == 1: # single axis
1349 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1350 if dims in (2, 3):
1351 if axis == 0:
1352 h, w = shape[axis], 1
1353 else:
1354 h, w = 1, shape[axis]
1355 else:
1356 if axis == 1:
1357 h, w = shape[axis], 1
1358 else:
1359 h, w = 1, shape[axis]
1360 else: # multiple axes
1361 axis = sorted(axis.values)
1362 h, w = [shape[i] for i in axis]
1363
1364 # Set necessary depthwise attributes
1365 op.attrs.update(
1366 {
1367 "padding": Padding.VALID,
1368 "stride_h": 1,
1369 "stride_w": 1,
1370 "strides": (1, 1, 1, 1),
1371 "depth_multiplier": 1,
1372 "channel_multiplier": 1,
1373 "dilation_h_factor": 1,
1374 "dilation_w_factor": 1,
1375 "dilation": (1, 1, 1, 1),
1376 }
1377 )
1378 # Change op type
1379 op.type = Op.DepthwiseConv2DBias
1380 # Set IFM/OFM shapes after changing op type
1381 op.set_ifm_ofm_shapes()
1382
1383 weight_scale, bias = 1, None
1384 ofmq, ifmq = op.ofm.quantization, inp.quantization
1385 # Set rounding mode, scaling and zero point based on which reference implementation to match
1386 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1387 if inp.dtype == DataType.uint8:
1388 # This attribute means a different scaling calculation is used in order to match reference
1389 op.low_precision_scaling = True
1390 weight_scale = h * w
1391 # Set zero points to 0 as they will be adjusted for with bias term
1392 foq = ofmq.clone()
1393 foq.zero_point = 0
1394 fiq = ifmq.clone()
1395 fiq.zero_point = 0
1396 op.forced_input_quantization = fiq
1397 bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
1398 # If the bias term is outside uint8 range, we need an Add op to apply it.
1399 if bias_term < 0 or bias_term > 255:
1400 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1401 # Bias term has higher bitness (i32) than input/output (u8).
1402 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1403 # the bias can only effectively assume values in the range [-255, 255].
1404 intermediate.dtype = DataType.int16
1405 intermediate.quantization.zero_point = 0
1406 add_op = Operation(Op.Add, op.name + "_bias")
1407 add_op.forced_output_quantization = foq
1408 add_op.add_input_tensor(intermediate)
1409 quant = QuantizationParameters()
1410 quant.zero_point = 0
1411 bias_term_tens = create_const_tensor(
1412 op.name + "_bias",
1413 [1, 1, 1, 1],
1414 DataType.int16,
1415 [bias_term],
1416 np.int16,
1417 quantization=quant,
1418 quant_value_dtype=np.int16,
1419 )
1420 add_op.add_input_tensor(bias_term_tens)
1421 add_op.set_output_tensor(op.ofm)
1422 add_op.set_ifm_ofm_shapes()
1423 add_op.activation = op.activation
1424 op.activation = None
1425 op.set_output_tensor(intermediate)
1426 op.set_ifm_ofm_shapes()
1427 # If not, we can just do it with the OFM zero point.
1428 else:
1429 foq.zero_point = bias_term
1430 op.forced_output_quantization = foq
1431 else:
1432 assert inp.dtype == DataType.int8
1433 # Use a depthwise to calculate the sum,
1434 # followed by a multiplication with 1/N to get the MEAN
1435 weight_scale = 1
1436 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1437 intermediate.dtype = DataType.int16
1438 mul_op = Operation(Op.Mul, op.name + "_mul")
1439 mul_op.add_input_tensor(intermediate)
1440 # Create scalar containing 1/N
1441 quant = QuantizationParameters()
1442 quant.zero_point = 0
1443 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1444 # while rounding mode NATURAL would round this to -1.
1445 # This can only occur if N is even, and can be emulated by
1446 # multiplying with a number that is slightly smaller than 1/N.
1447 # It must be so small that other roundings are not affected;
1448 # the calculated value is based on worst case,
1449 # which is sum 256 * N (the maximum sum that can occur with int8)
1450 n = int(h * w)
1451 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1452 quant.scale_f32 = 1 / (n - eps)
1453 scalar = create_const_tensor(
1454 op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
1455 )
1456 mul_op.add_input_tensor(scalar)
1457 mul_op.set_output_tensor(op.ofm)
1458 mul_op.set_ifm_ofm_shapes()
1459 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1460 mul_op.activation = op.activation
1461 op.activation = None
1462 op.set_output_tensor(intermediate)
1463 op.set_ifm_ofm_shapes()
1464 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1465 # Here we can just use a simple AvgPool with truncating rounding,
1466 # as we're emulating simple integer division.
1467 op.rounding_mode = NpuRoundingMode.TRUNCATE
1468 op.type = Op.AvgPool
1469 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1470 else:
1471 op.rounding_mode = NpuRoundingMode.NATURAL
1472 weight_scale = 1 / (h * w)
1473 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1474 bias = -ifmq.zero_point * h * w
1475 fiq = ifmq.clone()
1476 fiq.zero_point = 0
1477 op.forced_input_quantization = fiq
1478
1479 # Change dimensions to 4
1480 if dims < 4:
1481 shape = [1] + shape
1482 if dims == 2:
1483 shape += [1]
1484
1485 # If height is greater than max kernel height, reshape to from HxW to 1x(HxW)
1486 if h > 64:
1487 shape = [shape[0], 1, h * w, shape[3]]
1488 op.ifm_shapes[0] = Shape4D(shape)
1489 if h > 256 and op.type == Op.AvgPool:
1490 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1491
1492 # If the AvgPool version is used, we don't need to do anything else
1493 if op.type == Op.AvgPool:
1494 return op
1495
1496 # Make unit weight tensor quantization
1497 weight_quant = ifmq.clone()
1498 weight_quant.min = 0
1499 weight_quant.max = 255
1500 weight_quant.scale_f32 = weight_scale
1501 weight_quant.zero_point = 0
1502
1503 # Set weight shape to [H,W,C,B]
1504 weight_shape = shape[1:4] + [shape[0]]
1505 # Add unit weight tensor
1506 op.set_input_tensor(
1507 create_const_tensor(
1508 "weights",
1509 weight_shape,
1510 inp.dtype,
1511 np.ones(weight_shape),
1512 value_dtype=np.uint8,
1513 quantization=weight_quant,
1514 ),
1515 1,
1516 )
1517 op.weights.quant_values = np.reshape(op.inputs[1].quant_values, weight_shape)
1518
1519 # Add None bias tensor
1520 op.inputs.append(None)
1521 # Add bias tensor
1522 if bias:
1523 bias_shape = [shape[-1]]
1524 op.set_input_tensor(
1525 create_const_tensor(
1526 "bias",
1527 bias_shape,
1528 inp.dtype,
1529 np.ones(bias_shape) * bias,
1530 value_dtype=np.int32,
1531 quant_value_dtype=np.int32,
1532 quantization=None,
1533 ),
1534 2,
1535 )
1536
1537 return op
1538
1539
1540def supported_operator_check(op, arch, nng):
1541 op.run_on_npu = arch.supported_operators.is_operator_supported(op)
1542 return op
1543
1544
1545def tflite_optimise_graph(nng, arch):
1546 # Pre-processing step
1547 pre_process_list = [
1548 supported_operator_check,
1549 set_ifm_ofm_op_shapes,
1550 ]
1551
1552 for idx, sg in enumerate(nng.subgraphs):
1553 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1554 nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
1555 )
1556
1557 # Handle Concat Ops
1558 for idx, sg in enumerate(nng.subgraphs):
1559 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1560 sg.refresh_after_modification()
1561
1562 # Handle Split Ops
1563 for idx, sg in enumerate(nng.subgraphs):
1564 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1565 nng,
1566 sg,
1567 arch,
1568 [],
1569 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1570 rewrite_unsupported=False,
1571 )
1572
1573 for idx, sg in enumerate(nng.subgraphs):
1574 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1575 nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False,
1576 )
1577
1578 # Handle sg input output
1579 for idx, sg in enumerate(nng.subgraphs):
1580 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1581 nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
1582 )
1583
1584 # Removal of reshapes
1585 for sg in nng.subgraphs:
1586 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
1587 sg.refresh_after_modification()
1588
1589 # Rewrite of operators
1590 op_rewrite_list = [
1591 set_tensor_equivalence,
1592 convert_mean_to_depthwise_conv_or_avgpool,
1593 convert_depthwise_to_conv,
1594 convert_conv_to_fc,
1595 convert_softmax,
1596 optimise_strided_conv,
1597 convert_hardswish_to_lut,
1598 rewrite_fully_connected_input,
1599 convert_batched_fc_shape,
1600 fixup_conv2d_backprop,
1601 fixup_relus_with_differing_ifm_ofm_scaling,
1602 fixup_elementwise_with_scalars,
1603 reorder_depthwise_weights,
1604 fixup_resizebilinear,
1605 fixup_bias_tensors,
1606 convert_mul_max_to_abs_or_lrelu,
1607 convert_lrelu,
1608 convert_tanh_sigmoid_to_lut,
1609 replace_pad_by_hw_pad,
1610 ]
1611
1612 for idx, sg in enumerate(nng.subgraphs):
1613 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1614 nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
1615 )
1616
1617 for idx, sg in enumerate(nng.subgraphs):
1618 # remove passthrough tensors and attempt further optimizations
1619 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1620 nng,
1621 sg,
1622 arch,
1623 [remove_passthrough_tensor],
1624 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1625 )
1626
1627 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1628 # since ifm/ofm_shapes are of importance to this function
1629 for sg in nng.subgraphs:
1630 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1631 sg.refresh_after_modification()
1632
1633 return nng