blob: 15b82c7e8469bcd5a364ce546e7662beee0c1058 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
25from . import lut
26from . import rewrite_graph
27from . import scaling
28from .api import NpuRoundingMode
29from .data_type import DataType
30from .debug_database import DebugDatabase
31from .errors import UnsupportedFeatureError
32from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Patrik Gustavssondf995102021-08-23 15:33:59 +020033from .graph_optimiser_util import bypass_reshape_and_squeeze_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020034from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020035from .graph_optimiser_util import convert_depthwise_to_conv
36from .graph_optimiser_util import fix_sg_input_output
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020037from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020038from .graph_optimiser_util import needed_total_padding
39from .graph_optimiser_util import set_ifm_ofm_op_shapes
40from .graph_optimiser_util import set_tensor_equivalence
41from .numeric_util import clamp_sigmoid
42from .numeric_util import full_shape
43from .numeric_util import round_away_zero
44from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020045from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020046from .operation import NpuBlockType
47from .operation import Op
48from .operation import Operation
49from .operation import Padding
50from .operation_util import create_avgpool_nop
51from .operation_util import get_pad_values_from_input
52from .shape4d import Shape4D
53from .softmax import SoftMax
54from .tensor import check_quantized_tens_scaling_equal
55from .tensor import create_const_tensor
56from .tensor import create_equivalence_id
57from .tensor import QuantizationParameters
58from .tensor import Tensor
59from .tensor import TensorPurpose
60from .tflite_mapping import optype_to_builtintype
61
62passthrough_nodes = (Op.Identity,)
63
64
65def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
66 """Creates an average pool for the given concat op/input feature map"""
67 ofm = concat_op.ofm
68 avgpool_op = create_avgpool_nop(name)
69 avgpool_op.inputs = [ifm]
70 avgpool_op.outputs = [ofm]
71
72 avgpool_op.write_offset = write_offset
73 avgpool_op.write_shape = ifm_shape
74 ofm.ops.append(avgpool_op)
75 DebugDatabase.add_optimised(concat_op, avgpool_op)
76 avgpool_op.ifm_shapes.append(ifm_shape)
77 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
78 avgpool_op.memory_function = Op.ConcatSliceWrite
79 return avgpool_op
80
81
82def remove_passthrough_tensor(tens, arch, nng):
83 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
84 assert len(tens.ops[0].inputs) == 1
85 tens = tens.ops[0].inputs[0]
86 return tens
87
88
89def rewrite_concat_ops(op, arch):
90 if not op.run_on_npu or not op.type.is_concat_op():
91 return
92
93 axis_4D = 0
94 ofm = op.ofm
95 ofm.ops = []
96 offset = 0
97
98 unfuse_activation_function(op)
99
100 if op.type == Op.Pack:
101 # Pack is also referred to as Stack
102 axis = int(op.attrs["axis"])
103 if axis < 0: # Convert to positive axis
104 axis = len(op.inputs[0].shape) + 1 + axis
105
106 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
107
108 axis_4D = axis + (4 - len(desired_shape))
109
110 for idx, inp in enumerate(op.inputs):
111 op.ifm_shapes[idx] = Shape4D(desired_shape)
112 op.type = Op.PackReshaped
113
114 inputs, axis = op.get_concat_inputs_axis()
115 for idx, inp in enumerate(inputs):
116 if op.type != Op.PackReshaped:
117 op.ifm_shapes[idx] = Shape4D(inp.shape)
118 if axis >= 0:
119 axis_4D = axis + (4 - len(inp.shape))
120 else:
121 axis_4D = axis
122 write_offset = [0, 0, 0, 0]
123 write_offset[axis_4D] = offset
124 concat_end = offset + op.ifm_shapes[idx][axis_4D]
125 create_avg_pool_for_concat(
126 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
127 )
128 offset = concat_end
129 assert ofm.shape[axis] == offset
130
131 return op
132
133
134def rewrite_split_ops(tens, arch, nng):
135
136 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
137 split_op = tens.ops[0]
138
139 # Not supported so leave it and run on CPU
140 if not split_op.run_on_npu:
141 return tens
142
143 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
144
145 tens.ops = []
146 new_op = Operation(Op.SplitSliceRead, split_op.name)
147 new_op.inputs = [inp]
148 ofm_shape_idx = 0
149 read_shape = offset_end
150
151 # For Split the offset cannot be extracted from the tensor so it has to
152 # be calculated from the index of the output tensor
153 if axis is not None:
154 # Get the start and end of the split
155 offset_start = [0] * 4
156 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
157 for idx, out in enumerate(outputs):
158 if axis_4D_list is not None:
159 axis_4D = axis_4D_list[idx]
160 else:
161 split_op.ofm_shapes[idx] = Shape4D(out.shape)
162 if axis >= 0:
163 axis_4D = axis + (4 - len(out.shape))
164 else:
165 axis_4D = axis
166
167 if out == tens:
168 ofm_shape_idx = idx
169 read_shape = split_op.ofm_shapes[idx]
170 break
171
172 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
173
174 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
175 new_op.read_shapes[0] = read_shape
176 new_op.run_on_npu = True
177 new_op.set_output_tensor(tens)
178 new_op.ifm_shapes.append(Shape4D(inp.shape))
179 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
180 DebugDatabase.add_optimised(split_op, new_op)
181
182 return tens
183
184
185def remove_SplitSliceRead(op, arch):
186
187 if op.type == Op.SplitSliceRead:
188 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
189 if (
190 len(op.ofm.consumer_list) == 1
191 and op.ofm.consumer_list[0] is not None
192 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200193 and op.ofm.consumer_list[0].type not in (Op.Reshape, Op.Squeeze)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200194 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
195 ):
196 # SplitSliceRead can be performed by tensor consumer
197 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200198 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200199 else:
200 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
201 avgpool_op.add_input_tensor(op.ifm)
202 avgpool_op.outputs = [op.ofm]
203 op.ofm.ops.remove(op)
204 op.ofm.ops.append(avgpool_op)
205 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
206 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
207 avgpool_op.read_offsets[0] = op.read_offsets[0]
208 avgpool_op.read_shapes[0] = op.read_shapes[0]
209
210 op.ifm.consumer_list.remove(op)
211 DebugDatabase.add_optimised(op, avgpool_op)
212
213
214def insert_copy_op_after_tens(tens):
215 tens_cons_list_copy = tens.consumer_list.copy()
216
217 # Create a avg_pool nop op with ifm as input
218 copy_tens = tens.clone()
219 copy_op = create_avgpool_nop(tens.name + "_avgpool")
220 copy_op.add_input_tensor(tens)
221 copy_op.set_output_tensor(copy_tens)
222 copy_op.set_ifm_ofm_shapes()
223 copy_op.run_on_npu = True
224
225 # Set copy_ifm consumers
226 for tens_cons in tens_cons_list_copy:
227 if tens_cons is not None:
228 for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
229 if cons_inp == tens:
230 tens_cons.set_input_tensor(copy_tens, ifm_idx)
231
232 DebugDatabase.add_optimised(tens.ops[0], copy_op)
233
234
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200235def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
236 k_w, k_h = kernel.dilated_wh()
237 s_x, s_y = kernel.stride
238 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
239 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
240 if padding_type == Padding.SAME:
241 left_pad = (xpad + 0) // 2
242 right_pad = (xpad + 1) // 2
243 top_pad = (ypad + 0) // 2
244 bottom_pad = (ypad + 1) // 2
245 elif padding_type == Padding.VALID:
246 left_pad = 0
247 right_pad = 0
248 top_pad = 0
249 bottom_pad = 0
250 elif padding_type == Padding.EXPLICIT:
251 # Padding is specified in a PAD operator which has been bypassed.
252 top, left, bottom, right = explicit_padding
253 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
254 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
255 else:
256 raise UnsupportedFeatureError(f"Unknown padding")
257 padding = (top_pad, left_pad, bottom_pad, right_pad)
258 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
259 return padding, skirt
260
261
262def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
263 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
264 if padding_type == Padding.SAME:
265 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
266 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
267 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
268 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
269 left_pad = max(kernel_width - 1 - right_pad, 0)
270 top_pad = max(kernel_height - 1 - bottom_pad, 0)
271 elif padding_type == Padding.VALID:
272 right_pad = max(kernel_width - 2, 0)
273 bottom_pad = max(kernel_height - 2, 0)
274 left_pad = kernel_width - 1
275 top_pad = kernel_height - 1
276 else:
277 raise UnsupportedFeatureError(f"Unknown padding")
278 padding = (top_pad, left_pad, bottom_pad, right_pad)
279 skirt = padding
280 return padding, skirt
281
282
283def fixup_conv2d_backprop(op, arch, nng):
284 if op.type == Op.Conv2DBackpropInput:
285 # flip the inputs
286 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
287 op.type = Op.Conv2DBackpropInputSwitchedBias
288 op.ifm.resampling_mode = resampling_mode.TRANSPOSE
289
290 # Update strides
291 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
292
293 return op
294
295
296# Convert the op to an elementwise add
297def convert_resizebilinear_1x1_to_add(op):
298 op.type = Op.Add
299 op.name = op.name + "_add"
300 op.attrs["resizebilinear"] = True
301 # Create an input tensor filled with zeros
302 shape = op.ofm_shapes[0].as_list()
303 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100304 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200305 tens.quantization = QuantizationParameters(0.0, 255.0)
306 tens.quantization.scale_f32 = 1.0
307 tens.quantization.zero_point = 0
308 tens.consumer_list = [op]
309 tens_op = op.inputs[1].ops[0]
310 tens_op.set_output_tensor(tens)
311 # Set the add inputs
312 op.inputs[1] = op.inputs[0]
313 op.inputs[0] = tens
314 op.set_ifm_ofm_shapes()
315
316 return op
317
318
319# Convert ResizeBilinear to a number of 2x2 pool ops
320def convert_resizebilinear_to_2x2_pool(op):
321 count = 0
322 pre_op = op
323 outputs = op.outputs
324
325 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
326 if op.attrs["align_corners"]:
327 shape_modifier = 1
328 op.attrs["padding"] = Padding.VALID
329 else:
330 shape_modifier = 0
331 op.attrs["padding"] = Padding.SAME
332 op.inputs[0].resampling_mode = resampling_mode.NEAREST
333
334 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
335 out_shape = np.array(op.ofm_shapes[0].get_hw_as_list())
336 if (upscaled_shape == upscaled_shape * 2 - shape_modifier).all():
337 return op
338
339 while (upscaled_shape < out_shape).all():
340 if count == 0:
341 scaled_op = pre_op
342 else:
343 scaled_op = op.clone("_{}".format(count))
344 scaled_op.inputs[0] = pre_op.outputs[0]
345
346 upscaled_shape = upscaled_shape * 2 - shape_modifier
347
348 if (upscaled_shape == out_shape).all():
349 scaled_op.outputs = outputs
350 scaled_op.outputs[0].ops = [scaled_op]
351 else:
352 shape = op.ofm_shapes[0].as_list()
353 shape[1:3] = upscaled_shape
354 out_tens = Tensor(shape, DataType.int16, "{}_{}".format(op.outputs[0].name, count))
355 out_tens.quantization = op.outputs[0].quantization.clone()
356 out_tens.quantization.quant_min = np.iinfo(np.int16).min
357 out_tens.quantization.quant_max = np.iinfo(np.int16).max
358 scaled_op.set_output_tensor(out_tens)
359 pre_op = scaled_op
360 count += 1
361
362 # Setup the scale value
363 if scaled_op.inputs[0].dtype.bits == 8 and scaled_op.outputs[0].dtype.bits == 16:
364 scaled_op.rescale = 128
365 elif scaled_op.inputs[0].dtype.bits == 16 and scaled_op.outputs[0].dtype.bits == 8:
366 scaled_op.rescale = 1 / 128
367 else:
368 scaled_op.rescale = None
369 scaled_op.set_ifm_ofm_shapes()
370
371 return op
372
373
374def fixup_resizebilinear(op, arch, nng):
375 if op.type == Op.ResizeBilinear and op.run_on_npu:
376 if op.ifm_shapes[0] == op.ofm_shapes[0]:
377 # Bypass nop resizebilinear
378 op.inputs = op.inputs[:1]
379 op.type = Op.Identity
380 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
381 convert_resizebilinear_1x1_to_add(op)
382 else:
383 convert_resizebilinear_to_2x2_pool(op)
384
385 return op
386
387
388def convert_nop_split_to_identity(op, arch, nng):
389 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
390 # the list comprehension should return a list with a single tensor
391 # if it shouldn't, remove_passthrough_tensor will fail appropriately
392 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
393 op.type = Op.Identity
394 return op
395
396
397def rewrite_fully_connected_input(op, arch, nng):
398 if op.type == Op.FullyConnected:
399 n_in_elems = op.weights.shape[-2]
400 elms = op.ifm.elements()
401 batch_size = elms // n_in_elems
402 assert batch_size * n_in_elems == elms
403
404 op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
405 return op
406
407
408def convert_batched_fc_shape(op, arch, nng):
409 if op.type == Op.FullyConnected:
410 # Check if the first dimension indicates batching
411 if op.ifm_shapes[0].batch > 1:
412 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
413 n = op.ifm_shapes[0].batch
414 h, w = batching_split.get(n, (1, n))
415 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
416
417 # Reshape Weights to be 4D. IO becomes HWIO
418 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100419 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
420 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200421
422 n = op.ofm_shapes[0].batch
423 h, w = batching_split.get(n, (1, n))
424 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
425 return op
426
427
428def unfuse_activation_function(op):
429 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
430 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
431 op.activation = None
432 out_tens = op.outputs[0]
433 intermediate_tens = out_tens.clone("_act_intermediate")
434 act_op.set_output_tensor(out_tens)
435 act_op.add_input_tensor(intermediate_tens)
436 op.set_output_tensor(intermediate_tens)
437 act_op.set_ifm_ofm_shapes()
438
439
440def rewrite_stridedslice_output(op, arch, nng):
441 if not op.run_on_npu or op.type != Op.StridedSlice:
442 return op
443
444 new_axis_mask = op.attrs["new_axis_mask"]
445 shrink_axis_mask = op.attrs["shrink_axis_mask"]
446
447 if shrink_axis_mask == 0 and new_axis_mask == 0:
448 return op
449
450 axis_4D = [0] * len(op.outputs)
451 for idx, out_tens in enumerate(op.outputs):
452 output_shape = list(out_tens.shape)
453
454 if shrink_axis_mask != 0:
455 n = 0
456 axis = 0
457 while shrink_axis_mask:
458 prev_mask = shrink_axis_mask
459 n += 1
460 shrink_axis_mask &= shrink_axis_mask - 1
461 axis = int(math.log2(prev_mask - shrink_axis_mask))
462 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
463
464 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
465 op.attrs["shrink_axis_mask"] = 0
466 if axis >= 0:
467 axis_4D[idx] = axis + (4 - len(output_shape))
468 else:
469 axis_4D[idx] = axis
470 op.ofm_shapes[idx] = Shape4D(output_shape)
471
472 elif new_axis_mask != 0:
473 n = 0
474 axis = 0
475 while new_axis_mask:
476 prev_mask = new_axis_mask
477 n += 1
478 new_axis_mask &= new_axis_mask - 1
479 axis = int(math.log2(prev_mask - new_axis_mask))
480 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
481 new_axis_mask >>= 1
482
483 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
484 op.attrs["new_axis_mask"] = 0
485 if axis >= 0:
486 axis_4D[idx] = axis + (4 - len(output_shape))
487 else:
488 axis_4D[idx] = axis
489 op.ofm_shapes[idx] = Shape4D(output_shape)
490
491 op.attrs["split_axis_4D"] = axis_4D
492 return op
493
494
495def rewrite_unpack_output(op, arch, nng):
496 tens = op.outputs[0]
497 if op.run_on_npu and op.type == Op.Unpack:
498 # Unpack is also referred to as Unstack
499 axis = int(op.attrs["axis"])
500 if axis < 0: # Convert to positive axis
501 axis = len(op.inputs[0].shape) + 1 + axis
502 op.type = Op.UnpackReshaped
503 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
504
505 axis_4D = axis + (4 - len(desired_output_shape))
506 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
507
508 for idx, out_tens in enumerate(op.outputs):
509 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
510 return op
511
512
513def add_padding_fields(op, arch, nng):
514 if op.run_on_npu:
515 if "padding" in op.attrs:
516 input_shape = op.ifm_shapes[0]
517 output_shape = op.ofm_shapes[0]
518 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
519 kernel_size = op.inputs[1].shape[:2]
520 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
521 kernel_size = op.attrs["ksize"][1:3]
522 else:
523 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
524
525 if op.type == Op.Conv2DBackpropInputSwitchedBias:
526 upscaling_factor = output_shape.height // input_shape.height
527 padding, skirt = calc_upscaled_padding_and_skirt(
528 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
529 )
530 else:
531 padding, skirt = calc_padding_and_skirt(
532 op.attrs["padding"], op.kernel, input_shape, op.attrs.get("explicit_padding"),
533 )
534
535 op.attrs["explicit_padding"] = padding
536 op.attrs["skirt"] = skirt
537
538 return op
539
540
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200541def reorder_depthwise_weights(op, arch, nng):
542 if op.type.is_depthwise_conv2d_op():
543 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100544 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
545 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200546 weight_tensor.weight_transpose_depthwise = True
547
548 return op
549
550
551def optimise_strided_conv(op, arch, nng):
552 stride_x, stride_y = op.get_kernel_stride()
553 ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()
554
555 if (
556 op.type == Op.Conv2DBias
557 and op.op_index == 0
558 and stride_x == 2
559 and op.ifm_shapes[0].depth <= 4
560 and op.ifm_shapes[0].width % 2 == 0
561 and weight_tensor is not None
562 and weight_tensor.shape[1] >= 2
563 ):
564 ifm_shape = op.ifm_shapes[0]
565 # IFM
566 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
567
568 # Weights
569 weight_shape = weight_tensor.shape
570 if weight_shape[1] % 2 != 0:
571 weight_shape[1] = weight_shape[1] + 1
572 padded_array = np.zeros(weight_shape)
573 for i in range(weight_shape[0]):
574 padded_array[i] = np.vstack(
575 [
James Peet7519d502021-07-19 16:47:58 +0100576 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200577 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
578 ]
579 )
James Peet7519d502021-07-19 16:47:58 +0100580 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200581 weight_shape[1] //= 2
582 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100583 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200584 weight_tensor.set_all_shapes(weight_shape)
585 # If multiple copies of the weights are used, we could avoid
586 # them having the same address by changing the value_id
587 weight_tensor.value_id = uuid.uuid4()
588
589 # Strides
590 stride_x = 1
591 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
592
593 return op
594
595
596def convert_conv_to_fc(op, arch, nng):
597 # Conv 1x1 can be equivalent to Fully Connected.
598 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
599 # caching/double buffering for the weights.
600 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
601 if op.type == Op.Conv2DBias:
602 h = op.ifm_shapes[0].height
603 w = op.ifm_shapes[0].width
604 kh, kw, _, _ = op.inputs[1].shape
605 if h == 1 and w == 1 and kh == 1 and kw == 1:
606 # Overwrite this op as a Fully Connected Op
607 op.name += "_fc"
608 op.type = Op.FullyConnected
609 op.attrs = {
610 "weights_format": 0,
611 }
612 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
613 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100614 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
615 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200616
617 DebugDatabase.add_optimised(op, op)
618 return op
619
620
621def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
622 if op.run_on_npu and op.type.is_relu_op():
623 ifm = op.inputs[0]
624 ofm = op.outputs[0]
625 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
626 # and requires its own to be inserted
627 if not check_quantized_tens_scaling_equal(ifm, ofm):
628 # Override this op with its own primary op (avgpool)
629 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
630 # And fuse the original activation function to it
631 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200632 # Add explicit rescaling
633 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
634 multiplier, shift = scaling.quantise_scale(rescale)
635 relu_fused_op.rescale = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200636 # Tidy up and assign the ifm and ofm to the new op
637 ifm.consumer_list.remove(op)
638
639 relu_fused_op.add_input_tensor(ifm)
640 relu_fused_op.set_output_tensor(ofm)
641 relu_fused_op.set_ifm_ofm_shapes()
642 op = relu_fused_op
643 return op
644
645
646def fixup_elementwise_with_scalars(op, arch, nng):
647 if op.type.is_binary_elementwise_op():
648 ifm_tensor, ifm2_tensor, _, _ = op.get_ifm_ifm2_weights_ofm()
649 if ifm2_tensor.shape != [] and ifm_tensor.shape != []:
650 diff = len(ifm_tensor.shape) - len(ifm2_tensor.shape)
651 if diff > 0:
652 ifm2_tensor.shape = full_shape(len(ifm_tensor.shape), ifm2_tensor.shape, 1)
653 elif diff < 0:
654 ifm_tensor.shape = full_shape(len(ifm2_tensor.shape), ifm_tensor.shape, 1)
James Peet7519d502021-07-19 16:47:58 +0100655 elif ifm_tensor.shape == [] and ifm_tensor.values is None:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200656 # IFM is marked as a scalar, but is a result of an operation; change it to a shape of size 1
657 ifm_tensor.shape = len(ifm2_tensor.shape) * [1]
658 ifm_tensor.storage_shape = ifm_tensor.shape
James Peet7519d502021-07-19 16:47:58 +0100659 elif ifm2_tensor.shape == [] and ifm2_tensor.values is None:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200660 # IFM2 is marked as a scalar, but is a result of an operation; change it to a shape of size 1
661 ifm2_tensor.shape = len(ifm_tensor.shape) * [1]
662 ifm2_tensor.storage_shape = ifm2_tensor.shape
663 return op
664
665
666def convert_softmax(op, arch, nng):
667 if op.type == Op.Softmax and op.run_on_npu:
668 softmax = SoftMax(op)
669 op = softmax.get_graph()
670 return op
671
672
673def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
674 r"""Whenever there is a subgraph with this topology:
675
676 Input X For X = -1 or X > 0
677 | \ / This subgraph can be replaced with either
678 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
679 | /
680 Max
681 """
682
683 if op.type == Op.Maximum:
684 # finds the Mul input(s) to the Max
685 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
686 if len(muls) == 1:
687 mul = muls[0].ops[0]
688 elif len(muls) == 2:
689 # In the case both inputs are Muls, find the one with the same input as the Max
690 mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
691 else:
692 # No Mul inputs
693 return op
694
695 # make sure the Mul doesn't have any other consumers
696 mul_ofm = mul.outputs[0]
697 if len(mul_ofm.consumers()) != 1:
698 return op
699 # make sure the Mul doesn't have a fused activation function
700 if mul.activation:
701 return op
702 ifm, ofm = op.get_ifm_ofm()
703 if ifm is None or ofm is None:
704 return op
705
706 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
707 return op
708 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
709 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
710 return op
711
712 # finds the branched input that goes to both the Max and the Mul
713 shared = set(op.inputs) & set(mul.inputs)
714 if len(shared) == 1:
715 shared_in = shared.pop()
716 # find the constant scalar input to the Mul
717 const_tens = (set(mul.inputs) - {shared_in}).pop()
718 # check that it is a scalar
719 if const_tens.shape != []:
720 return op
721 const = const_tens.ops[0]
722 # check that it is a constant
723 if const.type != Op.Const:
724 return op
725 # Remove the Mul from the shared input's consumers
726 shared_in.consumer_list.remove(mul)
727 else:
728 return op
729
730 val = const.outputs[0].values
731 if val >= 0:
732 new_op = Op.LeakyRelu
733 op.attrs["alpha"] = val
734 # to produce bit exact results, the alpha is not enough;
735 # save additional scaling info in attr "alpha_scale", to be used as input
736 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +0100737 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200738 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
739 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
740 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
741 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
742 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
743 elif val == -1:
744 new_op = Op.Abs
745 else:
746 return op
747
748 op.type = new_op
749 op.name = op.name.replace("Maximum", new_op.name)
750 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
751 op.inputs = [shared_in]
752 op.set_ifm_ofm_shapes()
753
754 # Record optimisation in debug database
755 DebugDatabase.add_optimised(op, op)
756
757 return op
758
759
760def convert_hardswish_to_lut(op, arch, nng):
761 if op.type == Op.HardSwish:
762 ifm, ofm = op.get_ifm_ofm()
763 # Generate the LUT
764 ifm_scale = np.double(ifm.quantization.scale_f32)
765 ofm_scale = np.double(ofm.quantization.scale_f32)
766 zp_in = ifm.quantization.zero_point
767 zp_out = ofm.quantization.zero_point
768 ifm_scale_hires = (1 / 128) * ifm_scale
769 relu_multiplier = np.double(3 / 32768)
770 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
771 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
772 # Use 16bit scale
773 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
774 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
775
776 values = []
777 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
778 quantized_min = min(ix)
779 quantized_max = max(ix)
780 for x in ix:
781 input_value = x - zp_in
782 input_value_hires = input_value * 128
783 # Compute the input value on essentially the output scale, not shifted yet
784 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
785 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
786 relu_value = np.int16(input_value_hires)
787 if relu_shift < 31:
788 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
789
790 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
791
792 if relu_shift < 31:
793 relu_value = fp_math.shift_left16(relu_value, 1)
794
795 if relu_shift > 31:
796 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
797
798 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
799 # Now convert that to a 16bit fixedpoint value in [0, 1]
800 relu_value = (relu_value + (1 << 15)) >> 1
801 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
802 shift = 31 - out_shift
803 shift = -shift if shift < 0 else 0
804 # Finally apply the output shift
805 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
806 lut_result = min(quantized_max, max(quantized_min, lut_result))
807 values.append(lut_result)
808 return convert_to_lut(op, values, "hardswish")
809 return op
810
811
812def convert_lrelu_to_mul_max(op, arch):
813 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
814 # (the opposite of convert_mul_max_to_abs_or_lrelu)
815 ifm, ofm = op.get_ifm_ofm()
816 if ifm is None or ofm is None:
817 return op
818
819 # Add multiplication with alpha
820 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
821 mul_alpha.add_input_tensor(ifm)
822 # Create const tensor containing alpha as scalar
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200823 alpha = np.float32(op.attrs["alpha"])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200824 quantization = ifm.quantization.clone()
825 quantization.min = 0
826 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
827 quantization.zero_point = 0
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200828 if np.isinf(1 / alpha):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200829 # Handling of alpha near zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200830 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200831 scalar = 0
832 else:
833 quantization.scale_f32 = alpha
834 scalar = alpha
835 alpha_tens = create_const_tensor(
836 op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
837 )
James Peet7519d502021-07-19 16:47:58 +0100838 alpha_tens.values = np.array([1])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200839 mul_alpha.add_input_tensor(alpha_tens)
840 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
841 mul_alpha.set_output_tensor(fm_alpha)
842 mul_alpha.set_ifm_ofm_shapes()
843 DebugDatabase.add_optimised(op, mul_alpha)
844
845 if check_quantized_tens_scaling_equal(ifm, ofm):
846 # No identity multiplication is needed
847 fm_id = ifm
848 else:
849 # Add multiplication with identity
850 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
851 mul_identity.add_input_tensor(ifm)
852 # Create const tensor containing identity as scalar
853 quantization = ifm.quantization.clone()
854 quantization.min = 0
855 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200856 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200857 quantization.zero_point = 0
858 identity_tens = create_const_tensor(
859 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
860 )
861 mul_identity.add_input_tensor(identity_tens)
862 # Make sure that fm_id is allocated to a different address than fm_alpha
863 fm_id = ofm.clone(op.name + "_id", set_unique=True)
864 mul_identity.set_output_tensor(fm_id)
865 mul_identity.set_ifm_ofm_shapes()
866 DebugDatabase.add_optimised(op, mul_identity)
867
868 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
869 op.type = Op.Maximum
870 op.name = op.name.replace("LeakyRelu", "Maximum")
871 op.inputs = []
872 ifm.consumer_list.remove(op)
873 op.add_input_tensor(fm_alpha)
874 op.add_input_tensor(fm_id)
875 op.set_ifm_ofm_shapes()
876
877 DebugDatabase.add_optimised(op, op)
878 return op
879
880
881def convert_to_lut(op, lut_values, lut_name):
882 # Rewrite the operation by Add with scalar 0 + LUT activation
883 ifm = op.inputs[0]
884 if ifm is None:
885 return op
886 assert ifm.dtype.size_in_bytes() == 1
887 op.type = Op.Add
888 op.name = op.name + "_lut_" + lut_name
889 # Mark as no-op to enable potential fusing optimizations
890 op.attrs["is_nop"] = True
891 # Create an input tensor containing scalar zero
892 quantization = QuantizationParameters(0.0, 255.0)
893 quantization.scale_f32 = ifm.quantization.scale_f32
894 quantization.zero_point = 0
895 tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization)
896 op.add_input_tensor(tens)
897 op.ifm_shapes.append(Shape4D(tens.shape))
898
899 # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
900 # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
901 # should be the same as the IFM
902 op.forced_output_quantization = ifm.quantization
903 lut_tensor = lut.create_lut_tensor(op.name + "_values", lut_values, DataType.int8)
904 op.set_activation_lut(lut_tensor)
905 op.set_ifm_ofm_shapes()
906 return op
907
908
909def convert_to_lut8(op, fn, fn_name):
910 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
911 # fn is a function(real) -> real
912 ifm, ofm = op.get_ifm_ofm()
913 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
914 return op
915 # Generate the LUT
916 ifm_scale = np.double(ifm.quantization.scale_f32)
917 ofm_scale = np.double(ofm.quantization.scale_f32)
918 zp_in = ifm.quantization.zero_point
919 zp_out = ofm.quantization.zero_point
920 values = []
921 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
922 quantized_min = min(ix)
923 quantized_max = max(ix)
924 for x in ix:
925 x_real = ifm_scale * (x - zp_in)
926 y_real = fn(x_real)
927 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
928 lut_result = min(quantized_max, max(quantized_min, lut_result))
929 values.append(lut_result)
930 return convert_to_lut(op, values, fn_name)
931
932
933def convert_lrelu_to_lut(op, arch):
934 ifm, ofm = op.get_ifm_ofm()
935 # Generate the LUT
936 alpha = op.attrs["alpha"]
937 ifm_scale = np.double(ifm.quantization.scale_f32)
938 ofm_scale = np.double(ofm.quantization.scale_f32)
939 zp_in = ifm.quantization.zero_point
940 zp_out = ofm.quantization.zero_point
941 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
942 alpha_scalar = 1
943 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
944 if "alpha_scaling" in op.attrs:
945 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
946 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
947 values = []
948 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
949 quantized_min = min(ix)
950 quantized_max = max(ix)
951 for x in ix:
952 if x < zp_in:
953 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
954 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
955 )
956 else:
957 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
958 lut_result = min(quantized_max, max(quantized_min, lut_result))
959 values.append(lut_result)
960 return convert_to_lut(op, values, "lrelu")
961
962
963def convert_lrelu(op, arch, nng):
964 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
965 if op.type != Op.LeakyRelu:
966 return op
967 ifm, ofm = op.get_ifm_ofm()
968 if ifm is None or ofm is None:
969 return op
970 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
971 # use LUT for int8/uint8
972 return convert_lrelu_to_lut(op, arch)
973 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
974 # use LeakyRelu unmodified for int16 with equal input/output scaling
975 return op
976 return convert_lrelu_to_mul_max(op, arch)
977
978
979def convert_tanh_sigmoid_to_lut(op, arch, nng):
980 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
981 if op.type == Op.Sigmoid:
982 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
983 elif op.type == Op.Tanh:
984 return convert_to_lut8(op, math.tanh, "tanh")
985 return op
986
987
Jonas Ohlsson81942e92021-08-20 09:33:28 +0200988def remove_reshape_and_squeeze_ops(op, arch):
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200989 if op.run_on_npu and op.type in (Op.Reshape, Op.Squeeze):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200990 ofm = op.ofm
991 ifm = op.ifm
992
993 # Check if quantization is the same in the input and output for the reshape ops
994 if not check_quantized_tens_scaling_equal(ifm, ofm):
995 # TODO Both tensors are needed, since quantisation properties currently are linked to Tensors.
996 # In order to remove this reshape either quantization properties need to be moved to Operator,
997 # or the reshape need to be replace with a NOP.
998 return
999
Patrik Gustavssondf995102021-08-23 15:33:59 +02001000 bypass_reshape_and_squeeze_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001001
1002
1003def fuse_activation_function_with_prev(op, arch, nng):
1004 # if op is a no-op: attempts to move the activation function to the preceding op
1005 if not op.attrs.get("is_nop", False) or op.activation is None:
1006 return op
1007 ifm, ofm = op.get_ifm_ofm()
1008 if ifm is None or ofm is None:
1009 return op
1010 # finds the input(s) to the operation
1011 prev_op = ifm.ops[0]
1012 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1013 fuse = (
1014 prev_op.run_on_npu
1015 and prev_op.type.npu_block_type != NpuBlockType.Default
1016 and len(ifm.ops) == 1
1017 and len(prev_op.outputs[0].consumers()) == 1
1018 and prev_op.activation is None
1019 )
1020 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1021 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1022 # LUT currently only works correctly for elementwise ops
1023 fuse = False
1024 if not fuse:
1025 return op
1026 # Move the fused activation function + corresponding info to prev_op
1027 prev_op.activation = op.activation
1028 prev_op.forced_output_quantization = op.forced_output_quantization
1029 if op.activation_lut is not None:
1030 prev_op.set_activation_lut(op.activation_lut)
1031 # Bypass op
1032 prev_op.set_output_tensor(ofm)
1033 DebugDatabase.add_optimised(op, prev_op)
1034 return op
1035
1036
1037def _leading_pad_ok(leading_pad, stride, kernel_size):
1038 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1039 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1040 max_size = kernel_size // 2
1041 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1042
1043
1044def replace_pad_by_hw_pad(op: Operation, arch, nng):
1045 """
1046 Tries to completely remove a PAD operator by using hardware padding.
1047 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1048 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1049 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1050 if both operations can be run on the NPU.
1051 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1052 """
1053 if (
1054 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
1055 and op.run_on_npu
1056 and op.attrs["padding"] == Padding.VALID
1057 ):
1058 pad_op = op.ifm.ops[0]
1059 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1060 return op
1061 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1062 return op
1063 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1064 k = op.kernel
1065 k_w, k_h = k.dilated_wh()
1066
1067 # Check if the PAD operator can be replaced by hardware padding
1068 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1069 # Too much padding, it would require hardware padding to actually insert zeros
1070 return op
1071 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1072 return op
1073
1074 if op.type.is_avgpool_op():
1075 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1076 for pad, k_size in (
1077 (left, k_w),
1078 (right, k_w),
1079 (top, k_h),
1080 (bottom, k_h),
1081 ):
1082 if pad not in (0, k_size // 2):
1083 return op
1084 # Average pool is converted to depthwise, because NPU average pool + same padding
1085 # has a special implementation that is different from PAD followed by average pool with
1086 # valid padding.
1087 k_w, k_h = op.kernel.width, op.kernel.height
1088 ifm = op.ifm
1089 # Remember other inputs
1090 other_inputs = op.inputs[1:]
1091 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1092 quantization = QuantizationParameters(0.0, 255.0)
1093 quantization.scale_f32 = 1.0 / (k_w * k_h)
1094 quantization.zero_point = 0
1095 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1096 weights = np.full(shape, 1)
1097
1098 weight_tens = create_const_tensor(
1099 op.name + "_weights",
1100 shape,
1101 op.ifm.dtype,
1102 weights,
1103 np.uint8,
1104 purpose=TensorPurpose.Weights,
1105 quantization=quantization,
1106 )
James Peet7519d502021-07-19 16:47:58 +01001107 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001108 op.type = Op.DepthwiseConv2DBias
1109 op.inputs = []
1110 op.add_input_tensor(ifm)
1111 op.add_input_tensor(weight_tens)
1112 # Add bias tensor, all biases set to 0
1113 op.inputs.append(None)
1114 fixup_bias_tensors(op, arch, nng)
1115 # Add other inputs
1116 op.inputs.extend(other_inputs)
1117 op.rounding_mode = NpuRoundingMode.NATURAL
1118
1119 # Bypass the PAD operator
1120 op.set_input_tensor(pad_op.ifm, 0)
1121 # Adjust the padding attributes of the convolution operator
1122 op.attrs["padding"] = Padding.EXPLICIT
1123 op.attrs["explicit_padding"] = (top, left, bottom, right)
1124 op.set_ifm_ofm_shapes()
1125 return op
1126
1127
1128def convert_pad(op: Operation, arch, nng):
1129 """
1130 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1131 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1132 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1133 """
1134 if op.type != Op.Pad or not op.run_on_npu:
1135 return op
1136 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1137
1138 ifm = op.ifm
1139 assert ifm is not None
1140 ifm_shape = Shape4D(ifm.shape)
1141 ofm = op.ofm
1142 assert ofm is not None
1143 ofm.ops = []
1144 ofm_shape = op.ofm_shapes[0]
1145
1146 # Average pool op that copies IFM to the right place inside the OFM
1147 shp0 = Shape4D(0, 0, 0, 0)
1148 shp_top = shp0.with_height(top)
1149 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1150 avgpool_op.activation = op.activation
1151 quant = ofm.quantization
1152 pad_value = quant.zero_point
1153 # Add operations that fill the borders of the OFM
1154 if top > 0:
1155 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1156 zero_tens = create_const_tensor(
1157 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1158 )
1159 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1160 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1161 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1162 if bottom > 0:
1163 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1164 zero_tens = create_const_tensor(
1165 op.name + "_bottom",
1166 shape.as_list(),
1167 ofm.dtype,
1168 shape.elements() * [pad_value],
1169 np.uint8,
1170 quantization=quant,
1171 )
1172 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1173 create_avg_pool_for_concat(
1174 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1175 )
1176 if left > 0:
1177 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1178 zero_tens = create_const_tensor(
1179 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1180 )
1181 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1182 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1183 if right > 0:
1184 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1185 zero_tens = create_const_tensor(
1186 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1187 )
1188 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1189 create_avg_pool_for_concat(
1190 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1191 )
1192
1193 op.type = Op.ConcatTFLite
1194 return avgpool_op
1195
1196
1197def add_attrs_to_resizebilinear(op, arch, nng):
1198 if op.type == Op.ResizeBilinear and op.run_on_npu:
1199 input_tensor = op.inputs[0]
1200 input_shape = op.ifm_shapes[0]
1201 upscaled_height = input_shape.height * 2
1202 upscaled_width = input_shape.width * 2
1203 out_shape = op.ofm_shapes[0]
1204 if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
1205 # this means the output is supposed to be a x2 upscale,
1206 # so we need to do SAME padding
1207 op.attrs["padding"] = Padding.SAME
1208 elif (
1209 op.attrs["align_corners"]
1210 and out_shape.height == (upscaled_height - 1)
1211 and out_shape.width == (upscaled_width - 1)
1212 ):
1213 # here we can just run the avg pool without padding and
1214 # produce a (M * 2 - 1, N * 2 - 1) sized output
1215 op.attrs["padding"] = Padding.VALID
1216 else:
1217 return op
1218 input_tensor.resampling_mode = resampling_mode.NEAREST
1219 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
1220 return op
1221
1222
1223def fixup_bias_tensors(op, arch, nng):
1224 if op.type.needs_bias() and op.bias is None:
1225 # Op has no bias, add bias tensor filled with zeros
1226 nr_biases = op.inputs[1].shape[-1]
1227 bias_values = [0] * nr_biases
1228 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001229 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1230
1231 return op
1232
1233
1234def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1235 if op.type == Op.Mean and op.run_on_npu:
1236 keep_dims = op.attrs.get("keep_dims", False)
1237 inp, axis = op.inputs
1238 shape = inp.shape
1239 dims = len(shape)
1240
1241 # Height and width axes have different index depending on dimensions
1242 if axis.shape == [] or axis.shape[0] == 1: # single axis
1243 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1244 if dims in (2, 3):
1245 if axis == 0:
1246 h, w = shape[axis], 1
1247 else:
1248 h, w = 1, shape[axis]
1249 else:
1250 if axis == 1:
1251 h, w = shape[axis], 1
1252 else:
1253 h, w = 1, shape[axis]
1254 else: # multiple axes
1255 axis = sorted(axis.values)
1256 h, w = [shape[i] for i in axis]
1257
1258 # Set necessary depthwise attributes
1259 op.attrs.update(
1260 {
1261 "padding": Padding.VALID,
1262 "stride_h": 1,
1263 "stride_w": 1,
1264 "strides": (1, 1, 1, 1),
1265 "depth_multiplier": 1,
1266 "channel_multiplier": 1,
1267 "dilation_h_factor": 1,
1268 "dilation_w_factor": 1,
1269 "dilation": (1, 1, 1, 1),
1270 }
1271 )
1272 # Change op type
1273 op.type = Op.DepthwiseConv2DBias
1274 # Set IFM/OFM shapes after changing op type
1275 op.set_ifm_ofm_shapes()
1276
1277 weight_scale, bias = 1, None
1278 ofmq, ifmq = op.ofm.quantization, inp.quantization
1279 # Set rounding mode, scaling and zero point based on which reference implementation to match
1280 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1281 if inp.dtype == DataType.uint8:
1282 # This attribute means a different scaling calculation is used in order to match reference
1283 op.low_precision_scaling = True
1284 weight_scale = h * w
1285 # Set zero points to 0 as they will be adjusted for with bias term
1286 foq = ofmq.clone()
1287 foq.zero_point = 0
1288 fiq = ifmq.clone()
1289 fiq.zero_point = 0
1290 op.forced_input_quantization = fiq
1291 bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
1292 # If the bias term is outside uint8 range, we need an Add op to apply it.
1293 if bias_term < 0 or bias_term > 255:
1294 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1295 # Bias term has higher bitness (i32) than input/output (u8).
1296 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1297 # the bias can only effectively assume values in the range [-255, 255].
1298 intermediate.dtype = DataType.int16
1299 intermediate.quantization.zero_point = 0
1300 add_op = Operation(Op.Add, op.name + "_bias")
1301 add_op.forced_output_quantization = foq
1302 add_op.add_input_tensor(intermediate)
1303 quant = QuantizationParameters()
1304 quant.zero_point = 0
1305 bias_term_tens = create_const_tensor(
James Peet7519d502021-07-19 16:47:58 +01001306 op.name + "_bias", [1, 1, 1, 1], DataType.int16, [bias_term], np.int16, quantization=quant,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001307 )
1308 add_op.add_input_tensor(bias_term_tens)
1309 add_op.set_output_tensor(op.ofm)
1310 add_op.set_ifm_ofm_shapes()
1311 add_op.activation = op.activation
1312 op.activation = None
1313 op.set_output_tensor(intermediate)
1314 op.set_ifm_ofm_shapes()
1315 # If not, we can just do it with the OFM zero point.
1316 else:
1317 foq.zero_point = bias_term
1318 op.forced_output_quantization = foq
1319 else:
1320 assert inp.dtype == DataType.int8
1321 # Use a depthwise to calculate the sum,
1322 # followed by a multiplication with 1/N to get the MEAN
1323 weight_scale = 1
1324 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1325 intermediate.dtype = DataType.int16
1326 mul_op = Operation(Op.Mul, op.name + "_mul")
1327 mul_op.add_input_tensor(intermediate)
1328 # Create scalar containing 1/N
1329 quant = QuantizationParameters()
1330 quant.zero_point = 0
1331 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1332 # while rounding mode NATURAL would round this to -1.
1333 # This can only occur if N is even, and can be emulated by
1334 # multiplying with a number that is slightly smaller than 1/N.
1335 # It must be so small that other roundings are not affected;
1336 # the calculated value is based on worst case,
1337 # which is sum 256 * N (the maximum sum that can occur with int8)
1338 n = int(h * w)
1339 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1340 quant.scale_f32 = 1 / (n - eps)
1341 scalar = create_const_tensor(
1342 op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
1343 )
1344 mul_op.add_input_tensor(scalar)
1345 mul_op.set_output_tensor(op.ofm)
1346 mul_op.set_ifm_ofm_shapes()
1347 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1348 mul_op.activation = op.activation
1349 op.activation = None
1350 op.set_output_tensor(intermediate)
1351 op.set_ifm_ofm_shapes()
1352 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1353 # Here we can just use a simple AvgPool with truncating rounding,
1354 # as we're emulating simple integer division.
1355 op.rounding_mode = NpuRoundingMode.TRUNCATE
1356 op.type = Op.AvgPool
1357 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1358 else:
1359 op.rounding_mode = NpuRoundingMode.NATURAL
1360 weight_scale = 1 / (h * w)
1361 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1362 bias = -ifmq.zero_point * h * w
1363 fiq = ifmq.clone()
1364 fiq.zero_point = 0
1365 op.forced_input_quantization = fiq
1366
1367 # Change dimensions to 4
1368 if dims < 4:
1369 shape = [1] + shape
1370 if dims == 2:
1371 shape += [1]
1372
1373 # If height is greater than max kernel height, reshape to from HxW to 1x(HxW)
1374 if h > 64:
1375 shape = [shape[0], 1, h * w, shape[3]]
1376 op.ifm_shapes[0] = Shape4D(shape)
1377 if h > 256 and op.type == Op.AvgPool:
1378 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1379
1380 # If the AvgPool version is used, we don't need to do anything else
1381 if op.type == Op.AvgPool:
1382 return op
1383
1384 # Make unit weight tensor quantization
1385 weight_quant = ifmq.clone()
1386 weight_quant.min = 0
1387 weight_quant.max = 255
1388 weight_quant.scale_f32 = weight_scale
1389 weight_quant.zero_point = 0
1390
1391 # Set weight shape to [H,W,C,B]
1392 weight_shape = shape[1:4] + [shape[0]]
1393 # Add unit weight tensor
1394 op.set_input_tensor(
1395 create_const_tensor(
1396 "weights",
1397 weight_shape,
1398 inp.dtype,
1399 np.ones(weight_shape),
1400 value_dtype=np.uint8,
1401 quantization=weight_quant,
1402 ),
1403 1,
1404 )
James Peet7519d502021-07-19 16:47:58 +01001405 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001406
1407 # Add None bias tensor
1408 op.inputs.append(None)
1409 # Add bias tensor
1410 if bias:
1411 bias_shape = [shape[-1]]
1412 op.set_input_tensor(
1413 create_const_tensor(
Tim Hall8ae29292021-07-28 16:52:03 +01001414 "bias", bias_shape, inp.dtype, np.ones(bias_shape) * bias, value_dtype=np.int32, quantization=None,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001415 ),
1416 2,
1417 )
1418
1419 return op
1420
1421
1422def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001423 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001424 return op
1425
1426
1427def tflite_optimise_graph(nng, arch):
1428 # Pre-processing step
1429 pre_process_list = [
1430 supported_operator_check,
1431 set_ifm_ofm_op_shapes,
1432 ]
1433
1434 for idx, sg in enumerate(nng.subgraphs):
1435 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1436 nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
1437 )
1438
1439 # Handle Concat Ops
1440 for idx, sg in enumerate(nng.subgraphs):
1441 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1442 sg.refresh_after_modification()
1443
1444 # Handle Split Ops
1445 for idx, sg in enumerate(nng.subgraphs):
1446 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1447 nng,
1448 sg,
1449 arch,
1450 [],
1451 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1452 rewrite_unsupported=False,
1453 )
1454
1455 for idx, sg in enumerate(nng.subgraphs):
1456 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1457 nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False,
1458 )
1459
1460 # Handle sg input output
1461 for idx, sg in enumerate(nng.subgraphs):
1462 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1463 nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
1464 )
1465
Jonas Ohlsson81942e92021-08-20 09:33:28 +02001466 # Removal of reshapes and squeeze
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001467 for sg in nng.subgraphs:
Jonas Ohlsson81942e92021-08-20 09:33:28 +02001468 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshape_and_squeeze_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001469 sg.refresh_after_modification()
1470
1471 # Rewrite of operators
1472 op_rewrite_list = [
1473 set_tensor_equivalence,
1474 convert_mean_to_depthwise_conv_or_avgpool,
1475 convert_depthwise_to_conv,
1476 convert_conv_to_fc,
1477 convert_softmax,
1478 optimise_strided_conv,
1479 convert_hardswish_to_lut,
1480 rewrite_fully_connected_input,
1481 convert_batched_fc_shape,
1482 fixup_conv2d_backprop,
1483 fixup_relus_with_differing_ifm_ofm_scaling,
1484 fixup_elementwise_with_scalars,
1485 reorder_depthwise_weights,
1486 fixup_resizebilinear,
1487 fixup_bias_tensors,
1488 convert_mul_max_to_abs_or_lrelu,
1489 convert_lrelu,
1490 convert_tanh_sigmoid_to_lut,
1491 replace_pad_by_hw_pad,
1492 ]
1493
1494 for idx, sg in enumerate(nng.subgraphs):
1495 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1496 nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
1497 )
1498
1499 for idx, sg in enumerate(nng.subgraphs):
1500 # remove passthrough tensors and attempt further optimizations
1501 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1502 nng,
1503 sg,
1504 arch,
1505 [remove_passthrough_tensor],
1506 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1507 )
1508
1509 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1510 # since ifm/ofm_shapes are of importance to this function
1511 for sg in nng.subgraphs:
1512 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1513 sg.refresh_after_modification()
1514
1515 return nng