blob: 8cfc3734a3f3b946c13db04d901fe2b8b4b6bd66 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from . import rewrite_graph
26from . import scaling
27from .api import NpuRoundingMode
28from .data_type import DataType
29from .debug_database import DebugDatabase
30from .errors import UnsupportedFeatureError
31from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020032from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020033from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020034from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020035from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020036from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020037from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020038from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020039from .graph_optimiser_util import needed_total_padding
40from .graph_optimiser_util import set_ifm_ofm_op_shapes
41from .graph_optimiser_util import set_tensor_equivalence
42from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020043from .numeric_util import round_away_zero
44from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020045from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020046from .operation import NpuBlockType
47from .operation import Op
48from .operation import Operation
49from .operation import Padding
50from .operation_util import create_avgpool_nop
51from .operation_util import get_pad_values_from_input
52from .shape4d import Shape4D
53from .softmax import SoftMax
54from .tensor import check_quantized_tens_scaling_equal
55from .tensor import create_const_tensor
56from .tensor import create_equivalence_id
57from .tensor import QuantizationParameters
58from .tensor import Tensor
59from .tensor import TensorPurpose
60from .tflite_mapping import optype_to_builtintype
61
62passthrough_nodes = (Op.Identity,)
63
64
65def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
66 """Creates an average pool for the given concat op/input feature map"""
67 ofm = concat_op.ofm
68 avgpool_op = create_avgpool_nop(name)
69 avgpool_op.inputs = [ifm]
70 avgpool_op.outputs = [ofm]
71
72 avgpool_op.write_offset = write_offset
73 avgpool_op.write_shape = ifm_shape
74 ofm.ops.append(avgpool_op)
75 DebugDatabase.add_optimised(concat_op, avgpool_op)
76 avgpool_op.ifm_shapes.append(ifm_shape)
77 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
78 avgpool_op.memory_function = Op.ConcatSliceWrite
79 return avgpool_op
80
81
82def remove_passthrough_tensor(tens, arch, nng):
83 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
84 assert len(tens.ops[0].inputs) == 1
85 tens = tens.ops[0].inputs[0]
86 return tens
87
88
89def rewrite_concat_ops(op, arch):
90 if not op.run_on_npu or not op.type.is_concat_op():
91 return
92
93 axis_4D = 0
94 ofm = op.ofm
95 ofm.ops = []
96 offset = 0
97
98 unfuse_activation_function(op)
99
100 if op.type == Op.Pack:
101 # Pack is also referred to as Stack
102 axis = int(op.attrs["axis"])
103 if axis < 0: # Convert to positive axis
104 axis = len(op.inputs[0].shape) + 1 + axis
105
106 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
107
108 axis_4D = axis + (4 - len(desired_shape))
109
110 for idx, inp in enumerate(op.inputs):
111 op.ifm_shapes[idx] = Shape4D(desired_shape)
112 op.type = Op.PackReshaped
113
114 inputs, axis = op.get_concat_inputs_axis()
115 for idx, inp in enumerate(inputs):
116 if op.type != Op.PackReshaped:
117 op.ifm_shapes[idx] = Shape4D(inp.shape)
118 if axis >= 0:
119 axis_4D = axis + (4 - len(inp.shape))
120 else:
121 axis_4D = axis
122 write_offset = [0, 0, 0, 0]
123 write_offset[axis_4D] = offset
124 concat_end = offset + op.ifm_shapes[idx][axis_4D]
125 create_avg_pool_for_concat(
126 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
127 )
128 offset = concat_end
129 assert ofm.shape[axis] == offset
130
131 return op
132
133
134def rewrite_split_ops(tens, arch, nng):
135
136 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
137 split_op = tens.ops[0]
138
139 # Not supported so leave it and run on CPU
140 if not split_op.run_on_npu:
141 return tens
142
143 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
144
145 tens.ops = []
146 new_op = Operation(Op.SplitSliceRead, split_op.name)
147 new_op.inputs = [inp]
148 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000149 if None in (offset_end, offset_start):
150 read_shape = None
151 else:
152 # the read shape is relative to each start offset
153 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200154
155 # For Split the offset cannot be extracted from the tensor so it has to
156 # be calculated from the index of the output tensor
157 if axis is not None:
158 # Get the start and end of the split
159 offset_start = [0] * 4
160 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
161 for idx, out in enumerate(outputs):
162 if axis_4D_list is not None:
163 axis_4D = axis_4D_list[idx]
164 else:
165 split_op.ofm_shapes[idx] = Shape4D(out.shape)
166 if axis >= 0:
167 axis_4D = axis + (4 - len(out.shape))
168 else:
169 axis_4D = axis
170
171 if out == tens:
172 ofm_shape_idx = idx
173 read_shape = split_op.ofm_shapes[idx]
174 break
175
176 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
177
178 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
179 new_op.read_shapes[0] = read_shape
180 new_op.run_on_npu = True
181 new_op.set_output_tensor(tens)
182 new_op.ifm_shapes.append(Shape4D(inp.shape))
183 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
184 DebugDatabase.add_optimised(split_op, new_op)
185
186 return tens
187
188
189def remove_SplitSliceRead(op, arch):
190
191 if op.type == Op.SplitSliceRead:
192 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
193 if (
194 len(op.ofm.consumer_list) == 1
195 and op.ofm.consumer_list[0] is not None
196 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200197 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200198 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
199 ):
200 # SplitSliceRead can be performed by tensor consumer
201 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200202 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200203 else:
204 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
205 avgpool_op.add_input_tensor(op.ifm)
206 avgpool_op.outputs = [op.ofm]
207 op.ofm.ops.remove(op)
208 op.ofm.ops.append(avgpool_op)
209 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
210 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
211 avgpool_op.read_offsets[0] = op.read_offsets[0]
212 avgpool_op.read_shapes[0] = op.read_shapes[0]
213
214 op.ifm.consumer_list.remove(op)
215 DebugDatabase.add_optimised(op, avgpool_op)
216
217
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200218def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
219 k_w, k_h = kernel.dilated_wh()
220 s_x, s_y = kernel.stride
221 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
222 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
223 if padding_type == Padding.SAME:
224 left_pad = (xpad + 0) // 2
225 right_pad = (xpad + 1) // 2
226 top_pad = (ypad + 0) // 2
227 bottom_pad = (ypad + 1) // 2
228 elif padding_type == Padding.VALID:
229 left_pad = 0
230 right_pad = 0
231 top_pad = 0
232 bottom_pad = 0
233 elif padding_type == Padding.EXPLICIT:
234 # Padding is specified in a PAD operator which has been bypassed.
235 top, left, bottom, right = explicit_padding
236 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
237 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
238 else:
239 raise UnsupportedFeatureError(f"Unknown padding")
240 padding = (top_pad, left_pad, bottom_pad, right_pad)
241 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
242 return padding, skirt
243
244
245def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
246 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
247 if padding_type == Padding.SAME:
248 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
249 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
250 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
251 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
252 left_pad = max(kernel_width - 1 - right_pad, 0)
253 top_pad = max(kernel_height - 1 - bottom_pad, 0)
254 elif padding_type == Padding.VALID:
255 right_pad = max(kernel_width - 2, 0)
256 bottom_pad = max(kernel_height - 2, 0)
257 left_pad = kernel_width - 1
258 top_pad = kernel_height - 1
259 else:
260 raise UnsupportedFeatureError(f"Unknown padding")
261 padding = (top_pad, left_pad, bottom_pad, right_pad)
262 skirt = padding
263 return padding, skirt
264
265
266def fixup_conv2d_backprop(op, arch, nng):
267 if op.type == Op.Conv2DBackpropInput:
268 # flip the inputs
269 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
270 op.type = Op.Conv2DBackpropInputSwitchedBias
271 op.ifm.resampling_mode = resampling_mode.TRANSPOSE
272
273 # Update strides
274 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
275
276 return op
277
278
279# Convert the op to an elementwise add
280def convert_resizebilinear_1x1_to_add(op):
281 op.type = Op.Add
282 op.name = op.name + "_add"
283 op.attrs["resizebilinear"] = True
284 # Create an input tensor filled with zeros
285 shape = op.ofm_shapes[0].as_list()
286 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100287 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200288 tens.quantization = QuantizationParameters(0.0, 255.0)
289 tens.quantization.scale_f32 = 1.0
290 tens.quantization.zero_point = 0
291 tens.consumer_list = [op]
292 tens_op = op.inputs[1].ops[0]
293 tens_op.set_output_tensor(tens)
294 # Set the add inputs
295 op.inputs[1] = op.inputs[0]
296 op.inputs[0] = tens
297 op.set_ifm_ofm_shapes()
298
299 return op
300
301
302# Convert ResizeBilinear to a number of 2x2 pool ops
303def convert_resizebilinear_to_2x2_pool(op):
304 count = 0
305 pre_op = op
306 outputs = op.outputs
307
308 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
309 if op.attrs["align_corners"]:
310 shape_modifier = 1
311 op.attrs["padding"] = Padding.VALID
312 else:
313 shape_modifier = 0
314 op.attrs["padding"] = Padding.SAME
315 op.inputs[0].resampling_mode = resampling_mode.NEAREST
316
317 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
318 out_shape = np.array(op.ofm_shapes[0].get_hw_as_list())
319 if (upscaled_shape == upscaled_shape * 2 - shape_modifier).all():
320 return op
321
322 while (upscaled_shape < out_shape).all():
323 if count == 0:
324 scaled_op = pre_op
325 else:
326 scaled_op = op.clone("_{}".format(count))
327 scaled_op.inputs[0] = pre_op.outputs[0]
328
329 upscaled_shape = upscaled_shape * 2 - shape_modifier
330
331 if (upscaled_shape == out_shape).all():
332 scaled_op.outputs = outputs
333 scaled_op.outputs[0].ops = [scaled_op]
334 else:
335 shape = op.ofm_shapes[0].as_list()
336 shape[1:3] = upscaled_shape
337 out_tens = Tensor(shape, DataType.int16, "{}_{}".format(op.outputs[0].name, count))
338 out_tens.quantization = op.outputs[0].quantization.clone()
339 out_tens.quantization.quant_min = np.iinfo(np.int16).min
340 out_tens.quantization.quant_max = np.iinfo(np.int16).max
341 scaled_op.set_output_tensor(out_tens)
342 pre_op = scaled_op
343 count += 1
344
345 # Setup the scale value
346 if scaled_op.inputs[0].dtype.bits == 8 and scaled_op.outputs[0].dtype.bits == 16:
347 scaled_op.rescale = 128
348 elif scaled_op.inputs[0].dtype.bits == 16 and scaled_op.outputs[0].dtype.bits == 8:
349 scaled_op.rescale = 1 / 128
350 else:
351 scaled_op.rescale = None
352 scaled_op.set_ifm_ofm_shapes()
353
354 return op
355
356
357def fixup_resizebilinear(op, arch, nng):
358 if op.type == Op.ResizeBilinear and op.run_on_npu:
359 if op.ifm_shapes[0] == op.ofm_shapes[0]:
360 # Bypass nop resizebilinear
361 op.inputs = op.inputs[:1]
362 op.type = Op.Identity
363 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
364 convert_resizebilinear_1x1_to_add(op)
365 else:
366 convert_resizebilinear_to_2x2_pool(op)
367
368 return op
369
370
371def convert_nop_split_to_identity(op, arch, nng):
372 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
373 # the list comprehension should return a list with a single tensor
374 # if it shouldn't, remove_passthrough_tensor will fail appropriately
375 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
376 op.type = Op.Identity
377 return op
378
379
380def rewrite_fully_connected_input(op, arch, nng):
381 if op.type == Op.FullyConnected:
382 n_in_elems = op.weights.shape[-2]
383 elms = op.ifm.elements()
384 batch_size = elms // n_in_elems
385 assert batch_size * n_in_elems == elms
386
387 op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
388 return op
389
390
391def convert_batched_fc_shape(op, arch, nng):
392 if op.type == Op.FullyConnected:
393 # Check if the first dimension indicates batching
394 if op.ifm_shapes[0].batch > 1:
395 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
396 n = op.ifm_shapes[0].batch
397 h, w = batching_split.get(n, (1, n))
398 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
399
400 # Reshape Weights to be 4D. IO becomes HWIO
401 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100402 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
403 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200404
405 n = op.ofm_shapes[0].batch
406 h, w = batching_split.get(n, (1, n))
407 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
408 return op
409
410
411def unfuse_activation_function(op):
412 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
413 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
414 op.activation = None
415 out_tens = op.outputs[0]
416 intermediate_tens = out_tens.clone("_act_intermediate")
417 act_op.set_output_tensor(out_tens)
418 act_op.add_input_tensor(intermediate_tens)
419 op.set_output_tensor(intermediate_tens)
420 act_op.set_ifm_ofm_shapes()
421
422
423def rewrite_stridedslice_output(op, arch, nng):
424 if not op.run_on_npu or op.type != Op.StridedSlice:
425 return op
426
427 new_axis_mask = op.attrs["new_axis_mask"]
428 shrink_axis_mask = op.attrs["shrink_axis_mask"]
429
430 if shrink_axis_mask == 0 and new_axis_mask == 0:
431 return op
432
433 axis_4D = [0] * len(op.outputs)
434 for idx, out_tens in enumerate(op.outputs):
435 output_shape = list(out_tens.shape)
436
437 if shrink_axis_mask != 0:
438 n = 0
439 axis = 0
440 while shrink_axis_mask:
441 prev_mask = shrink_axis_mask
442 n += 1
443 shrink_axis_mask &= shrink_axis_mask - 1
444 axis = int(math.log2(prev_mask - shrink_axis_mask))
445 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
446
447 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
448 op.attrs["shrink_axis_mask"] = 0
449 if axis >= 0:
450 axis_4D[idx] = axis + (4 - len(output_shape))
451 else:
452 axis_4D[idx] = axis
453 op.ofm_shapes[idx] = Shape4D(output_shape)
454
455 elif new_axis_mask != 0:
456 n = 0
457 axis = 0
458 while new_axis_mask:
459 prev_mask = new_axis_mask
460 n += 1
461 new_axis_mask &= new_axis_mask - 1
462 axis = int(math.log2(prev_mask - new_axis_mask))
463 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
464 new_axis_mask >>= 1
465
466 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
467 op.attrs["new_axis_mask"] = 0
468 if axis >= 0:
469 axis_4D[idx] = axis + (4 - len(output_shape))
470 else:
471 axis_4D[idx] = axis
472 op.ofm_shapes[idx] = Shape4D(output_shape)
473
474 op.attrs["split_axis_4D"] = axis_4D
475 return op
476
477
478def rewrite_unpack_output(op, arch, nng):
479 tens = op.outputs[0]
480 if op.run_on_npu and op.type == Op.Unpack:
481 # Unpack is also referred to as Unstack
482 axis = int(op.attrs["axis"])
483 if axis < 0: # Convert to positive axis
484 axis = len(op.inputs[0].shape) + 1 + axis
485 op.type = Op.UnpackReshaped
486 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
487
488 axis_4D = axis + (4 - len(desired_output_shape))
489 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
490
491 for idx, out_tens in enumerate(op.outputs):
492 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
493 return op
494
495
496def add_padding_fields(op, arch, nng):
497 if op.run_on_npu:
498 if "padding" in op.attrs:
499 input_shape = op.ifm_shapes[0]
500 output_shape = op.ofm_shapes[0]
501 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
502 kernel_size = op.inputs[1].shape[:2]
503 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
504 kernel_size = op.attrs["ksize"][1:3]
505 else:
506 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
507
508 if op.type == Op.Conv2DBackpropInputSwitchedBias:
509 upscaling_factor = output_shape.height // input_shape.height
510 padding, skirt = calc_upscaled_padding_and_skirt(
511 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
512 )
513 else:
514 padding, skirt = calc_padding_and_skirt(
515 op.attrs["padding"], op.kernel, input_shape, op.attrs.get("explicit_padding"),
516 )
517
518 op.attrs["explicit_padding"] = padding
519 op.attrs["skirt"] = skirt
520
521 return op
522
523
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200524def reorder_depthwise_weights(op, arch, nng):
525 if op.type.is_depthwise_conv2d_op():
526 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100527 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
528 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200529 weight_tensor.weight_transpose_depthwise = True
530
531 return op
532
533
534def optimise_strided_conv(op, arch, nng):
535 stride_x, stride_y = op.get_kernel_stride()
536 ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()
537
538 if (
539 op.type == Op.Conv2DBias
540 and op.op_index == 0
541 and stride_x == 2
542 and op.ifm_shapes[0].depth <= 4
543 and op.ifm_shapes[0].width % 2 == 0
544 and weight_tensor is not None
545 and weight_tensor.shape[1] >= 2
546 ):
547 ifm_shape = op.ifm_shapes[0]
548 # IFM
549 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
550
551 # Weights
552 weight_shape = weight_tensor.shape
553 if weight_shape[1] % 2 != 0:
554 weight_shape[1] = weight_shape[1] + 1
555 padded_array = np.zeros(weight_shape)
556 for i in range(weight_shape[0]):
557 padded_array[i] = np.vstack(
558 [
James Peet7519d502021-07-19 16:47:58 +0100559 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200560 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
561 ]
562 )
James Peet7519d502021-07-19 16:47:58 +0100563 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200564 weight_shape[1] //= 2
565 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100566 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200567 weight_tensor.set_all_shapes(weight_shape)
568 # If multiple copies of the weights are used, we could avoid
569 # them having the same address by changing the value_id
570 weight_tensor.value_id = uuid.uuid4()
571
572 # Strides
573 stride_x = 1
574 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
575
576 return op
577
578
579def convert_conv_to_fc(op, arch, nng):
580 # Conv 1x1 can be equivalent to Fully Connected.
581 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
582 # caching/double buffering for the weights.
583 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
584 if op.type == Op.Conv2DBias:
585 h = op.ifm_shapes[0].height
586 w = op.ifm_shapes[0].width
587 kh, kw, _, _ = op.inputs[1].shape
588 if h == 1 and w == 1 and kh == 1 and kw == 1:
589 # Overwrite this op as a Fully Connected Op
590 op.name += "_fc"
591 op.type = Op.FullyConnected
592 op.attrs = {
593 "weights_format": 0,
594 }
595 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
596 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100597 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
598 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200599
600 DebugDatabase.add_optimised(op, op)
601 return op
602
603
604def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
605 if op.run_on_npu and op.type.is_relu_op():
606 ifm = op.inputs[0]
607 ofm = op.outputs[0]
608 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
609 # and requires its own to be inserted
610 if not check_quantized_tens_scaling_equal(ifm, ofm):
611 # Override this op with its own primary op (avgpool)
612 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
613 # And fuse the original activation function to it
614 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200615 # Add explicit rescaling
616 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
617 multiplier, shift = scaling.quantise_scale(rescale)
618 relu_fused_op.rescale = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200619 # Tidy up and assign the ifm and ofm to the new op
620 ifm.consumer_list.remove(op)
621
622 relu_fused_op.add_input_tensor(ifm)
623 relu_fused_op.set_output_tensor(ofm)
624 relu_fused_op.set_ifm_ofm_shapes()
625 op = relu_fused_op
626 return op
627
628
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200629def convert_softmax(op, arch, nng):
630 if op.type == Op.Softmax and op.run_on_npu:
631 softmax = SoftMax(op)
632 op = softmax.get_graph()
633 return op
634
635
636def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
637 r"""Whenever there is a subgraph with this topology:
638
639 Input X For X = -1 or X > 0
640 | \ / This subgraph can be replaced with either
641 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
642 | /
643 Max
644 """
645
646 if op.type == Op.Maximum:
647 # finds the Mul input(s) to the Max
648 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
649 if len(muls) == 1:
650 mul = muls[0].ops[0]
651 elif len(muls) == 2:
652 # In the case both inputs are Muls, find the one with the same input as the Max
653 mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
654 else:
655 # No Mul inputs
656 return op
657
658 # make sure the Mul doesn't have any other consumers
659 mul_ofm = mul.outputs[0]
660 if len(mul_ofm.consumers()) != 1:
661 return op
662 # make sure the Mul doesn't have a fused activation function
663 if mul.activation:
664 return op
665 ifm, ofm = op.get_ifm_ofm()
666 if ifm is None or ofm is None:
667 return op
668
669 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
670 return op
671 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
672 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
673 return op
674
675 # finds the branched input that goes to both the Max and the Mul
676 shared = set(op.inputs) & set(mul.inputs)
677 if len(shared) == 1:
678 shared_in = shared.pop()
679 # find the constant scalar input to the Mul
680 const_tens = (set(mul.inputs) - {shared_in}).pop()
681 # check that it is a scalar
682 if const_tens.shape != []:
683 return op
684 const = const_tens.ops[0]
685 # check that it is a constant
686 if const.type != Op.Const:
687 return op
688 # Remove the Mul from the shared input's consumers
689 shared_in.consumer_list.remove(mul)
690 else:
691 return op
692
693 val = const.outputs[0].values
694 if val >= 0:
695 new_op = Op.LeakyRelu
696 op.attrs["alpha"] = val
697 # to produce bit exact results, the alpha is not enough;
698 # save additional scaling info in attr "alpha_scale", to be used as input
699 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +0100700 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200701 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
702 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
703 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
704 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
705 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
706 elif val == -1:
707 new_op = Op.Abs
708 else:
709 return op
710
711 op.type = new_op
712 op.name = op.name.replace("Maximum", new_op.name)
713 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
714 op.inputs = [shared_in]
715 op.set_ifm_ofm_shapes()
716
717 # Record optimisation in debug database
718 DebugDatabase.add_optimised(op, op)
719
720 return op
721
722
723def convert_hardswish_to_lut(op, arch, nng):
724 if op.type == Op.HardSwish:
725 ifm, ofm = op.get_ifm_ofm()
726 # Generate the LUT
727 ifm_scale = np.double(ifm.quantization.scale_f32)
728 ofm_scale = np.double(ofm.quantization.scale_f32)
729 zp_in = ifm.quantization.zero_point
730 zp_out = ofm.quantization.zero_point
731 ifm_scale_hires = (1 / 128) * ifm_scale
732 relu_multiplier = np.double(3 / 32768)
733 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
734 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
735 # Use 16bit scale
736 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
737 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
738
739 values = []
740 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
741 quantized_min = min(ix)
742 quantized_max = max(ix)
743 for x in ix:
744 input_value = x - zp_in
745 input_value_hires = input_value * 128
746 # Compute the input value on essentially the output scale, not shifted yet
747 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
748 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
749 relu_value = np.int16(input_value_hires)
750 if relu_shift < 31:
751 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
752
753 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
754
755 if relu_shift < 31:
756 relu_value = fp_math.shift_left16(relu_value, 1)
757
758 if relu_shift > 31:
759 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
760
761 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
762 # Now convert that to a 16bit fixedpoint value in [0, 1]
763 relu_value = (relu_value + (1 << 15)) >> 1
764 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
765 shift = 31 - out_shift
766 shift = -shift if shift < 0 else 0
767 # Finally apply the output shift
768 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
769 lut_result = min(quantized_max, max(quantized_min, lut_result))
770 values.append(lut_result)
771 return convert_to_lut(op, values, "hardswish")
772 return op
773
774
775def convert_lrelu_to_mul_max(op, arch):
776 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
777 # (the opposite of convert_mul_max_to_abs_or_lrelu)
778 ifm, ofm = op.get_ifm_ofm()
779 if ifm is None or ofm is None:
780 return op
781
782 # Add multiplication with alpha
783 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
784 mul_alpha.add_input_tensor(ifm)
785 # Create const tensor containing alpha as scalar
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200786 alpha = np.float32(op.attrs["alpha"])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200787 quantization = ifm.quantization.clone()
788 quantization.min = 0
789 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
790 quantization.zero_point = 0
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200791 if np.isinf(1 / alpha):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200792 # Handling of alpha near zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200793 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200794 scalar = 0
795 else:
796 quantization.scale_f32 = alpha
797 scalar = alpha
798 alpha_tens = create_const_tensor(
799 op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
800 )
James Peet7519d502021-07-19 16:47:58 +0100801 alpha_tens.values = np.array([1])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200802 mul_alpha.add_input_tensor(alpha_tens)
803 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
804 mul_alpha.set_output_tensor(fm_alpha)
805 mul_alpha.set_ifm_ofm_shapes()
806 DebugDatabase.add_optimised(op, mul_alpha)
807
808 if check_quantized_tens_scaling_equal(ifm, ofm):
809 # No identity multiplication is needed
810 fm_id = ifm
811 else:
812 # Add multiplication with identity
813 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
814 mul_identity.add_input_tensor(ifm)
815 # Create const tensor containing identity as scalar
816 quantization = ifm.quantization.clone()
817 quantization.min = 0
818 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200819 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200820 quantization.zero_point = 0
821 identity_tens = create_const_tensor(
822 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
823 )
824 mul_identity.add_input_tensor(identity_tens)
825 # Make sure that fm_id is allocated to a different address than fm_alpha
826 fm_id = ofm.clone(op.name + "_id", set_unique=True)
827 mul_identity.set_output_tensor(fm_id)
828 mul_identity.set_ifm_ofm_shapes()
829 DebugDatabase.add_optimised(op, mul_identity)
830
831 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
832 op.type = Op.Maximum
833 op.name = op.name.replace("LeakyRelu", "Maximum")
834 op.inputs = []
835 ifm.consumer_list.remove(op)
836 op.add_input_tensor(fm_alpha)
837 op.add_input_tensor(fm_id)
838 op.set_ifm_ofm_shapes()
839
840 DebugDatabase.add_optimised(op, op)
841 return op
842
843
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200844def convert_to_lut8(op, fn, fn_name):
845 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
846 # fn is a function(real) -> real
847 ifm, ofm = op.get_ifm_ofm()
848 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
849 return op
850 # Generate the LUT
851 ifm_scale = np.double(ifm.quantization.scale_f32)
852 ofm_scale = np.double(ofm.quantization.scale_f32)
853 zp_in = ifm.quantization.zero_point
854 zp_out = ofm.quantization.zero_point
855 values = []
856 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
857 quantized_min = min(ix)
858 quantized_max = max(ix)
859 for x in ix:
860 x_real = ifm_scale * (x - zp_in)
861 y_real = fn(x_real)
862 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
863 lut_result = min(quantized_max, max(quantized_min, lut_result))
864 values.append(lut_result)
865 return convert_to_lut(op, values, fn_name)
866
867
868def convert_lrelu_to_lut(op, arch):
869 ifm, ofm = op.get_ifm_ofm()
870 # Generate the LUT
871 alpha = op.attrs["alpha"]
872 ifm_scale = np.double(ifm.quantization.scale_f32)
873 ofm_scale = np.double(ofm.quantization.scale_f32)
874 zp_in = ifm.quantization.zero_point
875 zp_out = ofm.quantization.zero_point
876 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
877 alpha_scalar = 1
878 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
879 if "alpha_scaling" in op.attrs:
880 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
881 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
882 values = []
883 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
884 quantized_min = min(ix)
885 quantized_max = max(ix)
886 for x in ix:
887 if x < zp_in:
888 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
889 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
890 )
891 else:
892 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
893 lut_result = min(quantized_max, max(quantized_min, lut_result))
894 values.append(lut_result)
895 return convert_to_lut(op, values, "lrelu")
896
897
898def convert_lrelu(op, arch, nng):
899 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
900 if op.type != Op.LeakyRelu:
901 return op
902 ifm, ofm = op.get_ifm_ofm()
903 if ifm is None or ofm is None:
904 return op
905 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
906 # use LUT for int8/uint8
907 return convert_lrelu_to_lut(op, arch)
908 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
909 # use LeakyRelu unmodified for int16 with equal input/output scaling
910 return op
911 return convert_lrelu_to_mul_max(op, arch)
912
913
914def convert_tanh_sigmoid_to_lut(op, arch, nng):
915 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
916 if op.type == Op.Sigmoid:
917 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
918 elif op.type == Op.Tanh:
919 return convert_to_lut8(op, math.tanh, "tanh")
920 return op
921
922
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200923def remove_memory_only_ops(op, arch):
924 if op.run_on_npu and op.type in memory_only_ops:
925 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200926
927
928def fuse_activation_function_with_prev(op, arch, nng):
929 # if op is a no-op: attempts to move the activation function to the preceding op
930 if not op.attrs.get("is_nop", False) or op.activation is None:
931 return op
932 ifm, ofm = op.get_ifm_ofm()
933 if ifm is None or ofm is None:
934 return op
935 # finds the input(s) to the operation
936 prev_op = ifm.ops[0]
937 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
938 fuse = (
939 prev_op.run_on_npu
940 and prev_op.type.npu_block_type != NpuBlockType.Default
941 and len(ifm.ops) == 1
942 and len(prev_op.outputs[0].consumers()) == 1
943 and prev_op.activation is None
944 )
945 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
946 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
947 # LUT currently only works correctly for elementwise ops
948 fuse = False
949 if not fuse:
950 return op
951 # Move the fused activation function + corresponding info to prev_op
952 prev_op.activation = op.activation
953 prev_op.forced_output_quantization = op.forced_output_quantization
954 if op.activation_lut is not None:
955 prev_op.set_activation_lut(op.activation_lut)
956 # Bypass op
957 prev_op.set_output_tensor(ofm)
958 DebugDatabase.add_optimised(op, prev_op)
959 return op
960
961
962def _leading_pad_ok(leading_pad, stride, kernel_size):
963 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
964 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
965 max_size = kernel_size // 2
966 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
967
968
969def replace_pad_by_hw_pad(op: Operation, arch, nng):
970 """
971 Tries to completely remove a PAD operator by using hardware padding.
972 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
973 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
974 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
975 if both operations can be run on the NPU.
976 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
977 """
978 if (
979 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
980 and op.run_on_npu
981 and op.attrs["padding"] == Padding.VALID
982 ):
983 pad_op = op.ifm.ops[0]
984 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
985 return op
986 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
987 return op
988 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
989 k = op.kernel
990 k_w, k_h = k.dilated_wh()
991
992 # Check if the PAD operator can be replaced by hardware padding
993 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
994 # Too much padding, it would require hardware padding to actually insert zeros
995 return op
996 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
997 return op
998
999 if op.type.is_avgpool_op():
1000 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1001 for pad, k_size in (
1002 (left, k_w),
1003 (right, k_w),
1004 (top, k_h),
1005 (bottom, k_h),
1006 ):
1007 if pad not in (0, k_size // 2):
1008 return op
1009 # Average pool is converted to depthwise, because NPU average pool + same padding
1010 # has a special implementation that is different from PAD followed by average pool with
1011 # valid padding.
1012 k_w, k_h = op.kernel.width, op.kernel.height
1013 ifm = op.ifm
1014 # Remember other inputs
1015 other_inputs = op.inputs[1:]
1016 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1017 quantization = QuantizationParameters(0.0, 255.0)
1018 quantization.scale_f32 = 1.0 / (k_w * k_h)
1019 quantization.zero_point = 0
1020 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1021 weights = np.full(shape, 1)
1022
1023 weight_tens = create_const_tensor(
1024 op.name + "_weights",
1025 shape,
1026 op.ifm.dtype,
1027 weights,
1028 np.uint8,
1029 purpose=TensorPurpose.Weights,
1030 quantization=quantization,
1031 )
James Peet7519d502021-07-19 16:47:58 +01001032 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001033 op.type = Op.DepthwiseConv2DBias
1034 op.inputs = []
1035 op.add_input_tensor(ifm)
1036 op.add_input_tensor(weight_tens)
1037 # Add bias tensor, all biases set to 0
1038 op.inputs.append(None)
1039 fixup_bias_tensors(op, arch, nng)
1040 # Add other inputs
1041 op.inputs.extend(other_inputs)
1042 op.rounding_mode = NpuRoundingMode.NATURAL
1043
1044 # Bypass the PAD operator
1045 op.set_input_tensor(pad_op.ifm, 0)
1046 # Adjust the padding attributes of the convolution operator
1047 op.attrs["padding"] = Padding.EXPLICIT
1048 op.attrs["explicit_padding"] = (top, left, bottom, right)
1049 op.set_ifm_ofm_shapes()
1050 return op
1051
1052
1053def convert_pad(op: Operation, arch, nng):
1054 """
1055 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1056 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1057 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1058 """
1059 if op.type != Op.Pad or not op.run_on_npu:
1060 return op
1061 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1062
1063 ifm = op.ifm
1064 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001065 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001066 ofm = op.ofm
1067 assert ofm is not None
1068 ofm.ops = []
1069 ofm_shape = op.ofm_shapes[0]
1070
1071 # Average pool op that copies IFM to the right place inside the OFM
1072 shp0 = Shape4D(0, 0, 0, 0)
1073 shp_top = shp0.with_height(top)
1074 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1075 avgpool_op.activation = op.activation
1076 quant = ofm.quantization
1077 pad_value = quant.zero_point
1078 # Add operations that fill the borders of the OFM
1079 if top > 0:
1080 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1081 zero_tens = create_const_tensor(
1082 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1083 )
1084 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1085 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1086 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1087 if bottom > 0:
1088 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1089 zero_tens = create_const_tensor(
1090 op.name + "_bottom",
1091 shape.as_list(),
1092 ofm.dtype,
1093 shape.elements() * [pad_value],
1094 np.uint8,
1095 quantization=quant,
1096 )
1097 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1098 create_avg_pool_for_concat(
1099 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1100 )
1101 if left > 0:
1102 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1103 zero_tens = create_const_tensor(
1104 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1105 )
1106 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1107 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1108 if right > 0:
1109 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1110 zero_tens = create_const_tensor(
1111 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1112 )
1113 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1114 create_avg_pool_for_concat(
1115 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1116 )
1117
1118 op.type = Op.ConcatTFLite
1119 return avgpool_op
1120
1121
1122def add_attrs_to_resizebilinear(op, arch, nng):
1123 if op.type == Op.ResizeBilinear and op.run_on_npu:
1124 input_tensor = op.inputs[0]
1125 input_shape = op.ifm_shapes[0]
1126 upscaled_height = input_shape.height * 2
1127 upscaled_width = input_shape.width * 2
1128 out_shape = op.ofm_shapes[0]
1129 if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
1130 # this means the output is supposed to be a x2 upscale,
1131 # so we need to do SAME padding
1132 op.attrs["padding"] = Padding.SAME
1133 elif (
1134 op.attrs["align_corners"]
1135 and out_shape.height == (upscaled_height - 1)
1136 and out_shape.width == (upscaled_width - 1)
1137 ):
1138 # here we can just run the avg pool without padding and
1139 # produce a (M * 2 - 1, N * 2 - 1) sized output
1140 op.attrs["padding"] = Padding.VALID
1141 else:
1142 return op
1143 input_tensor.resampling_mode = resampling_mode.NEAREST
1144 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
1145 return op
1146
1147
1148def fixup_bias_tensors(op, arch, nng):
1149 if op.type.needs_bias() and op.bias is None:
1150 # Op has no bias, add bias tensor filled with zeros
1151 nr_biases = op.inputs[1].shape[-1]
1152 bias_values = [0] * nr_biases
1153 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001154 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1155
1156 return op
1157
1158
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001159def fixup_asymmetric_weights(op, arch, nng):
1160 if op.run_on_npu and (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op()):
1161 if op.ifm.dtype == DataType.int8:
1162 if not np.all(op.weights.quantization.zero_point == 0):
1163 print(f"Warning: {op.type} '{op.name}' has asymmetric weights, zero points have been adjusted.")
1164 op.weights.quantization.zero_point *= 0
1165
1166 return op
1167
1168
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001169def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1170 if op.type == Op.Mean and op.run_on_npu:
1171 keep_dims = op.attrs.get("keep_dims", False)
1172 inp, axis = op.inputs
1173 shape = inp.shape
1174 dims = len(shape)
1175
1176 # Height and width axes have different index depending on dimensions
1177 if axis.shape == [] or axis.shape[0] == 1: # single axis
1178 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1179 if dims in (2, 3):
1180 if axis == 0:
1181 h, w = shape[axis], 1
1182 else:
1183 h, w = 1, shape[axis]
1184 else:
1185 if axis == 1:
1186 h, w = shape[axis], 1
1187 else:
1188 h, w = 1, shape[axis]
1189 else: # multiple axes
1190 axis = sorted(axis.values)
1191 h, w = [shape[i] for i in axis]
1192
1193 # Set necessary depthwise attributes
1194 op.attrs.update(
1195 {
1196 "padding": Padding.VALID,
1197 "stride_h": 1,
1198 "stride_w": 1,
1199 "strides": (1, 1, 1, 1),
1200 "depth_multiplier": 1,
1201 "channel_multiplier": 1,
1202 "dilation_h_factor": 1,
1203 "dilation_w_factor": 1,
1204 "dilation": (1, 1, 1, 1),
1205 }
1206 )
1207 # Change op type
1208 op.type = Op.DepthwiseConv2DBias
1209 # Set IFM/OFM shapes after changing op type
1210 op.set_ifm_ofm_shapes()
1211
1212 weight_scale, bias = 1, None
1213 ofmq, ifmq = op.ofm.quantization, inp.quantization
1214 # Set rounding mode, scaling and zero point based on which reference implementation to match
1215 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1216 if inp.dtype == DataType.uint8:
1217 # This attribute means a different scaling calculation is used in order to match reference
1218 op.low_precision_scaling = True
1219 weight_scale = h * w
1220 # Set zero points to 0 as they will be adjusted for with bias term
1221 foq = ofmq.clone()
1222 foq.zero_point = 0
1223 fiq = ifmq.clone()
1224 fiq.zero_point = 0
1225 op.forced_input_quantization = fiq
1226 bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
1227 # If the bias term is outside uint8 range, we need an Add op to apply it.
1228 if bias_term < 0 or bias_term > 255:
1229 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1230 # Bias term has higher bitness (i32) than input/output (u8).
1231 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1232 # the bias can only effectively assume values in the range [-255, 255].
1233 intermediate.dtype = DataType.int16
1234 intermediate.quantization.zero_point = 0
1235 add_op = Operation(Op.Add, op.name + "_bias")
1236 add_op.forced_output_quantization = foq
1237 add_op.add_input_tensor(intermediate)
1238 quant = QuantizationParameters()
1239 quant.zero_point = 0
1240 bias_term_tens = create_const_tensor(
James Peet7519d502021-07-19 16:47:58 +01001241 op.name + "_bias", [1, 1, 1, 1], DataType.int16, [bias_term], np.int16, quantization=quant,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001242 )
1243 add_op.add_input_tensor(bias_term_tens)
1244 add_op.set_output_tensor(op.ofm)
1245 add_op.set_ifm_ofm_shapes()
1246 add_op.activation = op.activation
1247 op.activation = None
1248 op.set_output_tensor(intermediate)
1249 op.set_ifm_ofm_shapes()
1250 # If not, we can just do it with the OFM zero point.
1251 else:
1252 foq.zero_point = bias_term
1253 op.forced_output_quantization = foq
1254 else:
1255 assert inp.dtype == DataType.int8
1256 # Use a depthwise to calculate the sum,
1257 # followed by a multiplication with 1/N to get the MEAN
1258 weight_scale = 1
1259 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1260 intermediate.dtype = DataType.int16
1261 mul_op = Operation(Op.Mul, op.name + "_mul")
1262 mul_op.add_input_tensor(intermediate)
1263 # Create scalar containing 1/N
1264 quant = QuantizationParameters()
1265 quant.zero_point = 0
1266 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1267 # while rounding mode NATURAL would round this to -1.
1268 # This can only occur if N is even, and can be emulated by
1269 # multiplying with a number that is slightly smaller than 1/N.
1270 # It must be so small that other roundings are not affected;
1271 # the calculated value is based on worst case,
1272 # which is sum 256 * N (the maximum sum that can occur with int8)
1273 n = int(h * w)
1274 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1275 quant.scale_f32 = 1 / (n - eps)
1276 scalar = create_const_tensor(
1277 op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
1278 )
1279 mul_op.add_input_tensor(scalar)
1280 mul_op.set_output_tensor(op.ofm)
1281 mul_op.set_ifm_ofm_shapes()
1282 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1283 mul_op.activation = op.activation
1284 op.activation = None
1285 op.set_output_tensor(intermediate)
1286 op.set_ifm_ofm_shapes()
1287 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1288 # Here we can just use a simple AvgPool with truncating rounding,
1289 # as we're emulating simple integer division.
1290 op.rounding_mode = NpuRoundingMode.TRUNCATE
1291 op.type = Op.AvgPool
1292 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1293 else:
1294 op.rounding_mode = NpuRoundingMode.NATURAL
1295 weight_scale = 1 / (h * w)
1296 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1297 bias = -ifmq.zero_point * h * w
1298 fiq = ifmq.clone()
1299 fiq.zero_point = 0
1300 op.forced_input_quantization = fiq
1301
1302 # Change dimensions to 4
1303 if dims < 4:
1304 shape = [1] + shape
1305 if dims == 2:
1306 shape += [1]
1307
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001308 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
1309 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001310 shape = [shape[0], 1, h * w, shape[3]]
1311 op.ifm_shapes[0] = Shape4D(shape)
1312 if h > 256 and op.type == Op.AvgPool:
1313 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1314
1315 # If the AvgPool version is used, we don't need to do anything else
1316 if op.type == Op.AvgPool:
1317 return op
1318
1319 # Make unit weight tensor quantization
1320 weight_quant = ifmq.clone()
1321 weight_quant.min = 0
1322 weight_quant.max = 255
1323 weight_quant.scale_f32 = weight_scale
1324 weight_quant.zero_point = 0
1325
1326 # Set weight shape to [H,W,C,B]
1327 weight_shape = shape[1:4] + [shape[0]]
1328 # Add unit weight tensor
1329 op.set_input_tensor(
1330 create_const_tensor(
1331 "weights",
1332 weight_shape,
1333 inp.dtype,
1334 np.ones(weight_shape),
1335 value_dtype=np.uint8,
1336 quantization=weight_quant,
1337 ),
1338 1,
1339 )
James Peet7519d502021-07-19 16:47:58 +01001340 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001341
1342 # Add None bias tensor
1343 op.inputs.append(None)
1344 # Add bias tensor
1345 if bias:
1346 bias_shape = [shape[-1]]
1347 op.set_input_tensor(
1348 create_const_tensor(
Tim Hall8ae29292021-07-28 16:52:03 +01001349 "bias", bias_shape, inp.dtype, np.ones(bias_shape) * bias, value_dtype=np.int32, quantization=None,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001350 ),
1351 2,
1352 )
1353
1354 return op
1355
1356
1357def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001358 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001359 return op
1360
1361
1362def tflite_optimise_graph(nng, arch):
1363 # Pre-processing step
1364 pre_process_list = [
1365 supported_operator_check,
1366 set_ifm_ofm_op_shapes,
1367 ]
1368
1369 for idx, sg in enumerate(nng.subgraphs):
1370 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1371 nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
1372 )
1373
1374 # Handle Concat Ops
1375 for idx, sg in enumerate(nng.subgraphs):
1376 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1377 sg.refresh_after_modification()
1378
1379 # Handle Split Ops
1380 for idx, sg in enumerate(nng.subgraphs):
1381 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1382 nng,
1383 sg,
1384 arch,
1385 [],
1386 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1387 rewrite_unsupported=False,
1388 )
1389
1390 for idx, sg in enumerate(nng.subgraphs):
1391 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1392 nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False,
1393 )
1394
1395 # Handle sg input output
1396 for idx, sg in enumerate(nng.subgraphs):
1397 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1398 nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
1399 )
1400
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001401 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001402 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001403 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001404 sg.refresh_after_modification()
1405
1406 # Rewrite of operators
1407 op_rewrite_list = [
1408 set_tensor_equivalence,
1409 convert_mean_to_depthwise_conv_or_avgpool,
1410 convert_depthwise_to_conv,
1411 convert_conv_to_fc,
1412 convert_softmax,
1413 optimise_strided_conv,
1414 convert_hardswish_to_lut,
1415 rewrite_fully_connected_input,
1416 convert_batched_fc_shape,
1417 fixup_conv2d_backprop,
1418 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001419 reorder_depthwise_weights,
1420 fixup_resizebilinear,
1421 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001422 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001423 convert_mul_max_to_abs_or_lrelu,
1424 convert_lrelu,
1425 convert_tanh_sigmoid_to_lut,
1426 replace_pad_by_hw_pad,
1427 ]
1428
1429 for idx, sg in enumerate(nng.subgraphs):
1430 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1431 nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
1432 )
1433
1434 for idx, sg in enumerate(nng.subgraphs):
1435 # remove passthrough tensors and attempt further optimizations
1436 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1437 nng,
1438 sg,
1439 arch,
1440 [remove_passthrough_tensor],
1441 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1442 )
1443
1444 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1445 # since ifm/ofm_shapes are of importance to this function
1446 for sg in nng.subgraphs:
1447 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1448 sg.refresh_after_modification()
1449
1450 return nng