blob: 2469a7005ace5f1b6ba1ef6717f243766c14e387 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from . import rewrite_graph
26from . import scaling
27from .api import NpuRoundingMode
28from .data_type import DataType
29from .debug_database import DebugDatabase
30from .errors import UnsupportedFeatureError
31from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020032from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020033from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020034from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020035from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020036from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020037from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020038from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020039from .graph_optimiser_util import needed_total_padding
40from .graph_optimiser_util import set_ifm_ofm_op_shapes
41from .graph_optimiser_util import set_tensor_equivalence
42from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020043from .numeric_util import round_away_zero
44from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020045from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020046from .operation import NpuBlockType
47from .operation import Op
48from .operation import Operation
49from .operation import Padding
50from .operation_util import create_avgpool_nop
51from .operation_util import get_pad_values_from_input
52from .shape4d import Shape4D
53from .softmax import SoftMax
54from .tensor import check_quantized_tens_scaling_equal
55from .tensor import create_const_tensor
56from .tensor import create_equivalence_id
57from .tensor import QuantizationParameters
58from .tensor import Tensor
59from .tensor import TensorPurpose
60from .tflite_mapping import optype_to_builtintype
61
62passthrough_nodes = (Op.Identity,)
63
64
65def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
66 """Creates an average pool for the given concat op/input feature map"""
67 ofm = concat_op.ofm
68 avgpool_op = create_avgpool_nop(name)
69 avgpool_op.inputs = [ifm]
70 avgpool_op.outputs = [ofm]
71
72 avgpool_op.write_offset = write_offset
73 avgpool_op.write_shape = ifm_shape
74 ofm.ops.append(avgpool_op)
75 DebugDatabase.add_optimised(concat_op, avgpool_op)
76 avgpool_op.ifm_shapes.append(ifm_shape)
77 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
78 avgpool_op.memory_function = Op.ConcatSliceWrite
79 return avgpool_op
80
81
82def remove_passthrough_tensor(tens, arch, nng):
83 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
84 assert len(tens.ops[0].inputs) == 1
85 tens = tens.ops[0].inputs[0]
86 return tens
87
88
89def rewrite_concat_ops(op, arch):
90 if not op.run_on_npu or not op.type.is_concat_op():
91 return
92
93 axis_4D = 0
94 ofm = op.ofm
95 ofm.ops = []
96 offset = 0
97
98 unfuse_activation_function(op)
99
100 if op.type == Op.Pack:
101 # Pack is also referred to as Stack
102 axis = int(op.attrs["axis"])
103 if axis < 0: # Convert to positive axis
104 axis = len(op.inputs[0].shape) + 1 + axis
105
106 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
107
108 axis_4D = axis + (4 - len(desired_shape))
109
110 for idx, inp in enumerate(op.inputs):
111 op.ifm_shapes[idx] = Shape4D(desired_shape)
112 op.type = Op.PackReshaped
113
114 inputs, axis = op.get_concat_inputs_axis()
115 for idx, inp in enumerate(inputs):
116 if op.type != Op.PackReshaped:
117 op.ifm_shapes[idx] = Shape4D(inp.shape)
118 if axis >= 0:
119 axis_4D = axis + (4 - len(inp.shape))
120 else:
121 axis_4D = axis
122 write_offset = [0, 0, 0, 0]
123 write_offset[axis_4D] = offset
124 concat_end = offset + op.ifm_shapes[idx][axis_4D]
125 create_avg_pool_for_concat(
126 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
127 )
128 offset = concat_end
129 assert ofm.shape[axis] == offset
130
131 return op
132
133
134def rewrite_split_ops(tens, arch, nng):
135
136 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
137 split_op = tens.ops[0]
138
139 # Not supported so leave it and run on CPU
140 if not split_op.run_on_npu:
141 return tens
142
143 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
144
145 tens.ops = []
146 new_op = Operation(Op.SplitSliceRead, split_op.name)
147 new_op.inputs = [inp]
148 ofm_shape_idx = 0
149 read_shape = offset_end
150
151 # For Split the offset cannot be extracted from the tensor so it has to
152 # be calculated from the index of the output tensor
153 if axis is not None:
154 # Get the start and end of the split
155 offset_start = [0] * 4
156 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
157 for idx, out in enumerate(outputs):
158 if axis_4D_list is not None:
159 axis_4D = axis_4D_list[idx]
160 else:
161 split_op.ofm_shapes[idx] = Shape4D(out.shape)
162 if axis >= 0:
163 axis_4D = axis + (4 - len(out.shape))
164 else:
165 axis_4D = axis
166
167 if out == tens:
168 ofm_shape_idx = idx
169 read_shape = split_op.ofm_shapes[idx]
170 break
171
172 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
173
174 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
175 new_op.read_shapes[0] = read_shape
176 new_op.run_on_npu = True
177 new_op.set_output_tensor(tens)
178 new_op.ifm_shapes.append(Shape4D(inp.shape))
179 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
180 DebugDatabase.add_optimised(split_op, new_op)
181
182 return tens
183
184
185def remove_SplitSliceRead(op, arch):
186
187 if op.type == Op.SplitSliceRead:
188 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
189 if (
190 len(op.ofm.consumer_list) == 1
191 and op.ofm.consumer_list[0] is not None
192 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200193 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200194 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
195 ):
196 # SplitSliceRead can be performed by tensor consumer
197 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200198 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200199 else:
200 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
201 avgpool_op.add_input_tensor(op.ifm)
202 avgpool_op.outputs = [op.ofm]
203 op.ofm.ops.remove(op)
204 op.ofm.ops.append(avgpool_op)
205 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
206 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
207 avgpool_op.read_offsets[0] = op.read_offsets[0]
208 avgpool_op.read_shapes[0] = op.read_shapes[0]
209
210 op.ifm.consumer_list.remove(op)
211 DebugDatabase.add_optimised(op, avgpool_op)
212
213
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200214def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
215 k_w, k_h = kernel.dilated_wh()
216 s_x, s_y = kernel.stride
217 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
218 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
219 if padding_type == Padding.SAME:
220 left_pad = (xpad + 0) // 2
221 right_pad = (xpad + 1) // 2
222 top_pad = (ypad + 0) // 2
223 bottom_pad = (ypad + 1) // 2
224 elif padding_type == Padding.VALID:
225 left_pad = 0
226 right_pad = 0
227 top_pad = 0
228 bottom_pad = 0
229 elif padding_type == Padding.EXPLICIT:
230 # Padding is specified in a PAD operator which has been bypassed.
231 top, left, bottom, right = explicit_padding
232 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
233 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
234 else:
235 raise UnsupportedFeatureError(f"Unknown padding")
236 padding = (top_pad, left_pad, bottom_pad, right_pad)
237 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
238 return padding, skirt
239
240
241def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
242 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
243 if padding_type == Padding.SAME:
244 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
245 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
246 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
247 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
248 left_pad = max(kernel_width - 1 - right_pad, 0)
249 top_pad = max(kernel_height - 1 - bottom_pad, 0)
250 elif padding_type == Padding.VALID:
251 right_pad = max(kernel_width - 2, 0)
252 bottom_pad = max(kernel_height - 2, 0)
253 left_pad = kernel_width - 1
254 top_pad = kernel_height - 1
255 else:
256 raise UnsupportedFeatureError(f"Unknown padding")
257 padding = (top_pad, left_pad, bottom_pad, right_pad)
258 skirt = padding
259 return padding, skirt
260
261
262def fixup_conv2d_backprop(op, arch, nng):
263 if op.type == Op.Conv2DBackpropInput:
264 # flip the inputs
265 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
266 op.type = Op.Conv2DBackpropInputSwitchedBias
267 op.ifm.resampling_mode = resampling_mode.TRANSPOSE
268
269 # Update strides
270 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
271
272 return op
273
274
275# Convert the op to an elementwise add
276def convert_resizebilinear_1x1_to_add(op):
277 op.type = Op.Add
278 op.name = op.name + "_add"
279 op.attrs["resizebilinear"] = True
280 # Create an input tensor filled with zeros
281 shape = op.ofm_shapes[0].as_list()
282 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100283 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200284 tens.quantization = QuantizationParameters(0.0, 255.0)
285 tens.quantization.scale_f32 = 1.0
286 tens.quantization.zero_point = 0
287 tens.consumer_list = [op]
288 tens_op = op.inputs[1].ops[0]
289 tens_op.set_output_tensor(tens)
290 # Set the add inputs
291 op.inputs[1] = op.inputs[0]
292 op.inputs[0] = tens
293 op.set_ifm_ofm_shapes()
294
295 return op
296
297
298# Convert ResizeBilinear to a number of 2x2 pool ops
299def convert_resizebilinear_to_2x2_pool(op):
300 count = 0
301 pre_op = op
302 outputs = op.outputs
303
304 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
305 if op.attrs["align_corners"]:
306 shape_modifier = 1
307 op.attrs["padding"] = Padding.VALID
308 else:
309 shape_modifier = 0
310 op.attrs["padding"] = Padding.SAME
311 op.inputs[0].resampling_mode = resampling_mode.NEAREST
312
313 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
314 out_shape = np.array(op.ofm_shapes[0].get_hw_as_list())
315 if (upscaled_shape == upscaled_shape * 2 - shape_modifier).all():
316 return op
317
318 while (upscaled_shape < out_shape).all():
319 if count == 0:
320 scaled_op = pre_op
321 else:
322 scaled_op = op.clone("_{}".format(count))
323 scaled_op.inputs[0] = pre_op.outputs[0]
324
325 upscaled_shape = upscaled_shape * 2 - shape_modifier
326
327 if (upscaled_shape == out_shape).all():
328 scaled_op.outputs = outputs
329 scaled_op.outputs[0].ops = [scaled_op]
330 else:
331 shape = op.ofm_shapes[0].as_list()
332 shape[1:3] = upscaled_shape
333 out_tens = Tensor(shape, DataType.int16, "{}_{}".format(op.outputs[0].name, count))
334 out_tens.quantization = op.outputs[0].quantization.clone()
335 out_tens.quantization.quant_min = np.iinfo(np.int16).min
336 out_tens.quantization.quant_max = np.iinfo(np.int16).max
337 scaled_op.set_output_tensor(out_tens)
338 pre_op = scaled_op
339 count += 1
340
341 # Setup the scale value
342 if scaled_op.inputs[0].dtype.bits == 8 and scaled_op.outputs[0].dtype.bits == 16:
343 scaled_op.rescale = 128
344 elif scaled_op.inputs[0].dtype.bits == 16 and scaled_op.outputs[0].dtype.bits == 8:
345 scaled_op.rescale = 1 / 128
346 else:
347 scaled_op.rescale = None
348 scaled_op.set_ifm_ofm_shapes()
349
350 return op
351
352
353def fixup_resizebilinear(op, arch, nng):
354 if op.type == Op.ResizeBilinear and op.run_on_npu:
355 if op.ifm_shapes[0] == op.ofm_shapes[0]:
356 # Bypass nop resizebilinear
357 op.inputs = op.inputs[:1]
358 op.type = Op.Identity
359 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
360 convert_resizebilinear_1x1_to_add(op)
361 else:
362 convert_resizebilinear_to_2x2_pool(op)
363
364 return op
365
366
367def convert_nop_split_to_identity(op, arch, nng):
368 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
369 # the list comprehension should return a list with a single tensor
370 # if it shouldn't, remove_passthrough_tensor will fail appropriately
371 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
372 op.type = Op.Identity
373 return op
374
375
376def rewrite_fully_connected_input(op, arch, nng):
377 if op.type == Op.FullyConnected:
378 n_in_elems = op.weights.shape[-2]
379 elms = op.ifm.elements()
380 batch_size = elms // n_in_elems
381 assert batch_size * n_in_elems == elms
382
383 op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
384 return op
385
386
387def convert_batched_fc_shape(op, arch, nng):
388 if op.type == Op.FullyConnected:
389 # Check if the first dimension indicates batching
390 if op.ifm_shapes[0].batch > 1:
391 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
392 n = op.ifm_shapes[0].batch
393 h, w = batching_split.get(n, (1, n))
394 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
395
396 # Reshape Weights to be 4D. IO becomes HWIO
397 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100398 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
399 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200400
401 n = op.ofm_shapes[0].batch
402 h, w = batching_split.get(n, (1, n))
403 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
404 return op
405
406
407def unfuse_activation_function(op):
408 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
409 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
410 op.activation = None
411 out_tens = op.outputs[0]
412 intermediate_tens = out_tens.clone("_act_intermediate")
413 act_op.set_output_tensor(out_tens)
414 act_op.add_input_tensor(intermediate_tens)
415 op.set_output_tensor(intermediate_tens)
416 act_op.set_ifm_ofm_shapes()
417
418
419def rewrite_stridedslice_output(op, arch, nng):
420 if not op.run_on_npu or op.type != Op.StridedSlice:
421 return op
422
423 new_axis_mask = op.attrs["new_axis_mask"]
424 shrink_axis_mask = op.attrs["shrink_axis_mask"]
425
426 if shrink_axis_mask == 0 and new_axis_mask == 0:
427 return op
428
429 axis_4D = [0] * len(op.outputs)
430 for idx, out_tens in enumerate(op.outputs):
431 output_shape = list(out_tens.shape)
432
433 if shrink_axis_mask != 0:
434 n = 0
435 axis = 0
436 while shrink_axis_mask:
437 prev_mask = shrink_axis_mask
438 n += 1
439 shrink_axis_mask &= shrink_axis_mask - 1
440 axis = int(math.log2(prev_mask - shrink_axis_mask))
441 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
442
443 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
444 op.attrs["shrink_axis_mask"] = 0
445 if axis >= 0:
446 axis_4D[idx] = axis + (4 - len(output_shape))
447 else:
448 axis_4D[idx] = axis
449 op.ofm_shapes[idx] = Shape4D(output_shape)
450
451 elif new_axis_mask != 0:
452 n = 0
453 axis = 0
454 while new_axis_mask:
455 prev_mask = new_axis_mask
456 n += 1
457 new_axis_mask &= new_axis_mask - 1
458 axis = int(math.log2(prev_mask - new_axis_mask))
459 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
460 new_axis_mask >>= 1
461
462 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
463 op.attrs["new_axis_mask"] = 0
464 if axis >= 0:
465 axis_4D[idx] = axis + (4 - len(output_shape))
466 else:
467 axis_4D[idx] = axis
468 op.ofm_shapes[idx] = Shape4D(output_shape)
469
470 op.attrs["split_axis_4D"] = axis_4D
471 return op
472
473
474def rewrite_unpack_output(op, arch, nng):
475 tens = op.outputs[0]
476 if op.run_on_npu and op.type == Op.Unpack:
477 # Unpack is also referred to as Unstack
478 axis = int(op.attrs["axis"])
479 if axis < 0: # Convert to positive axis
480 axis = len(op.inputs[0].shape) + 1 + axis
481 op.type = Op.UnpackReshaped
482 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
483
484 axis_4D = axis + (4 - len(desired_output_shape))
485 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
486
487 for idx, out_tens in enumerate(op.outputs):
488 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
489 return op
490
491
492def add_padding_fields(op, arch, nng):
493 if op.run_on_npu:
494 if "padding" in op.attrs:
495 input_shape = op.ifm_shapes[0]
496 output_shape = op.ofm_shapes[0]
497 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
498 kernel_size = op.inputs[1].shape[:2]
499 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
500 kernel_size = op.attrs["ksize"][1:3]
501 else:
502 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
503
504 if op.type == Op.Conv2DBackpropInputSwitchedBias:
505 upscaling_factor = output_shape.height // input_shape.height
506 padding, skirt = calc_upscaled_padding_and_skirt(
507 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
508 )
509 else:
510 padding, skirt = calc_padding_and_skirt(
511 op.attrs["padding"], op.kernel, input_shape, op.attrs.get("explicit_padding"),
512 )
513
514 op.attrs["explicit_padding"] = padding
515 op.attrs["skirt"] = skirt
516
517 return op
518
519
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200520def reorder_depthwise_weights(op, arch, nng):
521 if op.type.is_depthwise_conv2d_op():
522 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100523 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
524 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200525 weight_tensor.weight_transpose_depthwise = True
526
527 return op
528
529
530def optimise_strided_conv(op, arch, nng):
531 stride_x, stride_y = op.get_kernel_stride()
532 ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()
533
534 if (
535 op.type == Op.Conv2DBias
536 and op.op_index == 0
537 and stride_x == 2
538 and op.ifm_shapes[0].depth <= 4
539 and op.ifm_shapes[0].width % 2 == 0
540 and weight_tensor is not None
541 and weight_tensor.shape[1] >= 2
542 ):
543 ifm_shape = op.ifm_shapes[0]
544 # IFM
545 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
546
547 # Weights
548 weight_shape = weight_tensor.shape
549 if weight_shape[1] % 2 != 0:
550 weight_shape[1] = weight_shape[1] + 1
551 padded_array = np.zeros(weight_shape)
552 for i in range(weight_shape[0]):
553 padded_array[i] = np.vstack(
554 [
James Peet7519d502021-07-19 16:47:58 +0100555 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200556 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
557 ]
558 )
James Peet7519d502021-07-19 16:47:58 +0100559 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200560 weight_shape[1] //= 2
561 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100562 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200563 weight_tensor.set_all_shapes(weight_shape)
564 # If multiple copies of the weights are used, we could avoid
565 # them having the same address by changing the value_id
566 weight_tensor.value_id = uuid.uuid4()
567
568 # Strides
569 stride_x = 1
570 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
571
572 return op
573
574
575def convert_conv_to_fc(op, arch, nng):
576 # Conv 1x1 can be equivalent to Fully Connected.
577 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
578 # caching/double buffering for the weights.
579 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
580 if op.type == Op.Conv2DBias:
581 h = op.ifm_shapes[0].height
582 w = op.ifm_shapes[0].width
583 kh, kw, _, _ = op.inputs[1].shape
584 if h == 1 and w == 1 and kh == 1 and kw == 1:
585 # Overwrite this op as a Fully Connected Op
586 op.name += "_fc"
587 op.type = Op.FullyConnected
588 op.attrs = {
589 "weights_format": 0,
590 }
591 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
592 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100593 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
594 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200595
596 DebugDatabase.add_optimised(op, op)
597 return op
598
599
600def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
601 if op.run_on_npu and op.type.is_relu_op():
602 ifm = op.inputs[0]
603 ofm = op.outputs[0]
604 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
605 # and requires its own to be inserted
606 if not check_quantized_tens_scaling_equal(ifm, ofm):
607 # Override this op with its own primary op (avgpool)
608 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
609 # And fuse the original activation function to it
610 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200611 # Add explicit rescaling
612 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
613 multiplier, shift = scaling.quantise_scale(rescale)
614 relu_fused_op.rescale = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200615 # Tidy up and assign the ifm and ofm to the new op
616 ifm.consumer_list.remove(op)
617
618 relu_fused_op.add_input_tensor(ifm)
619 relu_fused_op.set_output_tensor(ofm)
620 relu_fused_op.set_ifm_ofm_shapes()
621 op = relu_fused_op
622 return op
623
624
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200625def convert_softmax(op, arch, nng):
626 if op.type == Op.Softmax and op.run_on_npu:
627 softmax = SoftMax(op)
628 op = softmax.get_graph()
629 return op
630
631
632def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
633 r"""Whenever there is a subgraph with this topology:
634
635 Input X For X = -1 or X > 0
636 | \ / This subgraph can be replaced with either
637 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
638 | /
639 Max
640 """
641
642 if op.type == Op.Maximum:
643 # finds the Mul input(s) to the Max
644 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
645 if len(muls) == 1:
646 mul = muls[0].ops[0]
647 elif len(muls) == 2:
648 # In the case both inputs are Muls, find the one with the same input as the Max
649 mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
650 else:
651 # No Mul inputs
652 return op
653
654 # make sure the Mul doesn't have any other consumers
655 mul_ofm = mul.outputs[0]
656 if len(mul_ofm.consumers()) != 1:
657 return op
658 # make sure the Mul doesn't have a fused activation function
659 if mul.activation:
660 return op
661 ifm, ofm = op.get_ifm_ofm()
662 if ifm is None or ofm is None:
663 return op
664
665 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
666 return op
667 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
668 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
669 return op
670
671 # finds the branched input that goes to both the Max and the Mul
672 shared = set(op.inputs) & set(mul.inputs)
673 if len(shared) == 1:
674 shared_in = shared.pop()
675 # find the constant scalar input to the Mul
676 const_tens = (set(mul.inputs) - {shared_in}).pop()
677 # check that it is a scalar
678 if const_tens.shape != []:
679 return op
680 const = const_tens.ops[0]
681 # check that it is a constant
682 if const.type != Op.Const:
683 return op
684 # Remove the Mul from the shared input's consumers
685 shared_in.consumer_list.remove(mul)
686 else:
687 return op
688
689 val = const.outputs[0].values
690 if val >= 0:
691 new_op = Op.LeakyRelu
692 op.attrs["alpha"] = val
693 # to produce bit exact results, the alpha is not enough;
694 # save additional scaling info in attr "alpha_scale", to be used as input
695 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +0100696 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200697 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
698 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
699 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
700 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
701 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
702 elif val == -1:
703 new_op = Op.Abs
704 else:
705 return op
706
707 op.type = new_op
708 op.name = op.name.replace("Maximum", new_op.name)
709 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
710 op.inputs = [shared_in]
711 op.set_ifm_ofm_shapes()
712
713 # Record optimisation in debug database
714 DebugDatabase.add_optimised(op, op)
715
716 return op
717
718
719def convert_hardswish_to_lut(op, arch, nng):
720 if op.type == Op.HardSwish:
721 ifm, ofm = op.get_ifm_ofm()
722 # Generate the LUT
723 ifm_scale = np.double(ifm.quantization.scale_f32)
724 ofm_scale = np.double(ofm.quantization.scale_f32)
725 zp_in = ifm.quantization.zero_point
726 zp_out = ofm.quantization.zero_point
727 ifm_scale_hires = (1 / 128) * ifm_scale
728 relu_multiplier = np.double(3 / 32768)
729 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
730 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
731 # Use 16bit scale
732 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
733 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
734
735 values = []
736 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
737 quantized_min = min(ix)
738 quantized_max = max(ix)
739 for x in ix:
740 input_value = x - zp_in
741 input_value_hires = input_value * 128
742 # Compute the input value on essentially the output scale, not shifted yet
743 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
744 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
745 relu_value = np.int16(input_value_hires)
746 if relu_shift < 31:
747 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
748
749 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
750
751 if relu_shift < 31:
752 relu_value = fp_math.shift_left16(relu_value, 1)
753
754 if relu_shift > 31:
755 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
756
757 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
758 # Now convert that to a 16bit fixedpoint value in [0, 1]
759 relu_value = (relu_value + (1 << 15)) >> 1
760 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
761 shift = 31 - out_shift
762 shift = -shift if shift < 0 else 0
763 # Finally apply the output shift
764 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
765 lut_result = min(quantized_max, max(quantized_min, lut_result))
766 values.append(lut_result)
767 return convert_to_lut(op, values, "hardswish")
768 return op
769
770
771def convert_lrelu_to_mul_max(op, arch):
772 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
773 # (the opposite of convert_mul_max_to_abs_or_lrelu)
774 ifm, ofm = op.get_ifm_ofm()
775 if ifm is None or ofm is None:
776 return op
777
778 # Add multiplication with alpha
779 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
780 mul_alpha.add_input_tensor(ifm)
781 # Create const tensor containing alpha as scalar
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200782 alpha = np.float32(op.attrs["alpha"])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200783 quantization = ifm.quantization.clone()
784 quantization.min = 0
785 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
786 quantization.zero_point = 0
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200787 if np.isinf(1 / alpha):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200788 # Handling of alpha near zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200789 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200790 scalar = 0
791 else:
792 quantization.scale_f32 = alpha
793 scalar = alpha
794 alpha_tens = create_const_tensor(
795 op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
796 )
James Peet7519d502021-07-19 16:47:58 +0100797 alpha_tens.values = np.array([1])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200798 mul_alpha.add_input_tensor(alpha_tens)
799 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
800 mul_alpha.set_output_tensor(fm_alpha)
801 mul_alpha.set_ifm_ofm_shapes()
802 DebugDatabase.add_optimised(op, mul_alpha)
803
804 if check_quantized_tens_scaling_equal(ifm, ofm):
805 # No identity multiplication is needed
806 fm_id = ifm
807 else:
808 # Add multiplication with identity
809 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
810 mul_identity.add_input_tensor(ifm)
811 # Create const tensor containing identity as scalar
812 quantization = ifm.quantization.clone()
813 quantization.min = 0
814 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200815 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200816 quantization.zero_point = 0
817 identity_tens = create_const_tensor(
818 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
819 )
820 mul_identity.add_input_tensor(identity_tens)
821 # Make sure that fm_id is allocated to a different address than fm_alpha
822 fm_id = ofm.clone(op.name + "_id", set_unique=True)
823 mul_identity.set_output_tensor(fm_id)
824 mul_identity.set_ifm_ofm_shapes()
825 DebugDatabase.add_optimised(op, mul_identity)
826
827 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
828 op.type = Op.Maximum
829 op.name = op.name.replace("LeakyRelu", "Maximum")
830 op.inputs = []
831 ifm.consumer_list.remove(op)
832 op.add_input_tensor(fm_alpha)
833 op.add_input_tensor(fm_id)
834 op.set_ifm_ofm_shapes()
835
836 DebugDatabase.add_optimised(op, op)
837 return op
838
839
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200840def convert_to_lut8(op, fn, fn_name):
841 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
842 # fn is a function(real) -> real
843 ifm, ofm = op.get_ifm_ofm()
844 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
845 return op
846 # Generate the LUT
847 ifm_scale = np.double(ifm.quantization.scale_f32)
848 ofm_scale = np.double(ofm.quantization.scale_f32)
849 zp_in = ifm.quantization.zero_point
850 zp_out = ofm.quantization.zero_point
851 values = []
852 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
853 quantized_min = min(ix)
854 quantized_max = max(ix)
855 for x in ix:
856 x_real = ifm_scale * (x - zp_in)
857 y_real = fn(x_real)
858 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
859 lut_result = min(quantized_max, max(quantized_min, lut_result))
860 values.append(lut_result)
861 return convert_to_lut(op, values, fn_name)
862
863
864def convert_lrelu_to_lut(op, arch):
865 ifm, ofm = op.get_ifm_ofm()
866 # Generate the LUT
867 alpha = op.attrs["alpha"]
868 ifm_scale = np.double(ifm.quantization.scale_f32)
869 ofm_scale = np.double(ofm.quantization.scale_f32)
870 zp_in = ifm.quantization.zero_point
871 zp_out = ofm.quantization.zero_point
872 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
873 alpha_scalar = 1
874 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
875 if "alpha_scaling" in op.attrs:
876 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
877 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
878 values = []
879 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
880 quantized_min = min(ix)
881 quantized_max = max(ix)
882 for x in ix:
883 if x < zp_in:
884 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
885 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
886 )
887 else:
888 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
889 lut_result = min(quantized_max, max(quantized_min, lut_result))
890 values.append(lut_result)
891 return convert_to_lut(op, values, "lrelu")
892
893
894def convert_lrelu(op, arch, nng):
895 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
896 if op.type != Op.LeakyRelu:
897 return op
898 ifm, ofm = op.get_ifm_ofm()
899 if ifm is None or ofm is None:
900 return op
901 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
902 # use LUT for int8/uint8
903 return convert_lrelu_to_lut(op, arch)
904 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
905 # use LeakyRelu unmodified for int16 with equal input/output scaling
906 return op
907 return convert_lrelu_to_mul_max(op, arch)
908
909
910def convert_tanh_sigmoid_to_lut(op, arch, nng):
911 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
912 if op.type == Op.Sigmoid:
913 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
914 elif op.type == Op.Tanh:
915 return convert_to_lut8(op, math.tanh, "tanh")
916 return op
917
918
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200919def remove_memory_only_ops(op, arch):
920 if op.run_on_npu and op.type in memory_only_ops:
921 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200922
923
924def fuse_activation_function_with_prev(op, arch, nng):
925 # if op is a no-op: attempts to move the activation function to the preceding op
926 if not op.attrs.get("is_nop", False) or op.activation is None:
927 return op
928 ifm, ofm = op.get_ifm_ofm()
929 if ifm is None or ofm is None:
930 return op
931 # finds the input(s) to the operation
932 prev_op = ifm.ops[0]
933 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
934 fuse = (
935 prev_op.run_on_npu
936 and prev_op.type.npu_block_type != NpuBlockType.Default
937 and len(ifm.ops) == 1
938 and len(prev_op.outputs[0].consumers()) == 1
939 and prev_op.activation is None
940 )
941 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
942 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
943 # LUT currently only works correctly for elementwise ops
944 fuse = False
945 if not fuse:
946 return op
947 # Move the fused activation function + corresponding info to prev_op
948 prev_op.activation = op.activation
949 prev_op.forced_output_quantization = op.forced_output_quantization
950 if op.activation_lut is not None:
951 prev_op.set_activation_lut(op.activation_lut)
952 # Bypass op
953 prev_op.set_output_tensor(ofm)
954 DebugDatabase.add_optimised(op, prev_op)
955 return op
956
957
958def _leading_pad_ok(leading_pad, stride, kernel_size):
959 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
960 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
961 max_size = kernel_size // 2
962 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
963
964
965def replace_pad_by_hw_pad(op: Operation, arch, nng):
966 """
967 Tries to completely remove a PAD operator by using hardware padding.
968 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
969 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
970 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
971 if both operations can be run on the NPU.
972 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
973 """
974 if (
975 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
976 and op.run_on_npu
977 and op.attrs["padding"] == Padding.VALID
978 ):
979 pad_op = op.ifm.ops[0]
980 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
981 return op
982 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
983 return op
984 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
985 k = op.kernel
986 k_w, k_h = k.dilated_wh()
987
988 # Check if the PAD operator can be replaced by hardware padding
989 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
990 # Too much padding, it would require hardware padding to actually insert zeros
991 return op
992 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
993 return op
994
995 if op.type.is_avgpool_op():
996 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
997 for pad, k_size in (
998 (left, k_w),
999 (right, k_w),
1000 (top, k_h),
1001 (bottom, k_h),
1002 ):
1003 if pad not in (0, k_size // 2):
1004 return op
1005 # Average pool is converted to depthwise, because NPU average pool + same padding
1006 # has a special implementation that is different from PAD followed by average pool with
1007 # valid padding.
1008 k_w, k_h = op.kernel.width, op.kernel.height
1009 ifm = op.ifm
1010 # Remember other inputs
1011 other_inputs = op.inputs[1:]
1012 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1013 quantization = QuantizationParameters(0.0, 255.0)
1014 quantization.scale_f32 = 1.0 / (k_w * k_h)
1015 quantization.zero_point = 0
1016 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1017 weights = np.full(shape, 1)
1018
1019 weight_tens = create_const_tensor(
1020 op.name + "_weights",
1021 shape,
1022 op.ifm.dtype,
1023 weights,
1024 np.uint8,
1025 purpose=TensorPurpose.Weights,
1026 quantization=quantization,
1027 )
James Peet7519d502021-07-19 16:47:58 +01001028 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001029 op.type = Op.DepthwiseConv2DBias
1030 op.inputs = []
1031 op.add_input_tensor(ifm)
1032 op.add_input_tensor(weight_tens)
1033 # Add bias tensor, all biases set to 0
1034 op.inputs.append(None)
1035 fixup_bias_tensors(op, arch, nng)
1036 # Add other inputs
1037 op.inputs.extend(other_inputs)
1038 op.rounding_mode = NpuRoundingMode.NATURAL
1039
1040 # Bypass the PAD operator
1041 op.set_input_tensor(pad_op.ifm, 0)
1042 # Adjust the padding attributes of the convolution operator
1043 op.attrs["padding"] = Padding.EXPLICIT
1044 op.attrs["explicit_padding"] = (top, left, bottom, right)
1045 op.set_ifm_ofm_shapes()
1046 return op
1047
1048
1049def convert_pad(op: Operation, arch, nng):
1050 """
1051 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1052 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1053 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1054 """
1055 if op.type != Op.Pad or not op.run_on_npu:
1056 return op
1057 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1058
1059 ifm = op.ifm
1060 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001061 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001062 ofm = op.ofm
1063 assert ofm is not None
1064 ofm.ops = []
1065 ofm_shape = op.ofm_shapes[0]
1066
1067 # Average pool op that copies IFM to the right place inside the OFM
1068 shp0 = Shape4D(0, 0, 0, 0)
1069 shp_top = shp0.with_height(top)
1070 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1071 avgpool_op.activation = op.activation
1072 quant = ofm.quantization
1073 pad_value = quant.zero_point
1074 # Add operations that fill the borders of the OFM
1075 if top > 0:
1076 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1077 zero_tens = create_const_tensor(
1078 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1079 )
1080 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1081 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1082 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1083 if bottom > 0:
1084 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1085 zero_tens = create_const_tensor(
1086 op.name + "_bottom",
1087 shape.as_list(),
1088 ofm.dtype,
1089 shape.elements() * [pad_value],
1090 np.uint8,
1091 quantization=quant,
1092 )
1093 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1094 create_avg_pool_for_concat(
1095 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1096 )
1097 if left > 0:
1098 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1099 zero_tens = create_const_tensor(
1100 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1101 )
1102 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1103 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1104 if right > 0:
1105 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1106 zero_tens = create_const_tensor(
1107 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1108 )
1109 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1110 create_avg_pool_for_concat(
1111 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1112 )
1113
1114 op.type = Op.ConcatTFLite
1115 return avgpool_op
1116
1117
1118def add_attrs_to_resizebilinear(op, arch, nng):
1119 if op.type == Op.ResizeBilinear and op.run_on_npu:
1120 input_tensor = op.inputs[0]
1121 input_shape = op.ifm_shapes[0]
1122 upscaled_height = input_shape.height * 2
1123 upscaled_width = input_shape.width * 2
1124 out_shape = op.ofm_shapes[0]
1125 if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
1126 # this means the output is supposed to be a x2 upscale,
1127 # so we need to do SAME padding
1128 op.attrs["padding"] = Padding.SAME
1129 elif (
1130 op.attrs["align_corners"]
1131 and out_shape.height == (upscaled_height - 1)
1132 and out_shape.width == (upscaled_width - 1)
1133 ):
1134 # here we can just run the avg pool without padding and
1135 # produce a (M * 2 - 1, N * 2 - 1) sized output
1136 op.attrs["padding"] = Padding.VALID
1137 else:
1138 return op
1139 input_tensor.resampling_mode = resampling_mode.NEAREST
1140 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
1141 return op
1142
1143
1144def fixup_bias_tensors(op, arch, nng):
1145 if op.type.needs_bias() and op.bias is None:
1146 # Op has no bias, add bias tensor filled with zeros
1147 nr_biases = op.inputs[1].shape[-1]
1148 bias_values = [0] * nr_biases
1149 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001150 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1151
1152 return op
1153
1154
1155def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1156 if op.type == Op.Mean and op.run_on_npu:
1157 keep_dims = op.attrs.get("keep_dims", False)
1158 inp, axis = op.inputs
1159 shape = inp.shape
1160 dims = len(shape)
1161
1162 # Height and width axes have different index depending on dimensions
1163 if axis.shape == [] or axis.shape[0] == 1: # single axis
1164 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1165 if dims in (2, 3):
1166 if axis == 0:
1167 h, w = shape[axis], 1
1168 else:
1169 h, w = 1, shape[axis]
1170 else:
1171 if axis == 1:
1172 h, w = shape[axis], 1
1173 else:
1174 h, w = 1, shape[axis]
1175 else: # multiple axes
1176 axis = sorted(axis.values)
1177 h, w = [shape[i] for i in axis]
1178
1179 # Set necessary depthwise attributes
1180 op.attrs.update(
1181 {
1182 "padding": Padding.VALID,
1183 "stride_h": 1,
1184 "stride_w": 1,
1185 "strides": (1, 1, 1, 1),
1186 "depth_multiplier": 1,
1187 "channel_multiplier": 1,
1188 "dilation_h_factor": 1,
1189 "dilation_w_factor": 1,
1190 "dilation": (1, 1, 1, 1),
1191 }
1192 )
1193 # Change op type
1194 op.type = Op.DepthwiseConv2DBias
1195 # Set IFM/OFM shapes after changing op type
1196 op.set_ifm_ofm_shapes()
1197
1198 weight_scale, bias = 1, None
1199 ofmq, ifmq = op.ofm.quantization, inp.quantization
1200 # Set rounding mode, scaling and zero point based on which reference implementation to match
1201 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1202 if inp.dtype == DataType.uint8:
1203 # This attribute means a different scaling calculation is used in order to match reference
1204 op.low_precision_scaling = True
1205 weight_scale = h * w
1206 # Set zero points to 0 as they will be adjusted for with bias term
1207 foq = ofmq.clone()
1208 foq.zero_point = 0
1209 fiq = ifmq.clone()
1210 fiq.zero_point = 0
1211 op.forced_input_quantization = fiq
1212 bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
1213 # If the bias term is outside uint8 range, we need an Add op to apply it.
1214 if bias_term < 0 or bias_term > 255:
1215 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1216 # Bias term has higher bitness (i32) than input/output (u8).
1217 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1218 # the bias can only effectively assume values in the range [-255, 255].
1219 intermediate.dtype = DataType.int16
1220 intermediate.quantization.zero_point = 0
1221 add_op = Operation(Op.Add, op.name + "_bias")
1222 add_op.forced_output_quantization = foq
1223 add_op.add_input_tensor(intermediate)
1224 quant = QuantizationParameters()
1225 quant.zero_point = 0
1226 bias_term_tens = create_const_tensor(
James Peet7519d502021-07-19 16:47:58 +01001227 op.name + "_bias", [1, 1, 1, 1], DataType.int16, [bias_term], np.int16, quantization=quant,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001228 )
1229 add_op.add_input_tensor(bias_term_tens)
1230 add_op.set_output_tensor(op.ofm)
1231 add_op.set_ifm_ofm_shapes()
1232 add_op.activation = op.activation
1233 op.activation = None
1234 op.set_output_tensor(intermediate)
1235 op.set_ifm_ofm_shapes()
1236 # If not, we can just do it with the OFM zero point.
1237 else:
1238 foq.zero_point = bias_term
1239 op.forced_output_quantization = foq
1240 else:
1241 assert inp.dtype == DataType.int8
1242 # Use a depthwise to calculate the sum,
1243 # followed by a multiplication with 1/N to get the MEAN
1244 weight_scale = 1
1245 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1246 intermediate.dtype = DataType.int16
1247 mul_op = Operation(Op.Mul, op.name + "_mul")
1248 mul_op.add_input_tensor(intermediate)
1249 # Create scalar containing 1/N
1250 quant = QuantizationParameters()
1251 quant.zero_point = 0
1252 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1253 # while rounding mode NATURAL would round this to -1.
1254 # This can only occur if N is even, and can be emulated by
1255 # multiplying with a number that is slightly smaller than 1/N.
1256 # It must be so small that other roundings are not affected;
1257 # the calculated value is based on worst case,
1258 # which is sum 256 * N (the maximum sum that can occur with int8)
1259 n = int(h * w)
1260 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1261 quant.scale_f32 = 1 / (n - eps)
1262 scalar = create_const_tensor(
1263 op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
1264 )
1265 mul_op.add_input_tensor(scalar)
1266 mul_op.set_output_tensor(op.ofm)
1267 mul_op.set_ifm_ofm_shapes()
1268 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1269 mul_op.activation = op.activation
1270 op.activation = None
1271 op.set_output_tensor(intermediate)
1272 op.set_ifm_ofm_shapes()
1273 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1274 # Here we can just use a simple AvgPool with truncating rounding,
1275 # as we're emulating simple integer division.
1276 op.rounding_mode = NpuRoundingMode.TRUNCATE
1277 op.type = Op.AvgPool
1278 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1279 else:
1280 op.rounding_mode = NpuRoundingMode.NATURAL
1281 weight_scale = 1 / (h * w)
1282 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1283 bias = -ifmq.zero_point * h * w
1284 fiq = ifmq.clone()
1285 fiq.zero_point = 0
1286 op.forced_input_quantization = fiq
1287
1288 # Change dimensions to 4
1289 if dims < 4:
1290 shape = [1] + shape
1291 if dims == 2:
1292 shape += [1]
1293
1294 # If height is greater than max kernel height, reshape to from HxW to 1x(HxW)
1295 if h > 64:
1296 shape = [shape[0], 1, h * w, shape[3]]
1297 op.ifm_shapes[0] = Shape4D(shape)
1298 if h > 256 and op.type == Op.AvgPool:
1299 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1300
1301 # If the AvgPool version is used, we don't need to do anything else
1302 if op.type == Op.AvgPool:
1303 return op
1304
1305 # Make unit weight tensor quantization
1306 weight_quant = ifmq.clone()
1307 weight_quant.min = 0
1308 weight_quant.max = 255
1309 weight_quant.scale_f32 = weight_scale
1310 weight_quant.zero_point = 0
1311
1312 # Set weight shape to [H,W,C,B]
1313 weight_shape = shape[1:4] + [shape[0]]
1314 # Add unit weight tensor
1315 op.set_input_tensor(
1316 create_const_tensor(
1317 "weights",
1318 weight_shape,
1319 inp.dtype,
1320 np.ones(weight_shape),
1321 value_dtype=np.uint8,
1322 quantization=weight_quant,
1323 ),
1324 1,
1325 )
James Peet7519d502021-07-19 16:47:58 +01001326 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001327
1328 # Add None bias tensor
1329 op.inputs.append(None)
1330 # Add bias tensor
1331 if bias:
1332 bias_shape = [shape[-1]]
1333 op.set_input_tensor(
1334 create_const_tensor(
Tim Hall8ae29292021-07-28 16:52:03 +01001335 "bias", bias_shape, inp.dtype, np.ones(bias_shape) * bias, value_dtype=np.int32, quantization=None,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001336 ),
1337 2,
1338 )
1339
1340 return op
1341
1342
1343def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001344 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001345 return op
1346
1347
1348def tflite_optimise_graph(nng, arch):
1349 # Pre-processing step
1350 pre_process_list = [
1351 supported_operator_check,
1352 set_ifm_ofm_op_shapes,
1353 ]
1354
1355 for idx, sg in enumerate(nng.subgraphs):
1356 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1357 nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
1358 )
1359
1360 # Handle Concat Ops
1361 for idx, sg in enumerate(nng.subgraphs):
1362 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1363 sg.refresh_after_modification()
1364
1365 # Handle Split Ops
1366 for idx, sg in enumerate(nng.subgraphs):
1367 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1368 nng,
1369 sg,
1370 arch,
1371 [],
1372 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1373 rewrite_unsupported=False,
1374 )
1375
1376 for idx, sg in enumerate(nng.subgraphs):
1377 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1378 nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False,
1379 )
1380
1381 # Handle sg input output
1382 for idx, sg in enumerate(nng.subgraphs):
1383 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1384 nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
1385 )
1386
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001387 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001388 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001389 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001390 sg.refresh_after_modification()
1391
1392 # Rewrite of operators
1393 op_rewrite_list = [
1394 set_tensor_equivalence,
1395 convert_mean_to_depthwise_conv_or_avgpool,
1396 convert_depthwise_to_conv,
1397 convert_conv_to_fc,
1398 convert_softmax,
1399 optimise_strided_conv,
1400 convert_hardswish_to_lut,
1401 rewrite_fully_connected_input,
1402 convert_batched_fc_shape,
1403 fixup_conv2d_backprop,
1404 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001405 reorder_depthwise_weights,
1406 fixup_resizebilinear,
1407 fixup_bias_tensors,
1408 convert_mul_max_to_abs_or_lrelu,
1409 convert_lrelu,
1410 convert_tanh_sigmoid_to_lut,
1411 replace_pad_by_hw_pad,
1412 ]
1413
1414 for idx, sg in enumerate(nng.subgraphs):
1415 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1416 nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
1417 )
1418
1419 for idx, sg in enumerate(nng.subgraphs):
1420 # remove passthrough tensors and attempt further optimizations
1421 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1422 nng,
1423 sg,
1424 arch,
1425 [remove_passthrough_tensor],
1426 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1427 )
1428
1429 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1430 # since ifm/ofm_shapes are of importance to this function
1431 for sg in nng.subgraphs:
1432 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1433 sg.refresh_after_modification()
1434
1435 return nng