blob: cf211de46219bf285d273cd64fffa08ad143b7f2 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from . import rewrite_graph
26from . import scaling
27from .api import NpuRoundingMode
28from .data_type import DataType
29from .debug_database import DebugDatabase
30from .errors import UnsupportedFeatureError
31from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020032from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020033from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020034from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020035from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020036from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020037from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020038from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020039from .graph_optimiser_util import needed_total_padding
40from .graph_optimiser_util import set_ifm_ofm_op_shapes
41from .graph_optimiser_util import set_tensor_equivalence
42from .numeric_util import clamp_sigmoid
43from .numeric_util import full_shape
44from .numeric_util import round_away_zero
45from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020046from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020047from .operation import NpuBlockType
48from .operation import Op
49from .operation import Operation
50from .operation import Padding
51from .operation_util import create_avgpool_nop
52from .operation_util import get_pad_values_from_input
53from .shape4d import Shape4D
54from .softmax import SoftMax
55from .tensor import check_quantized_tens_scaling_equal
56from .tensor import create_const_tensor
57from .tensor import create_equivalence_id
58from .tensor import QuantizationParameters
59from .tensor import Tensor
60from .tensor import TensorPurpose
61from .tflite_mapping import optype_to_builtintype
62
63passthrough_nodes = (Op.Identity,)
64
65
66def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
67 """Creates an average pool for the given concat op/input feature map"""
68 ofm = concat_op.ofm
69 avgpool_op = create_avgpool_nop(name)
70 avgpool_op.inputs = [ifm]
71 avgpool_op.outputs = [ofm]
72
73 avgpool_op.write_offset = write_offset
74 avgpool_op.write_shape = ifm_shape
75 ofm.ops.append(avgpool_op)
76 DebugDatabase.add_optimised(concat_op, avgpool_op)
77 avgpool_op.ifm_shapes.append(ifm_shape)
78 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
79 avgpool_op.memory_function = Op.ConcatSliceWrite
80 return avgpool_op
81
82
83def remove_passthrough_tensor(tens, arch, nng):
84 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
85 assert len(tens.ops[0].inputs) == 1
86 tens = tens.ops[0].inputs[0]
87 return tens
88
89
90def rewrite_concat_ops(op, arch):
91 if not op.run_on_npu or not op.type.is_concat_op():
92 return
93
94 axis_4D = 0
95 ofm = op.ofm
96 ofm.ops = []
97 offset = 0
98
99 unfuse_activation_function(op)
100
101 if op.type == Op.Pack:
102 # Pack is also referred to as Stack
103 axis = int(op.attrs["axis"])
104 if axis < 0: # Convert to positive axis
105 axis = len(op.inputs[0].shape) + 1 + axis
106
107 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
108
109 axis_4D = axis + (4 - len(desired_shape))
110
111 for idx, inp in enumerate(op.inputs):
112 op.ifm_shapes[idx] = Shape4D(desired_shape)
113 op.type = Op.PackReshaped
114
115 inputs, axis = op.get_concat_inputs_axis()
116 for idx, inp in enumerate(inputs):
117 if op.type != Op.PackReshaped:
118 op.ifm_shapes[idx] = Shape4D(inp.shape)
119 if axis >= 0:
120 axis_4D = axis + (4 - len(inp.shape))
121 else:
122 axis_4D = axis
123 write_offset = [0, 0, 0, 0]
124 write_offset[axis_4D] = offset
125 concat_end = offset + op.ifm_shapes[idx][axis_4D]
126 create_avg_pool_for_concat(
127 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
128 )
129 offset = concat_end
130 assert ofm.shape[axis] == offset
131
132 return op
133
134
135def rewrite_split_ops(tens, arch, nng):
136
137 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
138 split_op = tens.ops[0]
139
140 # Not supported so leave it and run on CPU
141 if not split_op.run_on_npu:
142 return tens
143
144 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
145
146 tens.ops = []
147 new_op = Operation(Op.SplitSliceRead, split_op.name)
148 new_op.inputs = [inp]
149 ofm_shape_idx = 0
150 read_shape = offset_end
151
152 # For Split the offset cannot be extracted from the tensor so it has to
153 # be calculated from the index of the output tensor
154 if axis is not None:
155 # Get the start and end of the split
156 offset_start = [0] * 4
157 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
158 for idx, out in enumerate(outputs):
159 if axis_4D_list is not None:
160 axis_4D = axis_4D_list[idx]
161 else:
162 split_op.ofm_shapes[idx] = Shape4D(out.shape)
163 if axis >= 0:
164 axis_4D = axis + (4 - len(out.shape))
165 else:
166 axis_4D = axis
167
168 if out == tens:
169 ofm_shape_idx = idx
170 read_shape = split_op.ofm_shapes[idx]
171 break
172
173 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
174
175 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
176 new_op.read_shapes[0] = read_shape
177 new_op.run_on_npu = True
178 new_op.set_output_tensor(tens)
179 new_op.ifm_shapes.append(Shape4D(inp.shape))
180 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
181 DebugDatabase.add_optimised(split_op, new_op)
182
183 return tens
184
185
186def remove_SplitSliceRead(op, arch):
187
188 if op.type == Op.SplitSliceRead:
189 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
190 if (
191 len(op.ofm.consumer_list) == 1
192 and op.ofm.consumer_list[0] is not None
193 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200194 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200195 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
196 ):
197 # SplitSliceRead can be performed by tensor consumer
198 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200199 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200200 else:
201 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
202 avgpool_op.add_input_tensor(op.ifm)
203 avgpool_op.outputs = [op.ofm]
204 op.ofm.ops.remove(op)
205 op.ofm.ops.append(avgpool_op)
206 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
207 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
208 avgpool_op.read_offsets[0] = op.read_offsets[0]
209 avgpool_op.read_shapes[0] = op.read_shapes[0]
210
211 op.ifm.consumer_list.remove(op)
212 DebugDatabase.add_optimised(op, avgpool_op)
213
214
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200215def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
216 k_w, k_h = kernel.dilated_wh()
217 s_x, s_y = kernel.stride
218 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
219 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
220 if padding_type == Padding.SAME:
221 left_pad = (xpad + 0) // 2
222 right_pad = (xpad + 1) // 2
223 top_pad = (ypad + 0) // 2
224 bottom_pad = (ypad + 1) // 2
225 elif padding_type == Padding.VALID:
226 left_pad = 0
227 right_pad = 0
228 top_pad = 0
229 bottom_pad = 0
230 elif padding_type == Padding.EXPLICIT:
231 # Padding is specified in a PAD operator which has been bypassed.
232 top, left, bottom, right = explicit_padding
233 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
234 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
235 else:
236 raise UnsupportedFeatureError(f"Unknown padding")
237 padding = (top_pad, left_pad, bottom_pad, right_pad)
238 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
239 return padding, skirt
240
241
242def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
243 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
244 if padding_type == Padding.SAME:
245 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
246 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
247 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
248 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
249 left_pad = max(kernel_width - 1 - right_pad, 0)
250 top_pad = max(kernel_height - 1 - bottom_pad, 0)
251 elif padding_type == Padding.VALID:
252 right_pad = max(kernel_width - 2, 0)
253 bottom_pad = max(kernel_height - 2, 0)
254 left_pad = kernel_width - 1
255 top_pad = kernel_height - 1
256 else:
257 raise UnsupportedFeatureError(f"Unknown padding")
258 padding = (top_pad, left_pad, bottom_pad, right_pad)
259 skirt = padding
260 return padding, skirt
261
262
263def fixup_conv2d_backprop(op, arch, nng):
264 if op.type == Op.Conv2DBackpropInput:
265 # flip the inputs
266 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
267 op.type = Op.Conv2DBackpropInputSwitchedBias
268 op.ifm.resampling_mode = resampling_mode.TRANSPOSE
269
270 # Update strides
271 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
272
273 return op
274
275
276# Convert the op to an elementwise add
277def convert_resizebilinear_1x1_to_add(op):
278 op.type = Op.Add
279 op.name = op.name + "_add"
280 op.attrs["resizebilinear"] = True
281 # Create an input tensor filled with zeros
282 shape = op.ofm_shapes[0].as_list()
283 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100284 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200285 tens.quantization = QuantizationParameters(0.0, 255.0)
286 tens.quantization.scale_f32 = 1.0
287 tens.quantization.zero_point = 0
288 tens.consumer_list = [op]
289 tens_op = op.inputs[1].ops[0]
290 tens_op.set_output_tensor(tens)
291 # Set the add inputs
292 op.inputs[1] = op.inputs[0]
293 op.inputs[0] = tens
294 op.set_ifm_ofm_shapes()
295
296 return op
297
298
299# Convert ResizeBilinear to a number of 2x2 pool ops
300def convert_resizebilinear_to_2x2_pool(op):
301 count = 0
302 pre_op = op
303 outputs = op.outputs
304
305 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
306 if op.attrs["align_corners"]:
307 shape_modifier = 1
308 op.attrs["padding"] = Padding.VALID
309 else:
310 shape_modifier = 0
311 op.attrs["padding"] = Padding.SAME
312 op.inputs[0].resampling_mode = resampling_mode.NEAREST
313
314 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
315 out_shape = np.array(op.ofm_shapes[0].get_hw_as_list())
316 if (upscaled_shape == upscaled_shape * 2 - shape_modifier).all():
317 return op
318
319 while (upscaled_shape < out_shape).all():
320 if count == 0:
321 scaled_op = pre_op
322 else:
323 scaled_op = op.clone("_{}".format(count))
324 scaled_op.inputs[0] = pre_op.outputs[0]
325
326 upscaled_shape = upscaled_shape * 2 - shape_modifier
327
328 if (upscaled_shape == out_shape).all():
329 scaled_op.outputs = outputs
330 scaled_op.outputs[0].ops = [scaled_op]
331 else:
332 shape = op.ofm_shapes[0].as_list()
333 shape[1:3] = upscaled_shape
334 out_tens = Tensor(shape, DataType.int16, "{}_{}".format(op.outputs[0].name, count))
335 out_tens.quantization = op.outputs[0].quantization.clone()
336 out_tens.quantization.quant_min = np.iinfo(np.int16).min
337 out_tens.quantization.quant_max = np.iinfo(np.int16).max
338 scaled_op.set_output_tensor(out_tens)
339 pre_op = scaled_op
340 count += 1
341
342 # Setup the scale value
343 if scaled_op.inputs[0].dtype.bits == 8 and scaled_op.outputs[0].dtype.bits == 16:
344 scaled_op.rescale = 128
345 elif scaled_op.inputs[0].dtype.bits == 16 and scaled_op.outputs[0].dtype.bits == 8:
346 scaled_op.rescale = 1 / 128
347 else:
348 scaled_op.rescale = None
349 scaled_op.set_ifm_ofm_shapes()
350
351 return op
352
353
354def fixup_resizebilinear(op, arch, nng):
355 if op.type == Op.ResizeBilinear and op.run_on_npu:
356 if op.ifm_shapes[0] == op.ofm_shapes[0]:
357 # Bypass nop resizebilinear
358 op.inputs = op.inputs[:1]
359 op.type = Op.Identity
360 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
361 convert_resizebilinear_1x1_to_add(op)
362 else:
363 convert_resizebilinear_to_2x2_pool(op)
364
365 return op
366
367
368def convert_nop_split_to_identity(op, arch, nng):
369 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
370 # the list comprehension should return a list with a single tensor
371 # if it shouldn't, remove_passthrough_tensor will fail appropriately
372 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
373 op.type = Op.Identity
374 return op
375
376
377def rewrite_fully_connected_input(op, arch, nng):
378 if op.type == Op.FullyConnected:
379 n_in_elems = op.weights.shape[-2]
380 elms = op.ifm.elements()
381 batch_size = elms // n_in_elems
382 assert batch_size * n_in_elems == elms
383
384 op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
385 return op
386
387
388def convert_batched_fc_shape(op, arch, nng):
389 if op.type == Op.FullyConnected:
390 # Check if the first dimension indicates batching
391 if op.ifm_shapes[0].batch > 1:
392 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
393 n = op.ifm_shapes[0].batch
394 h, w = batching_split.get(n, (1, n))
395 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
396
397 # Reshape Weights to be 4D. IO becomes HWIO
398 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100399 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
400 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200401
402 n = op.ofm_shapes[0].batch
403 h, w = batching_split.get(n, (1, n))
404 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
405 return op
406
407
408def unfuse_activation_function(op):
409 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
410 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
411 op.activation = None
412 out_tens = op.outputs[0]
413 intermediate_tens = out_tens.clone("_act_intermediate")
414 act_op.set_output_tensor(out_tens)
415 act_op.add_input_tensor(intermediate_tens)
416 op.set_output_tensor(intermediate_tens)
417 act_op.set_ifm_ofm_shapes()
418
419
420def rewrite_stridedslice_output(op, arch, nng):
421 if not op.run_on_npu or op.type != Op.StridedSlice:
422 return op
423
424 new_axis_mask = op.attrs["new_axis_mask"]
425 shrink_axis_mask = op.attrs["shrink_axis_mask"]
426
427 if shrink_axis_mask == 0 and new_axis_mask == 0:
428 return op
429
430 axis_4D = [0] * len(op.outputs)
431 for idx, out_tens in enumerate(op.outputs):
432 output_shape = list(out_tens.shape)
433
434 if shrink_axis_mask != 0:
435 n = 0
436 axis = 0
437 while shrink_axis_mask:
438 prev_mask = shrink_axis_mask
439 n += 1
440 shrink_axis_mask &= shrink_axis_mask - 1
441 axis = int(math.log2(prev_mask - shrink_axis_mask))
442 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
443
444 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
445 op.attrs["shrink_axis_mask"] = 0
446 if axis >= 0:
447 axis_4D[idx] = axis + (4 - len(output_shape))
448 else:
449 axis_4D[idx] = axis
450 op.ofm_shapes[idx] = Shape4D(output_shape)
451
452 elif new_axis_mask != 0:
453 n = 0
454 axis = 0
455 while new_axis_mask:
456 prev_mask = new_axis_mask
457 n += 1
458 new_axis_mask &= new_axis_mask - 1
459 axis = int(math.log2(prev_mask - new_axis_mask))
460 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
461 new_axis_mask >>= 1
462
463 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
464 op.attrs["new_axis_mask"] = 0
465 if axis >= 0:
466 axis_4D[idx] = axis + (4 - len(output_shape))
467 else:
468 axis_4D[idx] = axis
469 op.ofm_shapes[idx] = Shape4D(output_shape)
470
471 op.attrs["split_axis_4D"] = axis_4D
472 return op
473
474
475def rewrite_unpack_output(op, arch, nng):
476 tens = op.outputs[0]
477 if op.run_on_npu and op.type == Op.Unpack:
478 # Unpack is also referred to as Unstack
479 axis = int(op.attrs["axis"])
480 if axis < 0: # Convert to positive axis
481 axis = len(op.inputs[0].shape) + 1 + axis
482 op.type = Op.UnpackReshaped
483 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
484
485 axis_4D = axis + (4 - len(desired_output_shape))
486 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
487
488 for idx, out_tens in enumerate(op.outputs):
489 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
490 return op
491
492
493def add_padding_fields(op, arch, nng):
494 if op.run_on_npu:
495 if "padding" in op.attrs:
496 input_shape = op.ifm_shapes[0]
497 output_shape = op.ofm_shapes[0]
498 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
499 kernel_size = op.inputs[1].shape[:2]
500 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
501 kernel_size = op.attrs["ksize"][1:3]
502 else:
503 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
504
505 if op.type == Op.Conv2DBackpropInputSwitchedBias:
506 upscaling_factor = output_shape.height // input_shape.height
507 padding, skirt = calc_upscaled_padding_and_skirt(
508 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
509 )
510 else:
511 padding, skirt = calc_padding_and_skirt(
512 op.attrs["padding"], op.kernel, input_shape, op.attrs.get("explicit_padding"),
513 )
514
515 op.attrs["explicit_padding"] = padding
516 op.attrs["skirt"] = skirt
517
518 return op
519
520
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200521def reorder_depthwise_weights(op, arch, nng):
522 if op.type.is_depthwise_conv2d_op():
523 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100524 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
525 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200526 weight_tensor.weight_transpose_depthwise = True
527
528 return op
529
530
531def optimise_strided_conv(op, arch, nng):
532 stride_x, stride_y = op.get_kernel_stride()
533 ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()
534
535 if (
536 op.type == Op.Conv2DBias
537 and op.op_index == 0
538 and stride_x == 2
539 and op.ifm_shapes[0].depth <= 4
540 and op.ifm_shapes[0].width % 2 == 0
541 and weight_tensor is not None
542 and weight_tensor.shape[1] >= 2
543 ):
544 ifm_shape = op.ifm_shapes[0]
545 # IFM
546 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
547
548 # Weights
549 weight_shape = weight_tensor.shape
550 if weight_shape[1] % 2 != 0:
551 weight_shape[1] = weight_shape[1] + 1
552 padded_array = np.zeros(weight_shape)
553 for i in range(weight_shape[0]):
554 padded_array[i] = np.vstack(
555 [
James Peet7519d502021-07-19 16:47:58 +0100556 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200557 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
558 ]
559 )
James Peet7519d502021-07-19 16:47:58 +0100560 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200561 weight_shape[1] //= 2
562 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100563 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200564 weight_tensor.set_all_shapes(weight_shape)
565 # If multiple copies of the weights are used, we could avoid
566 # them having the same address by changing the value_id
567 weight_tensor.value_id = uuid.uuid4()
568
569 # Strides
570 stride_x = 1
571 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
572
573 return op
574
575
576def convert_conv_to_fc(op, arch, nng):
577 # Conv 1x1 can be equivalent to Fully Connected.
578 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
579 # caching/double buffering for the weights.
580 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
581 if op.type == Op.Conv2DBias:
582 h = op.ifm_shapes[0].height
583 w = op.ifm_shapes[0].width
584 kh, kw, _, _ = op.inputs[1].shape
585 if h == 1 and w == 1 and kh == 1 and kw == 1:
586 # Overwrite this op as a Fully Connected Op
587 op.name += "_fc"
588 op.type = Op.FullyConnected
589 op.attrs = {
590 "weights_format": 0,
591 }
592 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
593 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100594 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
595 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200596
597 DebugDatabase.add_optimised(op, op)
598 return op
599
600
601def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
602 if op.run_on_npu and op.type.is_relu_op():
603 ifm = op.inputs[0]
604 ofm = op.outputs[0]
605 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
606 # and requires its own to be inserted
607 if not check_quantized_tens_scaling_equal(ifm, ofm):
608 # Override this op with its own primary op (avgpool)
609 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
610 # And fuse the original activation function to it
611 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200612 # Add explicit rescaling
613 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
614 multiplier, shift = scaling.quantise_scale(rescale)
615 relu_fused_op.rescale = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200616 # Tidy up and assign the ifm and ofm to the new op
617 ifm.consumer_list.remove(op)
618
619 relu_fused_op.add_input_tensor(ifm)
620 relu_fused_op.set_output_tensor(ofm)
621 relu_fused_op.set_ifm_ofm_shapes()
622 op = relu_fused_op
623 return op
624
625
626def fixup_elementwise_with_scalars(op, arch, nng):
627 if op.type.is_binary_elementwise_op():
628 ifm_tensor, ifm2_tensor, _, _ = op.get_ifm_ifm2_weights_ofm()
629 if ifm2_tensor.shape != [] and ifm_tensor.shape != []:
630 diff = len(ifm_tensor.shape) - len(ifm2_tensor.shape)
631 if diff > 0:
632 ifm2_tensor.shape = full_shape(len(ifm_tensor.shape), ifm2_tensor.shape, 1)
633 elif diff < 0:
634 ifm_tensor.shape = full_shape(len(ifm2_tensor.shape), ifm_tensor.shape, 1)
James Peet7519d502021-07-19 16:47:58 +0100635 elif ifm_tensor.shape == [] and ifm_tensor.values is None:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200636 # IFM is marked as a scalar, but is a result of an operation; change it to a shape of size 1
637 ifm_tensor.shape = len(ifm2_tensor.shape) * [1]
638 ifm_tensor.storage_shape = ifm_tensor.shape
James Peet7519d502021-07-19 16:47:58 +0100639 elif ifm2_tensor.shape == [] and ifm2_tensor.values is None:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200640 # IFM2 is marked as a scalar, but is a result of an operation; change it to a shape of size 1
641 ifm2_tensor.shape = len(ifm_tensor.shape) * [1]
642 ifm2_tensor.storage_shape = ifm2_tensor.shape
643 return op
644
645
646def convert_softmax(op, arch, nng):
647 if op.type == Op.Softmax and op.run_on_npu:
648 softmax = SoftMax(op)
649 op = softmax.get_graph()
650 return op
651
652
653def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
654 r"""Whenever there is a subgraph with this topology:
655
656 Input X For X = -1 or X > 0
657 | \ / This subgraph can be replaced with either
658 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
659 | /
660 Max
661 """
662
663 if op.type == Op.Maximum:
664 # finds the Mul input(s) to the Max
665 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
666 if len(muls) == 1:
667 mul = muls[0].ops[0]
668 elif len(muls) == 2:
669 # In the case both inputs are Muls, find the one with the same input as the Max
670 mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
671 else:
672 # No Mul inputs
673 return op
674
675 # make sure the Mul doesn't have any other consumers
676 mul_ofm = mul.outputs[0]
677 if len(mul_ofm.consumers()) != 1:
678 return op
679 # make sure the Mul doesn't have a fused activation function
680 if mul.activation:
681 return op
682 ifm, ofm = op.get_ifm_ofm()
683 if ifm is None or ofm is None:
684 return op
685
686 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
687 return op
688 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
689 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
690 return op
691
692 # finds the branched input that goes to both the Max and the Mul
693 shared = set(op.inputs) & set(mul.inputs)
694 if len(shared) == 1:
695 shared_in = shared.pop()
696 # find the constant scalar input to the Mul
697 const_tens = (set(mul.inputs) - {shared_in}).pop()
698 # check that it is a scalar
699 if const_tens.shape != []:
700 return op
701 const = const_tens.ops[0]
702 # check that it is a constant
703 if const.type != Op.Const:
704 return op
705 # Remove the Mul from the shared input's consumers
706 shared_in.consumer_list.remove(mul)
707 else:
708 return op
709
710 val = const.outputs[0].values
711 if val >= 0:
712 new_op = Op.LeakyRelu
713 op.attrs["alpha"] = val
714 # to produce bit exact results, the alpha is not enough;
715 # save additional scaling info in attr "alpha_scale", to be used as input
716 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +0100717 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200718 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
719 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
720 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
721 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
722 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
723 elif val == -1:
724 new_op = Op.Abs
725 else:
726 return op
727
728 op.type = new_op
729 op.name = op.name.replace("Maximum", new_op.name)
730 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
731 op.inputs = [shared_in]
732 op.set_ifm_ofm_shapes()
733
734 # Record optimisation in debug database
735 DebugDatabase.add_optimised(op, op)
736
737 return op
738
739
740def convert_hardswish_to_lut(op, arch, nng):
741 if op.type == Op.HardSwish:
742 ifm, ofm = op.get_ifm_ofm()
743 # Generate the LUT
744 ifm_scale = np.double(ifm.quantization.scale_f32)
745 ofm_scale = np.double(ofm.quantization.scale_f32)
746 zp_in = ifm.quantization.zero_point
747 zp_out = ofm.quantization.zero_point
748 ifm_scale_hires = (1 / 128) * ifm_scale
749 relu_multiplier = np.double(3 / 32768)
750 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
751 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
752 # Use 16bit scale
753 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
754 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
755
756 values = []
757 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
758 quantized_min = min(ix)
759 quantized_max = max(ix)
760 for x in ix:
761 input_value = x - zp_in
762 input_value_hires = input_value * 128
763 # Compute the input value on essentially the output scale, not shifted yet
764 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
765 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
766 relu_value = np.int16(input_value_hires)
767 if relu_shift < 31:
768 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
769
770 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
771
772 if relu_shift < 31:
773 relu_value = fp_math.shift_left16(relu_value, 1)
774
775 if relu_shift > 31:
776 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
777
778 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
779 # Now convert that to a 16bit fixedpoint value in [0, 1]
780 relu_value = (relu_value + (1 << 15)) >> 1
781 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
782 shift = 31 - out_shift
783 shift = -shift if shift < 0 else 0
784 # Finally apply the output shift
785 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
786 lut_result = min(quantized_max, max(quantized_min, lut_result))
787 values.append(lut_result)
788 return convert_to_lut(op, values, "hardswish")
789 return op
790
791
792def convert_lrelu_to_mul_max(op, arch):
793 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
794 # (the opposite of convert_mul_max_to_abs_or_lrelu)
795 ifm, ofm = op.get_ifm_ofm()
796 if ifm is None or ofm is None:
797 return op
798
799 # Add multiplication with alpha
800 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
801 mul_alpha.add_input_tensor(ifm)
802 # Create const tensor containing alpha as scalar
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200803 alpha = np.float32(op.attrs["alpha"])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200804 quantization = ifm.quantization.clone()
805 quantization.min = 0
806 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
807 quantization.zero_point = 0
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200808 if np.isinf(1 / alpha):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200809 # Handling of alpha near zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200810 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200811 scalar = 0
812 else:
813 quantization.scale_f32 = alpha
814 scalar = alpha
815 alpha_tens = create_const_tensor(
816 op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
817 )
James Peet7519d502021-07-19 16:47:58 +0100818 alpha_tens.values = np.array([1])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200819 mul_alpha.add_input_tensor(alpha_tens)
820 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
821 mul_alpha.set_output_tensor(fm_alpha)
822 mul_alpha.set_ifm_ofm_shapes()
823 DebugDatabase.add_optimised(op, mul_alpha)
824
825 if check_quantized_tens_scaling_equal(ifm, ofm):
826 # No identity multiplication is needed
827 fm_id = ifm
828 else:
829 # Add multiplication with identity
830 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
831 mul_identity.add_input_tensor(ifm)
832 # Create const tensor containing identity as scalar
833 quantization = ifm.quantization.clone()
834 quantization.min = 0
835 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200836 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200837 quantization.zero_point = 0
838 identity_tens = create_const_tensor(
839 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
840 )
841 mul_identity.add_input_tensor(identity_tens)
842 # Make sure that fm_id is allocated to a different address than fm_alpha
843 fm_id = ofm.clone(op.name + "_id", set_unique=True)
844 mul_identity.set_output_tensor(fm_id)
845 mul_identity.set_ifm_ofm_shapes()
846 DebugDatabase.add_optimised(op, mul_identity)
847
848 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
849 op.type = Op.Maximum
850 op.name = op.name.replace("LeakyRelu", "Maximum")
851 op.inputs = []
852 ifm.consumer_list.remove(op)
853 op.add_input_tensor(fm_alpha)
854 op.add_input_tensor(fm_id)
855 op.set_ifm_ofm_shapes()
856
857 DebugDatabase.add_optimised(op, op)
858 return op
859
860
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200861def convert_to_lut8(op, fn, fn_name):
862 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
863 # fn is a function(real) -> real
864 ifm, ofm = op.get_ifm_ofm()
865 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
866 return op
867 # Generate the LUT
868 ifm_scale = np.double(ifm.quantization.scale_f32)
869 ofm_scale = np.double(ofm.quantization.scale_f32)
870 zp_in = ifm.quantization.zero_point
871 zp_out = ofm.quantization.zero_point
872 values = []
873 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
874 quantized_min = min(ix)
875 quantized_max = max(ix)
876 for x in ix:
877 x_real = ifm_scale * (x - zp_in)
878 y_real = fn(x_real)
879 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
880 lut_result = min(quantized_max, max(quantized_min, lut_result))
881 values.append(lut_result)
882 return convert_to_lut(op, values, fn_name)
883
884
885def convert_lrelu_to_lut(op, arch):
886 ifm, ofm = op.get_ifm_ofm()
887 # Generate the LUT
888 alpha = op.attrs["alpha"]
889 ifm_scale = np.double(ifm.quantization.scale_f32)
890 ofm_scale = np.double(ofm.quantization.scale_f32)
891 zp_in = ifm.quantization.zero_point
892 zp_out = ofm.quantization.zero_point
893 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
894 alpha_scalar = 1
895 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
896 if "alpha_scaling" in op.attrs:
897 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
898 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
899 values = []
900 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
901 quantized_min = min(ix)
902 quantized_max = max(ix)
903 for x in ix:
904 if x < zp_in:
905 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
906 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
907 )
908 else:
909 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
910 lut_result = min(quantized_max, max(quantized_min, lut_result))
911 values.append(lut_result)
912 return convert_to_lut(op, values, "lrelu")
913
914
915def convert_lrelu(op, arch, nng):
916 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
917 if op.type != Op.LeakyRelu:
918 return op
919 ifm, ofm = op.get_ifm_ofm()
920 if ifm is None or ofm is None:
921 return op
922 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
923 # use LUT for int8/uint8
924 return convert_lrelu_to_lut(op, arch)
925 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
926 # use LeakyRelu unmodified for int16 with equal input/output scaling
927 return op
928 return convert_lrelu_to_mul_max(op, arch)
929
930
931def convert_tanh_sigmoid_to_lut(op, arch, nng):
932 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
933 if op.type == Op.Sigmoid:
934 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
935 elif op.type == Op.Tanh:
936 return convert_to_lut8(op, math.tanh, "tanh")
937 return op
938
939
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200940def remove_memory_only_ops(op, arch):
941 if op.run_on_npu and op.type in memory_only_ops:
942 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200943
944
945def fuse_activation_function_with_prev(op, arch, nng):
946 # if op is a no-op: attempts to move the activation function to the preceding op
947 if not op.attrs.get("is_nop", False) or op.activation is None:
948 return op
949 ifm, ofm = op.get_ifm_ofm()
950 if ifm is None or ofm is None:
951 return op
952 # finds the input(s) to the operation
953 prev_op = ifm.ops[0]
954 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
955 fuse = (
956 prev_op.run_on_npu
957 and prev_op.type.npu_block_type != NpuBlockType.Default
958 and len(ifm.ops) == 1
959 and len(prev_op.outputs[0].consumers()) == 1
960 and prev_op.activation is None
961 )
962 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
963 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
964 # LUT currently only works correctly for elementwise ops
965 fuse = False
966 if not fuse:
967 return op
968 # Move the fused activation function + corresponding info to prev_op
969 prev_op.activation = op.activation
970 prev_op.forced_output_quantization = op.forced_output_quantization
971 if op.activation_lut is not None:
972 prev_op.set_activation_lut(op.activation_lut)
973 # Bypass op
974 prev_op.set_output_tensor(ofm)
975 DebugDatabase.add_optimised(op, prev_op)
976 return op
977
978
979def _leading_pad_ok(leading_pad, stride, kernel_size):
980 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
981 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
982 max_size = kernel_size // 2
983 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
984
985
986def replace_pad_by_hw_pad(op: Operation, arch, nng):
987 """
988 Tries to completely remove a PAD operator by using hardware padding.
989 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
990 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
991 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
992 if both operations can be run on the NPU.
993 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
994 """
995 if (
996 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
997 and op.run_on_npu
998 and op.attrs["padding"] == Padding.VALID
999 ):
1000 pad_op = op.ifm.ops[0]
1001 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1002 return op
1003 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1004 return op
1005 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1006 k = op.kernel
1007 k_w, k_h = k.dilated_wh()
1008
1009 # Check if the PAD operator can be replaced by hardware padding
1010 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1011 # Too much padding, it would require hardware padding to actually insert zeros
1012 return op
1013 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1014 return op
1015
1016 if op.type.is_avgpool_op():
1017 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1018 for pad, k_size in (
1019 (left, k_w),
1020 (right, k_w),
1021 (top, k_h),
1022 (bottom, k_h),
1023 ):
1024 if pad not in (0, k_size // 2):
1025 return op
1026 # Average pool is converted to depthwise, because NPU average pool + same padding
1027 # has a special implementation that is different from PAD followed by average pool with
1028 # valid padding.
1029 k_w, k_h = op.kernel.width, op.kernel.height
1030 ifm = op.ifm
1031 # Remember other inputs
1032 other_inputs = op.inputs[1:]
1033 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1034 quantization = QuantizationParameters(0.0, 255.0)
1035 quantization.scale_f32 = 1.0 / (k_w * k_h)
1036 quantization.zero_point = 0
1037 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1038 weights = np.full(shape, 1)
1039
1040 weight_tens = create_const_tensor(
1041 op.name + "_weights",
1042 shape,
1043 op.ifm.dtype,
1044 weights,
1045 np.uint8,
1046 purpose=TensorPurpose.Weights,
1047 quantization=quantization,
1048 )
James Peet7519d502021-07-19 16:47:58 +01001049 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001050 op.type = Op.DepthwiseConv2DBias
1051 op.inputs = []
1052 op.add_input_tensor(ifm)
1053 op.add_input_tensor(weight_tens)
1054 # Add bias tensor, all biases set to 0
1055 op.inputs.append(None)
1056 fixup_bias_tensors(op, arch, nng)
1057 # Add other inputs
1058 op.inputs.extend(other_inputs)
1059 op.rounding_mode = NpuRoundingMode.NATURAL
1060
1061 # Bypass the PAD operator
1062 op.set_input_tensor(pad_op.ifm, 0)
1063 # Adjust the padding attributes of the convolution operator
1064 op.attrs["padding"] = Padding.EXPLICIT
1065 op.attrs["explicit_padding"] = (top, left, bottom, right)
1066 op.set_ifm_ofm_shapes()
1067 return op
1068
1069
1070def convert_pad(op: Operation, arch, nng):
1071 """
1072 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1073 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1074 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1075 """
1076 if op.type != Op.Pad or not op.run_on_npu:
1077 return op
1078 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1079
1080 ifm = op.ifm
1081 assert ifm is not None
1082 ifm_shape = Shape4D(ifm.shape)
1083 ofm = op.ofm
1084 assert ofm is not None
1085 ofm.ops = []
1086 ofm_shape = op.ofm_shapes[0]
1087
1088 # Average pool op that copies IFM to the right place inside the OFM
1089 shp0 = Shape4D(0, 0, 0, 0)
1090 shp_top = shp0.with_height(top)
1091 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1092 avgpool_op.activation = op.activation
1093 quant = ofm.quantization
1094 pad_value = quant.zero_point
1095 # Add operations that fill the borders of the OFM
1096 if top > 0:
1097 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1098 zero_tens = create_const_tensor(
1099 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1100 )
1101 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1102 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1103 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1104 if bottom > 0:
1105 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1106 zero_tens = create_const_tensor(
1107 op.name + "_bottom",
1108 shape.as_list(),
1109 ofm.dtype,
1110 shape.elements() * [pad_value],
1111 np.uint8,
1112 quantization=quant,
1113 )
1114 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1115 create_avg_pool_for_concat(
1116 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1117 )
1118 if left > 0:
1119 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1120 zero_tens = create_const_tensor(
1121 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1122 )
1123 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1124 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1125 if right > 0:
1126 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1127 zero_tens = create_const_tensor(
1128 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1129 )
1130 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1131 create_avg_pool_for_concat(
1132 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1133 )
1134
1135 op.type = Op.ConcatTFLite
1136 return avgpool_op
1137
1138
1139def add_attrs_to_resizebilinear(op, arch, nng):
1140 if op.type == Op.ResizeBilinear and op.run_on_npu:
1141 input_tensor = op.inputs[0]
1142 input_shape = op.ifm_shapes[0]
1143 upscaled_height = input_shape.height * 2
1144 upscaled_width = input_shape.width * 2
1145 out_shape = op.ofm_shapes[0]
1146 if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
1147 # this means the output is supposed to be a x2 upscale,
1148 # so we need to do SAME padding
1149 op.attrs["padding"] = Padding.SAME
1150 elif (
1151 op.attrs["align_corners"]
1152 and out_shape.height == (upscaled_height - 1)
1153 and out_shape.width == (upscaled_width - 1)
1154 ):
1155 # here we can just run the avg pool without padding and
1156 # produce a (M * 2 - 1, N * 2 - 1) sized output
1157 op.attrs["padding"] = Padding.VALID
1158 else:
1159 return op
1160 input_tensor.resampling_mode = resampling_mode.NEAREST
1161 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
1162 return op
1163
1164
1165def fixup_bias_tensors(op, arch, nng):
1166 if op.type.needs_bias() and op.bias is None:
1167 # Op has no bias, add bias tensor filled with zeros
1168 nr_biases = op.inputs[1].shape[-1]
1169 bias_values = [0] * nr_biases
1170 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001171 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1172
1173 return op
1174
1175
1176def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1177 if op.type == Op.Mean and op.run_on_npu:
1178 keep_dims = op.attrs.get("keep_dims", False)
1179 inp, axis = op.inputs
1180 shape = inp.shape
1181 dims = len(shape)
1182
1183 # Height and width axes have different index depending on dimensions
1184 if axis.shape == [] or axis.shape[0] == 1: # single axis
1185 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1186 if dims in (2, 3):
1187 if axis == 0:
1188 h, w = shape[axis], 1
1189 else:
1190 h, w = 1, shape[axis]
1191 else:
1192 if axis == 1:
1193 h, w = shape[axis], 1
1194 else:
1195 h, w = 1, shape[axis]
1196 else: # multiple axes
1197 axis = sorted(axis.values)
1198 h, w = [shape[i] for i in axis]
1199
1200 # Set necessary depthwise attributes
1201 op.attrs.update(
1202 {
1203 "padding": Padding.VALID,
1204 "stride_h": 1,
1205 "stride_w": 1,
1206 "strides": (1, 1, 1, 1),
1207 "depth_multiplier": 1,
1208 "channel_multiplier": 1,
1209 "dilation_h_factor": 1,
1210 "dilation_w_factor": 1,
1211 "dilation": (1, 1, 1, 1),
1212 }
1213 )
1214 # Change op type
1215 op.type = Op.DepthwiseConv2DBias
1216 # Set IFM/OFM shapes after changing op type
1217 op.set_ifm_ofm_shapes()
1218
1219 weight_scale, bias = 1, None
1220 ofmq, ifmq = op.ofm.quantization, inp.quantization
1221 # Set rounding mode, scaling and zero point based on which reference implementation to match
1222 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1223 if inp.dtype == DataType.uint8:
1224 # This attribute means a different scaling calculation is used in order to match reference
1225 op.low_precision_scaling = True
1226 weight_scale = h * w
1227 # Set zero points to 0 as they will be adjusted for with bias term
1228 foq = ofmq.clone()
1229 foq.zero_point = 0
1230 fiq = ifmq.clone()
1231 fiq.zero_point = 0
1232 op.forced_input_quantization = fiq
1233 bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
1234 # If the bias term is outside uint8 range, we need an Add op to apply it.
1235 if bias_term < 0 or bias_term > 255:
1236 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1237 # Bias term has higher bitness (i32) than input/output (u8).
1238 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1239 # the bias can only effectively assume values in the range [-255, 255].
1240 intermediate.dtype = DataType.int16
1241 intermediate.quantization.zero_point = 0
1242 add_op = Operation(Op.Add, op.name + "_bias")
1243 add_op.forced_output_quantization = foq
1244 add_op.add_input_tensor(intermediate)
1245 quant = QuantizationParameters()
1246 quant.zero_point = 0
1247 bias_term_tens = create_const_tensor(
James Peet7519d502021-07-19 16:47:58 +01001248 op.name + "_bias", [1, 1, 1, 1], DataType.int16, [bias_term], np.int16, quantization=quant,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001249 )
1250 add_op.add_input_tensor(bias_term_tens)
1251 add_op.set_output_tensor(op.ofm)
1252 add_op.set_ifm_ofm_shapes()
1253 add_op.activation = op.activation
1254 op.activation = None
1255 op.set_output_tensor(intermediate)
1256 op.set_ifm_ofm_shapes()
1257 # If not, we can just do it with the OFM zero point.
1258 else:
1259 foq.zero_point = bias_term
1260 op.forced_output_quantization = foq
1261 else:
1262 assert inp.dtype == DataType.int8
1263 # Use a depthwise to calculate the sum,
1264 # followed by a multiplication with 1/N to get the MEAN
1265 weight_scale = 1
1266 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1267 intermediate.dtype = DataType.int16
1268 mul_op = Operation(Op.Mul, op.name + "_mul")
1269 mul_op.add_input_tensor(intermediate)
1270 # Create scalar containing 1/N
1271 quant = QuantizationParameters()
1272 quant.zero_point = 0
1273 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1274 # while rounding mode NATURAL would round this to -1.
1275 # This can only occur if N is even, and can be emulated by
1276 # multiplying with a number that is slightly smaller than 1/N.
1277 # It must be so small that other roundings are not affected;
1278 # the calculated value is based on worst case,
1279 # which is sum 256 * N (the maximum sum that can occur with int8)
1280 n = int(h * w)
1281 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1282 quant.scale_f32 = 1 / (n - eps)
1283 scalar = create_const_tensor(
1284 op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
1285 )
1286 mul_op.add_input_tensor(scalar)
1287 mul_op.set_output_tensor(op.ofm)
1288 mul_op.set_ifm_ofm_shapes()
1289 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1290 mul_op.activation = op.activation
1291 op.activation = None
1292 op.set_output_tensor(intermediate)
1293 op.set_ifm_ofm_shapes()
1294 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1295 # Here we can just use a simple AvgPool with truncating rounding,
1296 # as we're emulating simple integer division.
1297 op.rounding_mode = NpuRoundingMode.TRUNCATE
1298 op.type = Op.AvgPool
1299 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1300 else:
1301 op.rounding_mode = NpuRoundingMode.NATURAL
1302 weight_scale = 1 / (h * w)
1303 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1304 bias = -ifmq.zero_point * h * w
1305 fiq = ifmq.clone()
1306 fiq.zero_point = 0
1307 op.forced_input_quantization = fiq
1308
1309 # Change dimensions to 4
1310 if dims < 4:
1311 shape = [1] + shape
1312 if dims == 2:
1313 shape += [1]
1314
1315 # If height is greater than max kernel height, reshape to from HxW to 1x(HxW)
1316 if h > 64:
1317 shape = [shape[0], 1, h * w, shape[3]]
1318 op.ifm_shapes[0] = Shape4D(shape)
1319 if h > 256 and op.type == Op.AvgPool:
1320 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1321
1322 # If the AvgPool version is used, we don't need to do anything else
1323 if op.type == Op.AvgPool:
1324 return op
1325
1326 # Make unit weight tensor quantization
1327 weight_quant = ifmq.clone()
1328 weight_quant.min = 0
1329 weight_quant.max = 255
1330 weight_quant.scale_f32 = weight_scale
1331 weight_quant.zero_point = 0
1332
1333 # Set weight shape to [H,W,C,B]
1334 weight_shape = shape[1:4] + [shape[0]]
1335 # Add unit weight tensor
1336 op.set_input_tensor(
1337 create_const_tensor(
1338 "weights",
1339 weight_shape,
1340 inp.dtype,
1341 np.ones(weight_shape),
1342 value_dtype=np.uint8,
1343 quantization=weight_quant,
1344 ),
1345 1,
1346 )
James Peet7519d502021-07-19 16:47:58 +01001347 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001348
1349 # Add None bias tensor
1350 op.inputs.append(None)
1351 # Add bias tensor
1352 if bias:
1353 bias_shape = [shape[-1]]
1354 op.set_input_tensor(
1355 create_const_tensor(
Tim Hall8ae29292021-07-28 16:52:03 +01001356 "bias", bias_shape, inp.dtype, np.ones(bias_shape) * bias, value_dtype=np.int32, quantization=None,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001357 ),
1358 2,
1359 )
1360
1361 return op
1362
1363
1364def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001365 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001366 return op
1367
1368
1369def tflite_optimise_graph(nng, arch):
1370 # Pre-processing step
1371 pre_process_list = [
1372 supported_operator_check,
1373 set_ifm_ofm_op_shapes,
1374 ]
1375
1376 for idx, sg in enumerate(nng.subgraphs):
1377 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1378 nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
1379 )
1380
1381 # Handle Concat Ops
1382 for idx, sg in enumerate(nng.subgraphs):
1383 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1384 sg.refresh_after_modification()
1385
1386 # Handle Split Ops
1387 for idx, sg in enumerate(nng.subgraphs):
1388 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1389 nng,
1390 sg,
1391 arch,
1392 [],
1393 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1394 rewrite_unsupported=False,
1395 )
1396
1397 for idx, sg in enumerate(nng.subgraphs):
1398 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1399 nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False,
1400 )
1401
1402 # Handle sg input output
1403 for idx, sg in enumerate(nng.subgraphs):
1404 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1405 nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
1406 )
1407
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001408 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001409 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001410 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001411 sg.refresh_after_modification()
1412
1413 # Rewrite of operators
1414 op_rewrite_list = [
1415 set_tensor_equivalence,
1416 convert_mean_to_depthwise_conv_or_avgpool,
1417 convert_depthwise_to_conv,
1418 convert_conv_to_fc,
1419 convert_softmax,
1420 optimise_strided_conv,
1421 convert_hardswish_to_lut,
1422 rewrite_fully_connected_input,
1423 convert_batched_fc_shape,
1424 fixup_conv2d_backprop,
1425 fixup_relus_with_differing_ifm_ofm_scaling,
1426 fixup_elementwise_with_scalars,
1427 reorder_depthwise_weights,
1428 fixup_resizebilinear,
1429 fixup_bias_tensors,
1430 convert_mul_max_to_abs_or_lrelu,
1431 convert_lrelu,
1432 convert_tanh_sigmoid_to_lut,
1433 replace_pad_by_hw_pad,
1434 ]
1435
1436 for idx, sg in enumerate(nng.subgraphs):
1437 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1438 nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
1439 )
1440
1441 for idx, sg in enumerate(nng.subgraphs):
1442 # remove passthrough tensors and attempt further optimizations
1443 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1444 nng,
1445 sg,
1446 arch,
1447 [remove_passthrough_tensor],
1448 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1449 )
1450
1451 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1452 # since ifm/ofm_shapes are of importance to this function
1453 for sg in nng.subgraphs:
1454 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1455 sg.refresh_after_modification()
1456
1457 return nng