blob: 7526f46a836be22cdfd91d5945d3a857d48fe970 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
25from . import lut
26from . import rewrite_graph
27from . import scaling
28from .api import NpuRoundingMode
29from .data_type import DataType
30from .debug_database import DebugDatabase
31from .errors import UnsupportedFeatureError
32from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Patrik Gustavssondf995102021-08-23 15:33:59 +020033from .graph_optimiser_util import bypass_reshape_and_squeeze_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020034from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020035from .graph_optimiser_util import convert_depthwise_to_conv
36from .graph_optimiser_util import fix_sg_input_output
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020037from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020038from .graph_optimiser_util import needed_total_padding
39from .graph_optimiser_util import set_ifm_ofm_op_shapes
40from .graph_optimiser_util import set_tensor_equivalence
41from .numeric_util import clamp_sigmoid
42from .numeric_util import full_shape
43from .numeric_util import round_away_zero
44from .operation import create_activation_function
45from .operation import NpuBlockType
46from .operation import Op
47from .operation import Operation
48from .operation import Padding
49from .operation_util import create_avgpool_nop
50from .operation_util import get_pad_values_from_input
51from .shape4d import Shape4D
52from .softmax import SoftMax
53from .tensor import check_quantized_tens_scaling_equal
54from .tensor import create_const_tensor
55from .tensor import create_equivalence_id
56from .tensor import QuantizationParameters
57from .tensor import Tensor
58from .tensor import TensorPurpose
59from .tflite_mapping import optype_to_builtintype
60
61passthrough_nodes = (Op.Identity,)
62
63
64def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
65 """Creates an average pool for the given concat op/input feature map"""
66 ofm = concat_op.ofm
67 avgpool_op = create_avgpool_nop(name)
68 avgpool_op.inputs = [ifm]
69 avgpool_op.outputs = [ofm]
70
71 avgpool_op.write_offset = write_offset
72 avgpool_op.write_shape = ifm_shape
73 ofm.ops.append(avgpool_op)
74 DebugDatabase.add_optimised(concat_op, avgpool_op)
75 avgpool_op.ifm_shapes.append(ifm_shape)
76 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
77 avgpool_op.memory_function = Op.ConcatSliceWrite
78 return avgpool_op
79
80
81def remove_passthrough_tensor(tens, arch, nng):
82 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
83 assert len(tens.ops[0].inputs) == 1
84 tens = tens.ops[0].inputs[0]
85 return tens
86
87
88def rewrite_concat_ops(op, arch):
89 if not op.run_on_npu or not op.type.is_concat_op():
90 return
91
92 axis_4D = 0
93 ofm = op.ofm
94 ofm.ops = []
95 offset = 0
96
97 unfuse_activation_function(op)
98
99 if op.type == Op.Pack:
100 # Pack is also referred to as Stack
101 axis = int(op.attrs["axis"])
102 if axis < 0: # Convert to positive axis
103 axis = len(op.inputs[0].shape) + 1 + axis
104
105 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
106
107 axis_4D = axis + (4 - len(desired_shape))
108
109 for idx, inp in enumerate(op.inputs):
110 op.ifm_shapes[idx] = Shape4D(desired_shape)
111 op.type = Op.PackReshaped
112
113 inputs, axis = op.get_concat_inputs_axis()
114 for idx, inp in enumerate(inputs):
115 if op.type != Op.PackReshaped:
116 op.ifm_shapes[idx] = Shape4D(inp.shape)
117 if axis >= 0:
118 axis_4D = axis + (4 - len(inp.shape))
119 else:
120 axis_4D = axis
121 write_offset = [0, 0, 0, 0]
122 write_offset[axis_4D] = offset
123 concat_end = offset + op.ifm_shapes[idx][axis_4D]
124 create_avg_pool_for_concat(
125 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
126 )
127 offset = concat_end
128 assert ofm.shape[axis] == offset
129
130 return op
131
132
133def rewrite_split_ops(tens, arch, nng):
134
135 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
136 split_op = tens.ops[0]
137
138 # Not supported so leave it and run on CPU
139 if not split_op.run_on_npu:
140 return tens
141
142 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
143
144 tens.ops = []
145 new_op = Operation(Op.SplitSliceRead, split_op.name)
146 new_op.inputs = [inp]
147 ofm_shape_idx = 0
148 read_shape = offset_end
149
150 # For Split the offset cannot be extracted from the tensor so it has to
151 # be calculated from the index of the output tensor
152 if axis is not None:
153 # Get the start and end of the split
154 offset_start = [0] * 4
155 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
156 for idx, out in enumerate(outputs):
157 if axis_4D_list is not None:
158 axis_4D = axis_4D_list[idx]
159 else:
160 split_op.ofm_shapes[idx] = Shape4D(out.shape)
161 if axis >= 0:
162 axis_4D = axis + (4 - len(out.shape))
163 else:
164 axis_4D = axis
165
166 if out == tens:
167 ofm_shape_idx = idx
168 read_shape = split_op.ofm_shapes[idx]
169 break
170
171 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
172
173 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
174 new_op.read_shapes[0] = read_shape
175 new_op.run_on_npu = True
176 new_op.set_output_tensor(tens)
177 new_op.ifm_shapes.append(Shape4D(inp.shape))
178 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
179 DebugDatabase.add_optimised(split_op, new_op)
180
181 return tens
182
183
184def remove_SplitSliceRead(op, arch):
185
186 if op.type == Op.SplitSliceRead:
187 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
188 if (
189 len(op.ofm.consumer_list) == 1
190 and op.ofm.consumer_list[0] is not None
191 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200192 and op.ofm.consumer_list[0].type not in (Op.Reshape, Op.Squeeze)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200193 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
194 ):
195 # SplitSliceRead can be performed by tensor consumer
196 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200197 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200198 else:
199 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
200 avgpool_op.add_input_tensor(op.ifm)
201 avgpool_op.outputs = [op.ofm]
202 op.ofm.ops.remove(op)
203 op.ofm.ops.append(avgpool_op)
204 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
205 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
206 avgpool_op.read_offsets[0] = op.read_offsets[0]
207 avgpool_op.read_shapes[0] = op.read_shapes[0]
208
209 op.ifm.consumer_list.remove(op)
210 DebugDatabase.add_optimised(op, avgpool_op)
211
212
213def insert_copy_op_after_tens(tens):
214 tens_cons_list_copy = tens.consumer_list.copy()
215
216 # Create a avg_pool nop op with ifm as input
217 copy_tens = tens.clone()
218 copy_op = create_avgpool_nop(tens.name + "_avgpool")
219 copy_op.add_input_tensor(tens)
220 copy_op.set_output_tensor(copy_tens)
221 copy_op.set_ifm_ofm_shapes()
222 copy_op.run_on_npu = True
223
224 # Set copy_ifm consumers
225 for tens_cons in tens_cons_list_copy:
226 if tens_cons is not None:
227 for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
228 if cons_inp == tens:
229 tens_cons.set_input_tensor(copy_tens, ifm_idx)
230
231 DebugDatabase.add_optimised(tens.ops[0], copy_op)
232
233
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200234def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
235 k_w, k_h = kernel.dilated_wh()
236 s_x, s_y = kernel.stride
237 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
238 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
239 if padding_type == Padding.SAME:
240 left_pad = (xpad + 0) // 2
241 right_pad = (xpad + 1) // 2
242 top_pad = (ypad + 0) // 2
243 bottom_pad = (ypad + 1) // 2
244 elif padding_type == Padding.VALID:
245 left_pad = 0
246 right_pad = 0
247 top_pad = 0
248 bottom_pad = 0
249 elif padding_type == Padding.EXPLICIT:
250 # Padding is specified in a PAD operator which has been bypassed.
251 top, left, bottom, right = explicit_padding
252 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
253 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
254 else:
255 raise UnsupportedFeatureError(f"Unknown padding")
256 padding = (top_pad, left_pad, bottom_pad, right_pad)
257 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
258 return padding, skirt
259
260
261def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
262 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
263 if padding_type == Padding.SAME:
264 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
265 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
266 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
267 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
268 left_pad = max(kernel_width - 1 - right_pad, 0)
269 top_pad = max(kernel_height - 1 - bottom_pad, 0)
270 elif padding_type == Padding.VALID:
271 right_pad = max(kernel_width - 2, 0)
272 bottom_pad = max(kernel_height - 2, 0)
273 left_pad = kernel_width - 1
274 top_pad = kernel_height - 1
275 else:
276 raise UnsupportedFeatureError(f"Unknown padding")
277 padding = (top_pad, left_pad, bottom_pad, right_pad)
278 skirt = padding
279 return padding, skirt
280
281
282def fixup_conv2d_backprop(op, arch, nng):
283 if op.type == Op.Conv2DBackpropInput:
284 # flip the inputs
285 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
286 op.type = Op.Conv2DBackpropInputSwitchedBias
287 op.ifm.resampling_mode = resampling_mode.TRANSPOSE
288
289 # Update strides
290 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
291
292 return op
293
294
295# Convert the op to an elementwise add
296def convert_resizebilinear_1x1_to_add(op):
297 op.type = Op.Add
298 op.name = op.name + "_add"
299 op.attrs["resizebilinear"] = True
300 # Create an input tensor filled with zeros
301 shape = op.ofm_shapes[0].as_list()
302 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100303 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200304 tens.quantization = QuantizationParameters(0.0, 255.0)
305 tens.quantization.scale_f32 = 1.0
306 tens.quantization.zero_point = 0
307 tens.consumer_list = [op]
308 tens_op = op.inputs[1].ops[0]
309 tens_op.set_output_tensor(tens)
310 # Set the add inputs
311 op.inputs[1] = op.inputs[0]
312 op.inputs[0] = tens
313 op.set_ifm_ofm_shapes()
314
315 return op
316
317
318# Convert ResizeBilinear to a number of 2x2 pool ops
319def convert_resizebilinear_to_2x2_pool(op):
320 count = 0
321 pre_op = op
322 outputs = op.outputs
323
324 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
325 if op.attrs["align_corners"]:
326 shape_modifier = 1
327 op.attrs["padding"] = Padding.VALID
328 else:
329 shape_modifier = 0
330 op.attrs["padding"] = Padding.SAME
331 op.inputs[0].resampling_mode = resampling_mode.NEAREST
332
333 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
334 out_shape = np.array(op.ofm_shapes[0].get_hw_as_list())
335 if (upscaled_shape == upscaled_shape * 2 - shape_modifier).all():
336 return op
337
338 while (upscaled_shape < out_shape).all():
339 if count == 0:
340 scaled_op = pre_op
341 else:
342 scaled_op = op.clone("_{}".format(count))
343 scaled_op.inputs[0] = pre_op.outputs[0]
344
345 upscaled_shape = upscaled_shape * 2 - shape_modifier
346
347 if (upscaled_shape == out_shape).all():
348 scaled_op.outputs = outputs
349 scaled_op.outputs[0].ops = [scaled_op]
350 else:
351 shape = op.ofm_shapes[0].as_list()
352 shape[1:3] = upscaled_shape
353 out_tens = Tensor(shape, DataType.int16, "{}_{}".format(op.outputs[0].name, count))
354 out_tens.quantization = op.outputs[0].quantization.clone()
355 out_tens.quantization.quant_min = np.iinfo(np.int16).min
356 out_tens.quantization.quant_max = np.iinfo(np.int16).max
357 scaled_op.set_output_tensor(out_tens)
358 pre_op = scaled_op
359 count += 1
360
361 # Setup the scale value
362 if scaled_op.inputs[0].dtype.bits == 8 and scaled_op.outputs[0].dtype.bits == 16:
363 scaled_op.rescale = 128
364 elif scaled_op.inputs[0].dtype.bits == 16 and scaled_op.outputs[0].dtype.bits == 8:
365 scaled_op.rescale = 1 / 128
366 else:
367 scaled_op.rescale = None
368 scaled_op.set_ifm_ofm_shapes()
369
370 return op
371
372
373def fixup_resizebilinear(op, arch, nng):
374 if op.type == Op.ResizeBilinear and op.run_on_npu:
375 if op.ifm_shapes[0] == op.ofm_shapes[0]:
376 # Bypass nop resizebilinear
377 op.inputs = op.inputs[:1]
378 op.type = Op.Identity
379 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
380 convert_resizebilinear_1x1_to_add(op)
381 else:
382 convert_resizebilinear_to_2x2_pool(op)
383
384 return op
385
386
387def convert_nop_split_to_identity(op, arch, nng):
388 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
389 # the list comprehension should return a list with a single tensor
390 # if it shouldn't, remove_passthrough_tensor will fail appropriately
391 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
392 op.type = Op.Identity
393 return op
394
395
396def rewrite_fully_connected_input(op, arch, nng):
397 if op.type == Op.FullyConnected:
398 n_in_elems = op.weights.shape[-2]
399 elms = op.ifm.elements()
400 batch_size = elms // n_in_elems
401 assert batch_size * n_in_elems == elms
402
403 op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
404 return op
405
406
407def convert_batched_fc_shape(op, arch, nng):
408 if op.type == Op.FullyConnected:
409 # Check if the first dimension indicates batching
410 if op.ifm_shapes[0].batch > 1:
411 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
412 n = op.ifm_shapes[0].batch
413 h, w = batching_split.get(n, (1, n))
414 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
415
416 # Reshape Weights to be 4D. IO becomes HWIO
417 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100418 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
419 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200420
421 n = op.ofm_shapes[0].batch
422 h, w = batching_split.get(n, (1, n))
423 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
424 return op
425
426
427def unfuse_activation_function(op):
428 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
429 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
430 op.activation = None
431 out_tens = op.outputs[0]
432 intermediate_tens = out_tens.clone("_act_intermediate")
433 act_op.set_output_tensor(out_tens)
434 act_op.add_input_tensor(intermediate_tens)
435 op.set_output_tensor(intermediate_tens)
436 act_op.set_ifm_ofm_shapes()
437
438
439def rewrite_stridedslice_output(op, arch, nng):
440 if not op.run_on_npu or op.type != Op.StridedSlice:
441 return op
442
443 new_axis_mask = op.attrs["new_axis_mask"]
444 shrink_axis_mask = op.attrs["shrink_axis_mask"]
445
446 if shrink_axis_mask == 0 and new_axis_mask == 0:
447 return op
448
449 axis_4D = [0] * len(op.outputs)
450 for idx, out_tens in enumerate(op.outputs):
451 output_shape = list(out_tens.shape)
452
453 if shrink_axis_mask != 0:
454 n = 0
455 axis = 0
456 while shrink_axis_mask:
457 prev_mask = shrink_axis_mask
458 n += 1
459 shrink_axis_mask &= shrink_axis_mask - 1
460 axis = int(math.log2(prev_mask - shrink_axis_mask))
461 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
462
463 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
464 op.attrs["shrink_axis_mask"] = 0
465 if axis >= 0:
466 axis_4D[idx] = axis + (4 - len(output_shape))
467 else:
468 axis_4D[idx] = axis
469 op.ofm_shapes[idx] = Shape4D(output_shape)
470
471 elif new_axis_mask != 0:
472 n = 0
473 axis = 0
474 while new_axis_mask:
475 prev_mask = new_axis_mask
476 n += 1
477 new_axis_mask &= new_axis_mask - 1
478 axis = int(math.log2(prev_mask - new_axis_mask))
479 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
480 new_axis_mask >>= 1
481
482 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
483 op.attrs["new_axis_mask"] = 0
484 if axis >= 0:
485 axis_4D[idx] = axis + (4 - len(output_shape))
486 else:
487 axis_4D[idx] = axis
488 op.ofm_shapes[idx] = Shape4D(output_shape)
489
490 op.attrs["split_axis_4D"] = axis_4D
491 return op
492
493
494def rewrite_unpack_output(op, arch, nng):
495 tens = op.outputs[0]
496 if op.run_on_npu and op.type == Op.Unpack:
497 # Unpack is also referred to as Unstack
498 axis = int(op.attrs["axis"])
499 if axis < 0: # Convert to positive axis
500 axis = len(op.inputs[0].shape) + 1 + axis
501 op.type = Op.UnpackReshaped
502 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
503
504 axis_4D = axis + (4 - len(desired_output_shape))
505 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
506
507 for idx, out_tens in enumerate(op.outputs):
508 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
509 return op
510
511
512def add_padding_fields(op, arch, nng):
513 if op.run_on_npu:
514 if "padding" in op.attrs:
515 input_shape = op.ifm_shapes[0]
516 output_shape = op.ofm_shapes[0]
517 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
518 kernel_size = op.inputs[1].shape[:2]
519 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
520 kernel_size = op.attrs["ksize"][1:3]
521 else:
522 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
523
524 if op.type == Op.Conv2DBackpropInputSwitchedBias:
525 upscaling_factor = output_shape.height // input_shape.height
526 padding, skirt = calc_upscaled_padding_and_skirt(
527 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
528 )
529 else:
530 padding, skirt = calc_padding_and_skirt(
531 op.attrs["padding"], op.kernel, input_shape, op.attrs.get("explicit_padding"),
532 )
533
534 op.attrs["explicit_padding"] = padding
535 op.attrs["skirt"] = skirt
536
537 return op
538
539
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200540def reorder_depthwise_weights(op, arch, nng):
541 if op.type.is_depthwise_conv2d_op():
542 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100543 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
544 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200545 weight_tensor.weight_transpose_depthwise = True
546
547 return op
548
549
550def optimise_strided_conv(op, arch, nng):
551 stride_x, stride_y = op.get_kernel_stride()
552 ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()
553
554 if (
555 op.type == Op.Conv2DBias
556 and op.op_index == 0
557 and stride_x == 2
558 and op.ifm_shapes[0].depth <= 4
559 and op.ifm_shapes[0].width % 2 == 0
560 and weight_tensor is not None
561 and weight_tensor.shape[1] >= 2
562 ):
563 ifm_shape = op.ifm_shapes[0]
564 # IFM
565 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
566
567 # Weights
568 weight_shape = weight_tensor.shape
569 if weight_shape[1] % 2 != 0:
570 weight_shape[1] = weight_shape[1] + 1
571 padded_array = np.zeros(weight_shape)
572 for i in range(weight_shape[0]):
573 padded_array[i] = np.vstack(
574 [
James Peet7519d502021-07-19 16:47:58 +0100575 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200576 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
577 ]
578 )
James Peet7519d502021-07-19 16:47:58 +0100579 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200580 weight_shape[1] //= 2
581 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100582 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200583 weight_tensor.set_all_shapes(weight_shape)
584 # If multiple copies of the weights are used, we could avoid
585 # them having the same address by changing the value_id
586 weight_tensor.value_id = uuid.uuid4()
587
588 # Strides
589 stride_x = 1
590 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
591
592 return op
593
594
595def convert_conv_to_fc(op, arch, nng):
596 # Conv 1x1 can be equivalent to Fully Connected.
597 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
598 # caching/double buffering for the weights.
599 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
600 if op.type == Op.Conv2DBias:
601 h = op.ifm_shapes[0].height
602 w = op.ifm_shapes[0].width
603 kh, kw, _, _ = op.inputs[1].shape
604 if h == 1 and w == 1 and kh == 1 and kw == 1:
605 # Overwrite this op as a Fully Connected Op
606 op.name += "_fc"
607 op.type = Op.FullyConnected
608 op.attrs = {
609 "weights_format": 0,
610 }
611 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
612 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100613 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
614 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200615
616 DebugDatabase.add_optimised(op, op)
617 return op
618
619
620def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
621 if op.run_on_npu and op.type.is_relu_op():
622 ifm = op.inputs[0]
623 ofm = op.outputs[0]
624 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
625 # and requires its own to be inserted
626 if not check_quantized_tens_scaling_equal(ifm, ofm):
627 # Override this op with its own primary op (avgpool)
628 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
629 # And fuse the original activation function to it
630 relu_fused_op.activation = create_activation_function(op.type)
631 # Tidy up and assign the ifm and ofm to the new op
632 ifm.consumer_list.remove(op)
633
634 relu_fused_op.add_input_tensor(ifm)
635 relu_fused_op.set_output_tensor(ofm)
636 relu_fused_op.set_ifm_ofm_shapes()
637 op = relu_fused_op
638 return op
639
640
641def fixup_elementwise_with_scalars(op, arch, nng):
642 if op.type.is_binary_elementwise_op():
643 ifm_tensor, ifm2_tensor, _, _ = op.get_ifm_ifm2_weights_ofm()
644 if ifm2_tensor.shape != [] and ifm_tensor.shape != []:
645 diff = len(ifm_tensor.shape) - len(ifm2_tensor.shape)
646 if diff > 0:
647 ifm2_tensor.shape = full_shape(len(ifm_tensor.shape), ifm2_tensor.shape, 1)
648 elif diff < 0:
649 ifm_tensor.shape = full_shape(len(ifm2_tensor.shape), ifm_tensor.shape, 1)
James Peet7519d502021-07-19 16:47:58 +0100650 elif ifm_tensor.shape == [] and ifm_tensor.values is None:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200651 # IFM is marked as a scalar, but is a result of an operation; change it to a shape of size 1
652 ifm_tensor.shape = len(ifm2_tensor.shape) * [1]
653 ifm_tensor.storage_shape = ifm_tensor.shape
James Peet7519d502021-07-19 16:47:58 +0100654 elif ifm2_tensor.shape == [] and ifm2_tensor.values is None:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200655 # IFM2 is marked as a scalar, but is a result of an operation; change it to a shape of size 1
656 ifm2_tensor.shape = len(ifm_tensor.shape) * [1]
657 ifm2_tensor.storage_shape = ifm2_tensor.shape
658 return op
659
660
661def convert_softmax(op, arch, nng):
662 if op.type == Op.Softmax and op.run_on_npu:
663 softmax = SoftMax(op)
664 op = softmax.get_graph()
665 return op
666
667
668def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
669 r"""Whenever there is a subgraph with this topology:
670
671 Input X For X = -1 or X > 0
672 | \ / This subgraph can be replaced with either
673 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
674 | /
675 Max
676 """
677
678 if op.type == Op.Maximum:
679 # finds the Mul input(s) to the Max
680 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
681 if len(muls) == 1:
682 mul = muls[0].ops[0]
683 elif len(muls) == 2:
684 # In the case both inputs are Muls, find the one with the same input as the Max
685 mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
686 else:
687 # No Mul inputs
688 return op
689
690 # make sure the Mul doesn't have any other consumers
691 mul_ofm = mul.outputs[0]
692 if len(mul_ofm.consumers()) != 1:
693 return op
694 # make sure the Mul doesn't have a fused activation function
695 if mul.activation:
696 return op
697 ifm, ofm = op.get_ifm_ofm()
698 if ifm is None or ofm is None:
699 return op
700
701 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
702 return op
703 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
704 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
705 return op
706
707 # finds the branched input that goes to both the Max and the Mul
708 shared = set(op.inputs) & set(mul.inputs)
709 if len(shared) == 1:
710 shared_in = shared.pop()
711 # find the constant scalar input to the Mul
712 const_tens = (set(mul.inputs) - {shared_in}).pop()
713 # check that it is a scalar
714 if const_tens.shape != []:
715 return op
716 const = const_tens.ops[0]
717 # check that it is a constant
718 if const.type != Op.Const:
719 return op
720 # Remove the Mul from the shared input's consumers
721 shared_in.consumer_list.remove(mul)
722 else:
723 return op
724
725 val = const.outputs[0].values
726 if val >= 0:
727 new_op = Op.LeakyRelu
728 op.attrs["alpha"] = val
729 # to produce bit exact results, the alpha is not enough;
730 # save additional scaling info in attr "alpha_scale", to be used as input
731 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +0100732 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200733 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
734 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
735 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
736 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
737 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
738 elif val == -1:
739 new_op = Op.Abs
740 else:
741 return op
742
743 op.type = new_op
744 op.name = op.name.replace("Maximum", new_op.name)
745 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
746 op.inputs = [shared_in]
747 op.set_ifm_ofm_shapes()
748
749 # Record optimisation in debug database
750 DebugDatabase.add_optimised(op, op)
751
752 return op
753
754
755def convert_hardswish_to_lut(op, arch, nng):
756 if op.type == Op.HardSwish:
757 ifm, ofm = op.get_ifm_ofm()
758 # Generate the LUT
759 ifm_scale = np.double(ifm.quantization.scale_f32)
760 ofm_scale = np.double(ofm.quantization.scale_f32)
761 zp_in = ifm.quantization.zero_point
762 zp_out = ofm.quantization.zero_point
763 ifm_scale_hires = (1 / 128) * ifm_scale
764 relu_multiplier = np.double(3 / 32768)
765 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
766 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
767 # Use 16bit scale
768 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
769 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
770
771 values = []
772 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
773 quantized_min = min(ix)
774 quantized_max = max(ix)
775 for x in ix:
776 input_value = x - zp_in
777 input_value_hires = input_value * 128
778 # Compute the input value on essentially the output scale, not shifted yet
779 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
780 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
781 relu_value = np.int16(input_value_hires)
782 if relu_shift < 31:
783 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
784
785 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
786
787 if relu_shift < 31:
788 relu_value = fp_math.shift_left16(relu_value, 1)
789
790 if relu_shift > 31:
791 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
792
793 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
794 # Now convert that to a 16bit fixedpoint value in [0, 1]
795 relu_value = (relu_value + (1 << 15)) >> 1
796 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
797 shift = 31 - out_shift
798 shift = -shift if shift < 0 else 0
799 # Finally apply the output shift
800 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
801 lut_result = min(quantized_max, max(quantized_min, lut_result))
802 values.append(lut_result)
803 return convert_to_lut(op, values, "hardswish")
804 return op
805
806
807def convert_lrelu_to_mul_max(op, arch):
808 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
809 # (the opposite of convert_mul_max_to_abs_or_lrelu)
810 ifm, ofm = op.get_ifm_ofm()
811 if ifm is None or ofm is None:
812 return op
813
814 # Add multiplication with alpha
815 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
816 mul_alpha.add_input_tensor(ifm)
817 # Create const tensor containing alpha as scalar
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200818 alpha = np.float32(op.attrs["alpha"])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200819 quantization = ifm.quantization.clone()
820 quantization.min = 0
821 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
822 quantization.zero_point = 0
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200823 if np.isinf(1 / alpha):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200824 # Handling of alpha near zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200825 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200826 scalar = 0
827 else:
828 quantization.scale_f32 = alpha
829 scalar = alpha
830 alpha_tens = create_const_tensor(
831 op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
832 )
James Peet7519d502021-07-19 16:47:58 +0100833 alpha_tens.values = np.array([1])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200834 mul_alpha.add_input_tensor(alpha_tens)
835 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
836 mul_alpha.set_output_tensor(fm_alpha)
837 mul_alpha.set_ifm_ofm_shapes()
838 DebugDatabase.add_optimised(op, mul_alpha)
839
840 if check_quantized_tens_scaling_equal(ifm, ofm):
841 # No identity multiplication is needed
842 fm_id = ifm
843 else:
844 # Add multiplication with identity
845 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
846 mul_identity.add_input_tensor(ifm)
847 # Create const tensor containing identity as scalar
848 quantization = ifm.quantization.clone()
849 quantization.min = 0
850 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200851 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200852 quantization.zero_point = 0
853 identity_tens = create_const_tensor(
854 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
855 )
856 mul_identity.add_input_tensor(identity_tens)
857 # Make sure that fm_id is allocated to a different address than fm_alpha
858 fm_id = ofm.clone(op.name + "_id", set_unique=True)
859 mul_identity.set_output_tensor(fm_id)
860 mul_identity.set_ifm_ofm_shapes()
861 DebugDatabase.add_optimised(op, mul_identity)
862
863 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
864 op.type = Op.Maximum
865 op.name = op.name.replace("LeakyRelu", "Maximum")
866 op.inputs = []
867 ifm.consumer_list.remove(op)
868 op.add_input_tensor(fm_alpha)
869 op.add_input_tensor(fm_id)
870 op.set_ifm_ofm_shapes()
871
872 DebugDatabase.add_optimised(op, op)
873 return op
874
875
876def convert_to_lut(op, lut_values, lut_name):
877 # Rewrite the operation by Add with scalar 0 + LUT activation
878 ifm = op.inputs[0]
879 if ifm is None:
880 return op
881 assert ifm.dtype.size_in_bytes() == 1
882 op.type = Op.Add
883 op.name = op.name + "_lut_" + lut_name
884 # Mark as no-op to enable potential fusing optimizations
885 op.attrs["is_nop"] = True
886 # Create an input tensor containing scalar zero
887 quantization = QuantizationParameters(0.0, 255.0)
888 quantization.scale_f32 = ifm.quantization.scale_f32
889 quantization.zero_point = 0
890 tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization)
891 op.add_input_tensor(tens)
892 op.ifm_shapes.append(Shape4D(tens.shape))
893
894 # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
895 # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
896 # should be the same as the IFM
897 op.forced_output_quantization = ifm.quantization
898 lut_tensor = lut.create_lut_tensor(op.name + "_values", lut_values, DataType.int8)
899 op.set_activation_lut(lut_tensor)
900 op.set_ifm_ofm_shapes()
901 return op
902
903
904def convert_to_lut8(op, fn, fn_name):
905 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
906 # fn is a function(real) -> real
907 ifm, ofm = op.get_ifm_ofm()
908 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
909 return op
910 # Generate the LUT
911 ifm_scale = np.double(ifm.quantization.scale_f32)
912 ofm_scale = np.double(ofm.quantization.scale_f32)
913 zp_in = ifm.quantization.zero_point
914 zp_out = ofm.quantization.zero_point
915 values = []
916 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
917 quantized_min = min(ix)
918 quantized_max = max(ix)
919 for x in ix:
920 x_real = ifm_scale * (x - zp_in)
921 y_real = fn(x_real)
922 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
923 lut_result = min(quantized_max, max(quantized_min, lut_result))
924 values.append(lut_result)
925 return convert_to_lut(op, values, fn_name)
926
927
928def convert_lrelu_to_lut(op, arch):
929 ifm, ofm = op.get_ifm_ofm()
930 # Generate the LUT
931 alpha = op.attrs["alpha"]
932 ifm_scale = np.double(ifm.quantization.scale_f32)
933 ofm_scale = np.double(ofm.quantization.scale_f32)
934 zp_in = ifm.quantization.zero_point
935 zp_out = ofm.quantization.zero_point
936 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
937 alpha_scalar = 1
938 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
939 if "alpha_scaling" in op.attrs:
940 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
941 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
942 values = []
943 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
944 quantized_min = min(ix)
945 quantized_max = max(ix)
946 for x in ix:
947 if x < zp_in:
948 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
949 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
950 )
951 else:
952 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
953 lut_result = min(quantized_max, max(quantized_min, lut_result))
954 values.append(lut_result)
955 return convert_to_lut(op, values, "lrelu")
956
957
958def convert_lrelu(op, arch, nng):
959 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
960 if op.type != Op.LeakyRelu:
961 return op
962 ifm, ofm = op.get_ifm_ofm()
963 if ifm is None or ofm is None:
964 return op
965 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
966 # use LUT for int8/uint8
967 return convert_lrelu_to_lut(op, arch)
968 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
969 # use LeakyRelu unmodified for int16 with equal input/output scaling
970 return op
971 return convert_lrelu_to_mul_max(op, arch)
972
973
974def convert_tanh_sigmoid_to_lut(op, arch, nng):
975 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
976 if op.type == Op.Sigmoid:
977 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
978 elif op.type == Op.Tanh:
979 return convert_to_lut8(op, math.tanh, "tanh")
980 return op
981
982
Jonas Ohlsson81942e92021-08-20 09:33:28 +0200983def remove_reshape_and_squeeze_ops(op, arch):
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200984 if op.run_on_npu and op.type in (Op.Reshape, Op.Squeeze):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200985 ofm = op.ofm
986 ifm = op.ifm
987
988 # Check if quantization is the same in the input and output for the reshape ops
989 if not check_quantized_tens_scaling_equal(ifm, ofm):
990 # TODO Both tensors are needed, since quantisation properties currently are linked to Tensors.
991 # In order to remove this reshape either quantization properties need to be moved to Operator,
992 # or the reshape need to be replace with a NOP.
993 return
994
Patrik Gustavssondf995102021-08-23 15:33:59 +0200995 bypass_reshape_and_squeeze_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200996
997
998def fuse_activation_function_with_prev(op, arch, nng):
999 # if op is a no-op: attempts to move the activation function to the preceding op
1000 if not op.attrs.get("is_nop", False) or op.activation is None:
1001 return op
1002 ifm, ofm = op.get_ifm_ofm()
1003 if ifm is None or ofm is None:
1004 return op
1005 # finds the input(s) to the operation
1006 prev_op = ifm.ops[0]
1007 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1008 fuse = (
1009 prev_op.run_on_npu
1010 and prev_op.type.npu_block_type != NpuBlockType.Default
1011 and len(ifm.ops) == 1
1012 and len(prev_op.outputs[0].consumers()) == 1
1013 and prev_op.activation is None
1014 )
1015 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1016 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1017 # LUT currently only works correctly for elementwise ops
1018 fuse = False
1019 if not fuse:
1020 return op
1021 # Move the fused activation function + corresponding info to prev_op
1022 prev_op.activation = op.activation
1023 prev_op.forced_output_quantization = op.forced_output_quantization
1024 if op.activation_lut is not None:
1025 prev_op.set_activation_lut(op.activation_lut)
1026 # Bypass op
1027 prev_op.set_output_tensor(ofm)
1028 DebugDatabase.add_optimised(op, prev_op)
1029 return op
1030
1031
1032def _leading_pad_ok(leading_pad, stride, kernel_size):
1033 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1034 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1035 max_size = kernel_size // 2
1036 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1037
1038
1039def replace_pad_by_hw_pad(op: Operation, arch, nng):
1040 """
1041 Tries to completely remove a PAD operator by using hardware padding.
1042 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1043 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1044 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1045 if both operations can be run on the NPU.
1046 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1047 """
1048 if (
1049 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
1050 and op.run_on_npu
1051 and op.attrs["padding"] == Padding.VALID
1052 ):
1053 pad_op = op.ifm.ops[0]
1054 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1055 return op
1056 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1057 return op
1058 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1059 k = op.kernel
1060 k_w, k_h = k.dilated_wh()
1061
1062 # Check if the PAD operator can be replaced by hardware padding
1063 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1064 # Too much padding, it would require hardware padding to actually insert zeros
1065 return op
1066 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1067 return op
1068
1069 if op.type.is_avgpool_op():
1070 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1071 for pad, k_size in (
1072 (left, k_w),
1073 (right, k_w),
1074 (top, k_h),
1075 (bottom, k_h),
1076 ):
1077 if pad not in (0, k_size // 2):
1078 return op
1079 # Average pool is converted to depthwise, because NPU average pool + same padding
1080 # has a special implementation that is different from PAD followed by average pool with
1081 # valid padding.
1082 k_w, k_h = op.kernel.width, op.kernel.height
1083 ifm = op.ifm
1084 # Remember other inputs
1085 other_inputs = op.inputs[1:]
1086 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1087 quantization = QuantizationParameters(0.0, 255.0)
1088 quantization.scale_f32 = 1.0 / (k_w * k_h)
1089 quantization.zero_point = 0
1090 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1091 weights = np.full(shape, 1)
1092
1093 weight_tens = create_const_tensor(
1094 op.name + "_weights",
1095 shape,
1096 op.ifm.dtype,
1097 weights,
1098 np.uint8,
1099 purpose=TensorPurpose.Weights,
1100 quantization=quantization,
1101 )
James Peet7519d502021-07-19 16:47:58 +01001102 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001103 op.type = Op.DepthwiseConv2DBias
1104 op.inputs = []
1105 op.add_input_tensor(ifm)
1106 op.add_input_tensor(weight_tens)
1107 # Add bias tensor, all biases set to 0
1108 op.inputs.append(None)
1109 fixup_bias_tensors(op, arch, nng)
1110 # Add other inputs
1111 op.inputs.extend(other_inputs)
1112 op.rounding_mode = NpuRoundingMode.NATURAL
1113
1114 # Bypass the PAD operator
1115 op.set_input_tensor(pad_op.ifm, 0)
1116 # Adjust the padding attributes of the convolution operator
1117 op.attrs["padding"] = Padding.EXPLICIT
1118 op.attrs["explicit_padding"] = (top, left, bottom, right)
1119 op.set_ifm_ofm_shapes()
1120 return op
1121
1122
1123def convert_pad(op: Operation, arch, nng):
1124 """
1125 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1126 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1127 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1128 """
1129 if op.type != Op.Pad or not op.run_on_npu:
1130 return op
1131 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1132
1133 ifm = op.ifm
1134 assert ifm is not None
1135 ifm_shape = Shape4D(ifm.shape)
1136 ofm = op.ofm
1137 assert ofm is not None
1138 ofm.ops = []
1139 ofm_shape = op.ofm_shapes[0]
1140
1141 # Average pool op that copies IFM to the right place inside the OFM
1142 shp0 = Shape4D(0, 0, 0, 0)
1143 shp_top = shp0.with_height(top)
1144 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1145 avgpool_op.activation = op.activation
1146 quant = ofm.quantization
1147 pad_value = quant.zero_point
1148 # Add operations that fill the borders of the OFM
1149 if top > 0:
1150 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1151 zero_tens = create_const_tensor(
1152 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1153 )
1154 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1155 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1156 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1157 if bottom > 0:
1158 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1159 zero_tens = create_const_tensor(
1160 op.name + "_bottom",
1161 shape.as_list(),
1162 ofm.dtype,
1163 shape.elements() * [pad_value],
1164 np.uint8,
1165 quantization=quant,
1166 )
1167 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1168 create_avg_pool_for_concat(
1169 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1170 )
1171 if left > 0:
1172 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1173 zero_tens = create_const_tensor(
1174 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1175 )
1176 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1177 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1178 if right > 0:
1179 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1180 zero_tens = create_const_tensor(
1181 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1182 )
1183 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1184 create_avg_pool_for_concat(
1185 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1186 )
1187
1188 op.type = Op.ConcatTFLite
1189 return avgpool_op
1190
1191
1192def add_attrs_to_resizebilinear(op, arch, nng):
1193 if op.type == Op.ResizeBilinear and op.run_on_npu:
1194 input_tensor = op.inputs[0]
1195 input_shape = op.ifm_shapes[0]
1196 upscaled_height = input_shape.height * 2
1197 upscaled_width = input_shape.width * 2
1198 out_shape = op.ofm_shapes[0]
1199 if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
1200 # this means the output is supposed to be a x2 upscale,
1201 # so we need to do SAME padding
1202 op.attrs["padding"] = Padding.SAME
1203 elif (
1204 op.attrs["align_corners"]
1205 and out_shape.height == (upscaled_height - 1)
1206 and out_shape.width == (upscaled_width - 1)
1207 ):
1208 # here we can just run the avg pool without padding and
1209 # produce a (M * 2 - 1, N * 2 - 1) sized output
1210 op.attrs["padding"] = Padding.VALID
1211 else:
1212 return op
1213 input_tensor.resampling_mode = resampling_mode.NEAREST
1214 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
1215 return op
1216
1217
1218def fixup_bias_tensors(op, arch, nng):
1219 if op.type.needs_bias() and op.bias is None:
1220 # Op has no bias, add bias tensor filled with zeros
1221 nr_biases = op.inputs[1].shape[-1]
1222 bias_values = [0] * nr_biases
1223 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001224 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1225
1226 return op
1227
1228
1229def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1230 if op.type == Op.Mean and op.run_on_npu:
1231 keep_dims = op.attrs.get("keep_dims", False)
1232 inp, axis = op.inputs
1233 shape = inp.shape
1234 dims = len(shape)
1235
1236 # Height and width axes have different index depending on dimensions
1237 if axis.shape == [] or axis.shape[0] == 1: # single axis
1238 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1239 if dims in (2, 3):
1240 if axis == 0:
1241 h, w = shape[axis], 1
1242 else:
1243 h, w = 1, shape[axis]
1244 else:
1245 if axis == 1:
1246 h, w = shape[axis], 1
1247 else:
1248 h, w = 1, shape[axis]
1249 else: # multiple axes
1250 axis = sorted(axis.values)
1251 h, w = [shape[i] for i in axis]
1252
1253 # Set necessary depthwise attributes
1254 op.attrs.update(
1255 {
1256 "padding": Padding.VALID,
1257 "stride_h": 1,
1258 "stride_w": 1,
1259 "strides": (1, 1, 1, 1),
1260 "depth_multiplier": 1,
1261 "channel_multiplier": 1,
1262 "dilation_h_factor": 1,
1263 "dilation_w_factor": 1,
1264 "dilation": (1, 1, 1, 1),
1265 }
1266 )
1267 # Change op type
1268 op.type = Op.DepthwiseConv2DBias
1269 # Set IFM/OFM shapes after changing op type
1270 op.set_ifm_ofm_shapes()
1271
1272 weight_scale, bias = 1, None
1273 ofmq, ifmq = op.ofm.quantization, inp.quantization
1274 # Set rounding mode, scaling and zero point based on which reference implementation to match
1275 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1276 if inp.dtype == DataType.uint8:
1277 # This attribute means a different scaling calculation is used in order to match reference
1278 op.low_precision_scaling = True
1279 weight_scale = h * w
1280 # Set zero points to 0 as they will be adjusted for with bias term
1281 foq = ofmq.clone()
1282 foq.zero_point = 0
1283 fiq = ifmq.clone()
1284 fiq.zero_point = 0
1285 op.forced_input_quantization = fiq
1286 bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
1287 # If the bias term is outside uint8 range, we need an Add op to apply it.
1288 if bias_term < 0 or bias_term > 255:
1289 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1290 # Bias term has higher bitness (i32) than input/output (u8).
1291 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1292 # the bias can only effectively assume values in the range [-255, 255].
1293 intermediate.dtype = DataType.int16
1294 intermediate.quantization.zero_point = 0
1295 add_op = Operation(Op.Add, op.name + "_bias")
1296 add_op.forced_output_quantization = foq
1297 add_op.add_input_tensor(intermediate)
1298 quant = QuantizationParameters()
1299 quant.zero_point = 0
1300 bias_term_tens = create_const_tensor(
James Peet7519d502021-07-19 16:47:58 +01001301 op.name + "_bias", [1, 1, 1, 1], DataType.int16, [bias_term], np.int16, quantization=quant,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001302 )
1303 add_op.add_input_tensor(bias_term_tens)
1304 add_op.set_output_tensor(op.ofm)
1305 add_op.set_ifm_ofm_shapes()
1306 add_op.activation = op.activation
1307 op.activation = None
1308 op.set_output_tensor(intermediate)
1309 op.set_ifm_ofm_shapes()
1310 # If not, we can just do it with the OFM zero point.
1311 else:
1312 foq.zero_point = bias_term
1313 op.forced_output_quantization = foq
1314 else:
1315 assert inp.dtype == DataType.int8
1316 # Use a depthwise to calculate the sum,
1317 # followed by a multiplication with 1/N to get the MEAN
1318 weight_scale = 1
1319 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1320 intermediate.dtype = DataType.int16
1321 mul_op = Operation(Op.Mul, op.name + "_mul")
1322 mul_op.add_input_tensor(intermediate)
1323 # Create scalar containing 1/N
1324 quant = QuantizationParameters()
1325 quant.zero_point = 0
1326 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1327 # while rounding mode NATURAL would round this to -1.
1328 # This can only occur if N is even, and can be emulated by
1329 # multiplying with a number that is slightly smaller than 1/N.
1330 # It must be so small that other roundings are not affected;
1331 # the calculated value is based on worst case,
1332 # which is sum 256 * N (the maximum sum that can occur with int8)
1333 n = int(h * w)
1334 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1335 quant.scale_f32 = 1 / (n - eps)
1336 scalar = create_const_tensor(
1337 op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
1338 )
1339 mul_op.add_input_tensor(scalar)
1340 mul_op.set_output_tensor(op.ofm)
1341 mul_op.set_ifm_ofm_shapes()
1342 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1343 mul_op.activation = op.activation
1344 op.activation = None
1345 op.set_output_tensor(intermediate)
1346 op.set_ifm_ofm_shapes()
1347 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1348 # Here we can just use a simple AvgPool with truncating rounding,
1349 # as we're emulating simple integer division.
1350 op.rounding_mode = NpuRoundingMode.TRUNCATE
1351 op.type = Op.AvgPool
1352 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1353 else:
1354 op.rounding_mode = NpuRoundingMode.NATURAL
1355 weight_scale = 1 / (h * w)
1356 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1357 bias = -ifmq.zero_point * h * w
1358 fiq = ifmq.clone()
1359 fiq.zero_point = 0
1360 op.forced_input_quantization = fiq
1361
1362 # Change dimensions to 4
1363 if dims < 4:
1364 shape = [1] + shape
1365 if dims == 2:
1366 shape += [1]
1367
1368 # If height is greater than max kernel height, reshape to from HxW to 1x(HxW)
1369 if h > 64:
1370 shape = [shape[0], 1, h * w, shape[3]]
1371 op.ifm_shapes[0] = Shape4D(shape)
1372 if h > 256 and op.type == Op.AvgPool:
1373 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1374
1375 # If the AvgPool version is used, we don't need to do anything else
1376 if op.type == Op.AvgPool:
1377 return op
1378
1379 # Make unit weight tensor quantization
1380 weight_quant = ifmq.clone()
1381 weight_quant.min = 0
1382 weight_quant.max = 255
1383 weight_quant.scale_f32 = weight_scale
1384 weight_quant.zero_point = 0
1385
1386 # Set weight shape to [H,W,C,B]
1387 weight_shape = shape[1:4] + [shape[0]]
1388 # Add unit weight tensor
1389 op.set_input_tensor(
1390 create_const_tensor(
1391 "weights",
1392 weight_shape,
1393 inp.dtype,
1394 np.ones(weight_shape),
1395 value_dtype=np.uint8,
1396 quantization=weight_quant,
1397 ),
1398 1,
1399 )
James Peet7519d502021-07-19 16:47:58 +01001400 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001401
1402 # Add None bias tensor
1403 op.inputs.append(None)
1404 # Add bias tensor
1405 if bias:
1406 bias_shape = [shape[-1]]
1407 op.set_input_tensor(
1408 create_const_tensor(
Tim Hall8ae29292021-07-28 16:52:03 +01001409 "bias", bias_shape, inp.dtype, np.ones(bias_shape) * bias, value_dtype=np.int32, quantization=None,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001410 ),
1411 2,
1412 )
1413
1414 return op
1415
1416
1417def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001418 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001419 return op
1420
1421
1422def tflite_optimise_graph(nng, arch):
1423 # Pre-processing step
1424 pre_process_list = [
1425 supported_operator_check,
1426 set_ifm_ofm_op_shapes,
1427 ]
1428
1429 for idx, sg in enumerate(nng.subgraphs):
1430 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1431 nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
1432 )
1433
1434 # Handle Concat Ops
1435 for idx, sg in enumerate(nng.subgraphs):
1436 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1437 sg.refresh_after_modification()
1438
1439 # Handle Split Ops
1440 for idx, sg in enumerate(nng.subgraphs):
1441 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1442 nng,
1443 sg,
1444 arch,
1445 [],
1446 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1447 rewrite_unsupported=False,
1448 )
1449
1450 for idx, sg in enumerate(nng.subgraphs):
1451 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1452 nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False,
1453 )
1454
1455 # Handle sg input output
1456 for idx, sg in enumerate(nng.subgraphs):
1457 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1458 nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
1459 )
1460
Jonas Ohlsson81942e92021-08-20 09:33:28 +02001461 # Removal of reshapes and squeeze
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001462 for sg in nng.subgraphs:
Jonas Ohlsson81942e92021-08-20 09:33:28 +02001463 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshape_and_squeeze_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001464 sg.refresh_after_modification()
1465
1466 # Rewrite of operators
1467 op_rewrite_list = [
1468 set_tensor_equivalence,
1469 convert_mean_to_depthwise_conv_or_avgpool,
1470 convert_depthwise_to_conv,
1471 convert_conv_to_fc,
1472 convert_softmax,
1473 optimise_strided_conv,
1474 convert_hardswish_to_lut,
1475 rewrite_fully_connected_input,
1476 convert_batched_fc_shape,
1477 fixup_conv2d_backprop,
1478 fixup_relus_with_differing_ifm_ofm_scaling,
1479 fixup_elementwise_with_scalars,
1480 reorder_depthwise_weights,
1481 fixup_resizebilinear,
1482 fixup_bias_tensors,
1483 convert_mul_max_to_abs_or_lrelu,
1484 convert_lrelu,
1485 convert_tanh_sigmoid_to_lut,
1486 replace_pad_by_hw_pad,
1487 ]
1488
1489 for idx, sg in enumerate(nng.subgraphs):
1490 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1491 nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
1492 )
1493
1494 for idx, sg in enumerate(nng.subgraphs):
1495 # remove passthrough tensors and attempt further optimizations
1496 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1497 nng,
1498 sg,
1499 arch,
1500 [remove_passthrough_tensor],
1501 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1502 )
1503
1504 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1505 # since ifm/ofm_shapes are of importance to this function
1506 for sg in nng.subgraphs:
1507 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1508 sg.refresh_after_modification()
1509
1510 return nng