blob: b48cc7af4ef2065b39ce24452c98d1755e19eeb5 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
25from . import lut
26from . import rewrite_graph
27from . import scaling
28from .api import NpuRoundingMode
29from .data_type import DataType
30from .debug_database import DebugDatabase
31from .errors import UnsupportedFeatureError
32from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020033from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020034from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020035from .graph_optimiser_util import convert_depthwise_to_conv
36from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020037from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020038from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020039from .graph_optimiser_util import needed_total_padding
40from .graph_optimiser_util import set_ifm_ofm_op_shapes
41from .graph_optimiser_util import set_tensor_equivalence
42from .numeric_util import clamp_sigmoid
43from .numeric_util import full_shape
44from .numeric_util import round_away_zero
45from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020046from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020047from .operation import NpuBlockType
48from .operation import Op
49from .operation import Operation
50from .operation import Padding
51from .operation_util import create_avgpool_nop
52from .operation_util import get_pad_values_from_input
53from .shape4d import Shape4D
54from .softmax import SoftMax
55from .tensor import check_quantized_tens_scaling_equal
56from .tensor import create_const_tensor
57from .tensor import create_equivalence_id
58from .tensor import QuantizationParameters
59from .tensor import Tensor
60from .tensor import TensorPurpose
61from .tflite_mapping import optype_to_builtintype
62
63passthrough_nodes = (Op.Identity,)
64
65
66def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
67 """Creates an average pool for the given concat op/input feature map"""
68 ofm = concat_op.ofm
69 avgpool_op = create_avgpool_nop(name)
70 avgpool_op.inputs = [ifm]
71 avgpool_op.outputs = [ofm]
72
73 avgpool_op.write_offset = write_offset
74 avgpool_op.write_shape = ifm_shape
75 ofm.ops.append(avgpool_op)
76 DebugDatabase.add_optimised(concat_op, avgpool_op)
77 avgpool_op.ifm_shapes.append(ifm_shape)
78 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
79 avgpool_op.memory_function = Op.ConcatSliceWrite
80 return avgpool_op
81
82
83def remove_passthrough_tensor(tens, arch, nng):
84 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
85 assert len(tens.ops[0].inputs) == 1
86 tens = tens.ops[0].inputs[0]
87 return tens
88
89
90def rewrite_concat_ops(op, arch):
91 if not op.run_on_npu or not op.type.is_concat_op():
92 return
93
94 axis_4D = 0
95 ofm = op.ofm
96 ofm.ops = []
97 offset = 0
98
99 unfuse_activation_function(op)
100
101 if op.type == Op.Pack:
102 # Pack is also referred to as Stack
103 axis = int(op.attrs["axis"])
104 if axis < 0: # Convert to positive axis
105 axis = len(op.inputs[0].shape) + 1 + axis
106
107 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
108
109 axis_4D = axis + (4 - len(desired_shape))
110
111 for idx, inp in enumerate(op.inputs):
112 op.ifm_shapes[idx] = Shape4D(desired_shape)
113 op.type = Op.PackReshaped
114
115 inputs, axis = op.get_concat_inputs_axis()
116 for idx, inp in enumerate(inputs):
117 if op.type != Op.PackReshaped:
118 op.ifm_shapes[idx] = Shape4D(inp.shape)
119 if axis >= 0:
120 axis_4D = axis + (4 - len(inp.shape))
121 else:
122 axis_4D = axis
123 write_offset = [0, 0, 0, 0]
124 write_offset[axis_4D] = offset
125 concat_end = offset + op.ifm_shapes[idx][axis_4D]
126 create_avg_pool_for_concat(
127 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
128 )
129 offset = concat_end
130 assert ofm.shape[axis] == offset
131
132 return op
133
134
135def rewrite_split_ops(tens, arch, nng):
136
137 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
138 split_op = tens.ops[0]
139
140 # Not supported so leave it and run on CPU
141 if not split_op.run_on_npu:
142 return tens
143
144 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
145
146 tens.ops = []
147 new_op = Operation(Op.SplitSliceRead, split_op.name)
148 new_op.inputs = [inp]
149 ofm_shape_idx = 0
150 read_shape = offset_end
151
152 # For Split the offset cannot be extracted from the tensor so it has to
153 # be calculated from the index of the output tensor
154 if axis is not None:
155 # Get the start and end of the split
156 offset_start = [0] * 4
157 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
158 for idx, out in enumerate(outputs):
159 if axis_4D_list is not None:
160 axis_4D = axis_4D_list[idx]
161 else:
162 split_op.ofm_shapes[idx] = Shape4D(out.shape)
163 if axis >= 0:
164 axis_4D = axis + (4 - len(out.shape))
165 else:
166 axis_4D = axis
167
168 if out == tens:
169 ofm_shape_idx = idx
170 read_shape = split_op.ofm_shapes[idx]
171 break
172
173 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
174
175 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
176 new_op.read_shapes[0] = read_shape
177 new_op.run_on_npu = True
178 new_op.set_output_tensor(tens)
179 new_op.ifm_shapes.append(Shape4D(inp.shape))
180 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
181 DebugDatabase.add_optimised(split_op, new_op)
182
183 return tens
184
185
186def remove_SplitSliceRead(op, arch):
187
188 if op.type == Op.SplitSliceRead:
189 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
190 if (
191 len(op.ofm.consumer_list) == 1
192 and op.ofm.consumer_list[0] is not None
193 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200194 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200195 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
196 ):
197 # SplitSliceRead can be performed by tensor consumer
198 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200199 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200200 else:
201 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
202 avgpool_op.add_input_tensor(op.ifm)
203 avgpool_op.outputs = [op.ofm]
204 op.ofm.ops.remove(op)
205 op.ofm.ops.append(avgpool_op)
206 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
207 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
208 avgpool_op.read_offsets[0] = op.read_offsets[0]
209 avgpool_op.read_shapes[0] = op.read_shapes[0]
210
211 op.ifm.consumer_list.remove(op)
212 DebugDatabase.add_optimised(op, avgpool_op)
213
214
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200215def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
216 k_w, k_h = kernel.dilated_wh()
217 s_x, s_y = kernel.stride
218 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
219 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
220 if padding_type == Padding.SAME:
221 left_pad = (xpad + 0) // 2
222 right_pad = (xpad + 1) // 2
223 top_pad = (ypad + 0) // 2
224 bottom_pad = (ypad + 1) // 2
225 elif padding_type == Padding.VALID:
226 left_pad = 0
227 right_pad = 0
228 top_pad = 0
229 bottom_pad = 0
230 elif padding_type == Padding.EXPLICIT:
231 # Padding is specified in a PAD operator which has been bypassed.
232 top, left, bottom, right = explicit_padding
233 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
234 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
235 else:
236 raise UnsupportedFeatureError(f"Unknown padding")
237 padding = (top_pad, left_pad, bottom_pad, right_pad)
238 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
239 return padding, skirt
240
241
242def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
243 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
244 if padding_type == Padding.SAME:
245 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
246 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
247 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
248 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
249 left_pad = max(kernel_width - 1 - right_pad, 0)
250 top_pad = max(kernel_height - 1 - bottom_pad, 0)
251 elif padding_type == Padding.VALID:
252 right_pad = max(kernel_width - 2, 0)
253 bottom_pad = max(kernel_height - 2, 0)
254 left_pad = kernel_width - 1
255 top_pad = kernel_height - 1
256 else:
257 raise UnsupportedFeatureError(f"Unknown padding")
258 padding = (top_pad, left_pad, bottom_pad, right_pad)
259 skirt = padding
260 return padding, skirt
261
262
263def fixup_conv2d_backprop(op, arch, nng):
264 if op.type == Op.Conv2DBackpropInput:
265 # flip the inputs
266 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
267 op.type = Op.Conv2DBackpropInputSwitchedBias
268 op.ifm.resampling_mode = resampling_mode.TRANSPOSE
269
270 # Update strides
271 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
272
273 return op
274
275
276# Convert the op to an elementwise add
277def convert_resizebilinear_1x1_to_add(op):
278 op.type = Op.Add
279 op.name = op.name + "_add"
280 op.attrs["resizebilinear"] = True
281 # Create an input tensor filled with zeros
282 shape = op.ofm_shapes[0].as_list()
283 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100284 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200285 tens.quantization = QuantizationParameters(0.0, 255.0)
286 tens.quantization.scale_f32 = 1.0
287 tens.quantization.zero_point = 0
288 tens.consumer_list = [op]
289 tens_op = op.inputs[1].ops[0]
290 tens_op.set_output_tensor(tens)
291 # Set the add inputs
292 op.inputs[1] = op.inputs[0]
293 op.inputs[0] = tens
294 op.set_ifm_ofm_shapes()
295
296 return op
297
298
299# Convert ResizeBilinear to a number of 2x2 pool ops
300def convert_resizebilinear_to_2x2_pool(op):
301 count = 0
302 pre_op = op
303 outputs = op.outputs
304
305 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
306 if op.attrs["align_corners"]:
307 shape_modifier = 1
308 op.attrs["padding"] = Padding.VALID
309 else:
310 shape_modifier = 0
311 op.attrs["padding"] = Padding.SAME
312 op.inputs[0].resampling_mode = resampling_mode.NEAREST
313
314 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
315 out_shape = np.array(op.ofm_shapes[0].get_hw_as_list())
316 if (upscaled_shape == upscaled_shape * 2 - shape_modifier).all():
317 return op
318
319 while (upscaled_shape < out_shape).all():
320 if count == 0:
321 scaled_op = pre_op
322 else:
323 scaled_op = op.clone("_{}".format(count))
324 scaled_op.inputs[0] = pre_op.outputs[0]
325
326 upscaled_shape = upscaled_shape * 2 - shape_modifier
327
328 if (upscaled_shape == out_shape).all():
329 scaled_op.outputs = outputs
330 scaled_op.outputs[0].ops = [scaled_op]
331 else:
332 shape = op.ofm_shapes[0].as_list()
333 shape[1:3] = upscaled_shape
334 out_tens = Tensor(shape, DataType.int16, "{}_{}".format(op.outputs[0].name, count))
335 out_tens.quantization = op.outputs[0].quantization.clone()
336 out_tens.quantization.quant_min = np.iinfo(np.int16).min
337 out_tens.quantization.quant_max = np.iinfo(np.int16).max
338 scaled_op.set_output_tensor(out_tens)
339 pre_op = scaled_op
340 count += 1
341
342 # Setup the scale value
343 if scaled_op.inputs[0].dtype.bits == 8 and scaled_op.outputs[0].dtype.bits == 16:
344 scaled_op.rescale = 128
345 elif scaled_op.inputs[0].dtype.bits == 16 and scaled_op.outputs[0].dtype.bits == 8:
346 scaled_op.rescale = 1 / 128
347 else:
348 scaled_op.rescale = None
349 scaled_op.set_ifm_ofm_shapes()
350
351 return op
352
353
354def fixup_resizebilinear(op, arch, nng):
355 if op.type == Op.ResizeBilinear and op.run_on_npu:
356 if op.ifm_shapes[0] == op.ofm_shapes[0]:
357 # Bypass nop resizebilinear
358 op.inputs = op.inputs[:1]
359 op.type = Op.Identity
360 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
361 convert_resizebilinear_1x1_to_add(op)
362 else:
363 convert_resizebilinear_to_2x2_pool(op)
364
365 return op
366
367
368def convert_nop_split_to_identity(op, arch, nng):
369 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
370 # the list comprehension should return a list with a single tensor
371 # if it shouldn't, remove_passthrough_tensor will fail appropriately
372 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
373 op.type = Op.Identity
374 return op
375
376
377def rewrite_fully_connected_input(op, arch, nng):
378 if op.type == Op.FullyConnected:
379 n_in_elems = op.weights.shape[-2]
380 elms = op.ifm.elements()
381 batch_size = elms // n_in_elems
382 assert batch_size * n_in_elems == elms
383
384 op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
385 return op
386
387
388def convert_batched_fc_shape(op, arch, nng):
389 if op.type == Op.FullyConnected:
390 # Check if the first dimension indicates batching
391 if op.ifm_shapes[0].batch > 1:
392 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
393 n = op.ifm_shapes[0].batch
394 h, w = batching_split.get(n, (1, n))
395 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
396
397 # Reshape Weights to be 4D. IO becomes HWIO
398 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100399 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
400 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200401
402 n = op.ofm_shapes[0].batch
403 h, w = batching_split.get(n, (1, n))
404 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
405 return op
406
407
408def unfuse_activation_function(op):
409 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
410 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
411 op.activation = None
412 out_tens = op.outputs[0]
413 intermediate_tens = out_tens.clone("_act_intermediate")
414 act_op.set_output_tensor(out_tens)
415 act_op.add_input_tensor(intermediate_tens)
416 op.set_output_tensor(intermediate_tens)
417 act_op.set_ifm_ofm_shapes()
418
419
420def rewrite_stridedslice_output(op, arch, nng):
421 if not op.run_on_npu or op.type != Op.StridedSlice:
422 return op
423
424 new_axis_mask = op.attrs["new_axis_mask"]
425 shrink_axis_mask = op.attrs["shrink_axis_mask"]
426
427 if shrink_axis_mask == 0 and new_axis_mask == 0:
428 return op
429
430 axis_4D = [0] * len(op.outputs)
431 for idx, out_tens in enumerate(op.outputs):
432 output_shape = list(out_tens.shape)
433
434 if shrink_axis_mask != 0:
435 n = 0
436 axis = 0
437 while shrink_axis_mask:
438 prev_mask = shrink_axis_mask
439 n += 1
440 shrink_axis_mask &= shrink_axis_mask - 1
441 axis = int(math.log2(prev_mask - shrink_axis_mask))
442 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
443
444 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
445 op.attrs["shrink_axis_mask"] = 0
446 if axis >= 0:
447 axis_4D[idx] = axis + (4 - len(output_shape))
448 else:
449 axis_4D[idx] = axis
450 op.ofm_shapes[idx] = Shape4D(output_shape)
451
452 elif new_axis_mask != 0:
453 n = 0
454 axis = 0
455 while new_axis_mask:
456 prev_mask = new_axis_mask
457 n += 1
458 new_axis_mask &= new_axis_mask - 1
459 axis = int(math.log2(prev_mask - new_axis_mask))
460 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
461 new_axis_mask >>= 1
462
463 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
464 op.attrs["new_axis_mask"] = 0
465 if axis >= 0:
466 axis_4D[idx] = axis + (4 - len(output_shape))
467 else:
468 axis_4D[idx] = axis
469 op.ofm_shapes[idx] = Shape4D(output_shape)
470
471 op.attrs["split_axis_4D"] = axis_4D
472 return op
473
474
475def rewrite_unpack_output(op, arch, nng):
476 tens = op.outputs[0]
477 if op.run_on_npu and op.type == Op.Unpack:
478 # Unpack is also referred to as Unstack
479 axis = int(op.attrs["axis"])
480 if axis < 0: # Convert to positive axis
481 axis = len(op.inputs[0].shape) + 1 + axis
482 op.type = Op.UnpackReshaped
483 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
484
485 axis_4D = axis + (4 - len(desired_output_shape))
486 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
487
488 for idx, out_tens in enumerate(op.outputs):
489 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
490 return op
491
492
493def add_padding_fields(op, arch, nng):
494 if op.run_on_npu:
495 if "padding" in op.attrs:
496 input_shape = op.ifm_shapes[0]
497 output_shape = op.ofm_shapes[0]
498 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
499 kernel_size = op.inputs[1].shape[:2]
500 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
501 kernel_size = op.attrs["ksize"][1:3]
502 else:
503 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
504
505 if op.type == Op.Conv2DBackpropInputSwitchedBias:
506 upscaling_factor = output_shape.height // input_shape.height
507 padding, skirt = calc_upscaled_padding_and_skirt(
508 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
509 )
510 else:
511 padding, skirt = calc_padding_and_skirt(
512 op.attrs["padding"], op.kernel, input_shape, op.attrs.get("explicit_padding"),
513 )
514
515 op.attrs["explicit_padding"] = padding
516 op.attrs["skirt"] = skirt
517
518 return op
519
520
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200521def reorder_depthwise_weights(op, arch, nng):
522 if op.type.is_depthwise_conv2d_op():
523 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100524 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
525 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200526 weight_tensor.weight_transpose_depthwise = True
527
528 return op
529
530
531def optimise_strided_conv(op, arch, nng):
532 stride_x, stride_y = op.get_kernel_stride()
533 ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()
534
535 if (
536 op.type == Op.Conv2DBias
537 and op.op_index == 0
538 and stride_x == 2
539 and op.ifm_shapes[0].depth <= 4
540 and op.ifm_shapes[0].width % 2 == 0
541 and weight_tensor is not None
542 and weight_tensor.shape[1] >= 2
543 ):
544 ifm_shape = op.ifm_shapes[0]
545 # IFM
546 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
547
548 # Weights
549 weight_shape = weight_tensor.shape
550 if weight_shape[1] % 2 != 0:
551 weight_shape[1] = weight_shape[1] + 1
552 padded_array = np.zeros(weight_shape)
553 for i in range(weight_shape[0]):
554 padded_array[i] = np.vstack(
555 [
James Peet7519d502021-07-19 16:47:58 +0100556 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200557 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
558 ]
559 )
James Peet7519d502021-07-19 16:47:58 +0100560 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200561 weight_shape[1] //= 2
562 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100563 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200564 weight_tensor.set_all_shapes(weight_shape)
565 # If multiple copies of the weights are used, we could avoid
566 # them having the same address by changing the value_id
567 weight_tensor.value_id = uuid.uuid4()
568
569 # Strides
570 stride_x = 1
571 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
572
573 return op
574
575
576def convert_conv_to_fc(op, arch, nng):
577 # Conv 1x1 can be equivalent to Fully Connected.
578 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
579 # caching/double buffering for the weights.
580 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
581 if op.type == Op.Conv2DBias:
582 h = op.ifm_shapes[0].height
583 w = op.ifm_shapes[0].width
584 kh, kw, _, _ = op.inputs[1].shape
585 if h == 1 and w == 1 and kh == 1 and kw == 1:
586 # Overwrite this op as a Fully Connected Op
587 op.name += "_fc"
588 op.type = Op.FullyConnected
589 op.attrs = {
590 "weights_format": 0,
591 }
592 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
593 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100594 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
595 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200596
597 DebugDatabase.add_optimised(op, op)
598 return op
599
600
601def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
602 if op.run_on_npu and op.type.is_relu_op():
603 ifm = op.inputs[0]
604 ofm = op.outputs[0]
605 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
606 # and requires its own to be inserted
607 if not check_quantized_tens_scaling_equal(ifm, ofm):
608 # Override this op with its own primary op (avgpool)
609 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
610 # And fuse the original activation function to it
611 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200612 # Add explicit rescaling
613 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
614 multiplier, shift = scaling.quantise_scale(rescale)
615 relu_fused_op.rescale = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200616 # Tidy up and assign the ifm and ofm to the new op
617 ifm.consumer_list.remove(op)
618
619 relu_fused_op.add_input_tensor(ifm)
620 relu_fused_op.set_output_tensor(ofm)
621 relu_fused_op.set_ifm_ofm_shapes()
622 op = relu_fused_op
623 return op
624
625
626def fixup_elementwise_with_scalars(op, arch, nng):
627 if op.type.is_binary_elementwise_op():
628 ifm_tensor, ifm2_tensor, _, _ = op.get_ifm_ifm2_weights_ofm()
629 if ifm2_tensor.shape != [] and ifm_tensor.shape != []:
630 diff = len(ifm_tensor.shape) - len(ifm2_tensor.shape)
631 if diff > 0:
632 ifm2_tensor.shape = full_shape(len(ifm_tensor.shape), ifm2_tensor.shape, 1)
633 elif diff < 0:
634 ifm_tensor.shape = full_shape(len(ifm2_tensor.shape), ifm_tensor.shape, 1)
James Peet7519d502021-07-19 16:47:58 +0100635 elif ifm_tensor.shape == [] and ifm_tensor.values is None:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200636 # IFM is marked as a scalar, but is a result of an operation; change it to a shape of size 1
637 ifm_tensor.shape = len(ifm2_tensor.shape) * [1]
638 ifm_tensor.storage_shape = ifm_tensor.shape
James Peet7519d502021-07-19 16:47:58 +0100639 elif ifm2_tensor.shape == [] and ifm2_tensor.values is None:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200640 # IFM2 is marked as a scalar, but is a result of an operation; change it to a shape of size 1
641 ifm2_tensor.shape = len(ifm_tensor.shape) * [1]
642 ifm2_tensor.storage_shape = ifm2_tensor.shape
643 return op
644
645
646def convert_softmax(op, arch, nng):
647 if op.type == Op.Softmax and op.run_on_npu:
648 softmax = SoftMax(op)
649 op = softmax.get_graph()
650 return op
651
652
653def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
654 r"""Whenever there is a subgraph with this topology:
655
656 Input X For X = -1 or X > 0
657 | \ / This subgraph can be replaced with either
658 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
659 | /
660 Max
661 """
662
663 if op.type == Op.Maximum:
664 # finds the Mul input(s) to the Max
665 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
666 if len(muls) == 1:
667 mul = muls[0].ops[0]
668 elif len(muls) == 2:
669 # In the case both inputs are Muls, find the one with the same input as the Max
670 mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
671 else:
672 # No Mul inputs
673 return op
674
675 # make sure the Mul doesn't have any other consumers
676 mul_ofm = mul.outputs[0]
677 if len(mul_ofm.consumers()) != 1:
678 return op
679 # make sure the Mul doesn't have a fused activation function
680 if mul.activation:
681 return op
682 ifm, ofm = op.get_ifm_ofm()
683 if ifm is None or ofm is None:
684 return op
685
686 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
687 return op
688 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
689 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
690 return op
691
692 # finds the branched input that goes to both the Max and the Mul
693 shared = set(op.inputs) & set(mul.inputs)
694 if len(shared) == 1:
695 shared_in = shared.pop()
696 # find the constant scalar input to the Mul
697 const_tens = (set(mul.inputs) - {shared_in}).pop()
698 # check that it is a scalar
699 if const_tens.shape != []:
700 return op
701 const = const_tens.ops[0]
702 # check that it is a constant
703 if const.type != Op.Const:
704 return op
705 # Remove the Mul from the shared input's consumers
706 shared_in.consumer_list.remove(mul)
707 else:
708 return op
709
710 val = const.outputs[0].values
711 if val >= 0:
712 new_op = Op.LeakyRelu
713 op.attrs["alpha"] = val
714 # to produce bit exact results, the alpha is not enough;
715 # save additional scaling info in attr "alpha_scale", to be used as input
716 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +0100717 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200718 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
719 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
720 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
721 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
722 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
723 elif val == -1:
724 new_op = Op.Abs
725 else:
726 return op
727
728 op.type = new_op
729 op.name = op.name.replace("Maximum", new_op.name)
730 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
731 op.inputs = [shared_in]
732 op.set_ifm_ofm_shapes()
733
734 # Record optimisation in debug database
735 DebugDatabase.add_optimised(op, op)
736
737 return op
738
739
740def convert_hardswish_to_lut(op, arch, nng):
741 if op.type == Op.HardSwish:
742 ifm, ofm = op.get_ifm_ofm()
743 # Generate the LUT
744 ifm_scale = np.double(ifm.quantization.scale_f32)
745 ofm_scale = np.double(ofm.quantization.scale_f32)
746 zp_in = ifm.quantization.zero_point
747 zp_out = ofm.quantization.zero_point
748 ifm_scale_hires = (1 / 128) * ifm_scale
749 relu_multiplier = np.double(3 / 32768)
750 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
751 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
752 # Use 16bit scale
753 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
754 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
755
756 values = []
757 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
758 quantized_min = min(ix)
759 quantized_max = max(ix)
760 for x in ix:
761 input_value = x - zp_in
762 input_value_hires = input_value * 128
763 # Compute the input value on essentially the output scale, not shifted yet
764 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
765 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
766 relu_value = np.int16(input_value_hires)
767 if relu_shift < 31:
768 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
769
770 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
771
772 if relu_shift < 31:
773 relu_value = fp_math.shift_left16(relu_value, 1)
774
775 if relu_shift > 31:
776 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
777
778 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
779 # Now convert that to a 16bit fixedpoint value in [0, 1]
780 relu_value = (relu_value + (1 << 15)) >> 1
781 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
782 shift = 31 - out_shift
783 shift = -shift if shift < 0 else 0
784 # Finally apply the output shift
785 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
786 lut_result = min(quantized_max, max(quantized_min, lut_result))
787 values.append(lut_result)
788 return convert_to_lut(op, values, "hardswish")
789 return op
790
791
792def convert_lrelu_to_mul_max(op, arch):
793 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
794 # (the opposite of convert_mul_max_to_abs_or_lrelu)
795 ifm, ofm = op.get_ifm_ofm()
796 if ifm is None or ofm is None:
797 return op
798
799 # Add multiplication with alpha
800 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
801 mul_alpha.add_input_tensor(ifm)
802 # Create const tensor containing alpha as scalar
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200803 alpha = np.float32(op.attrs["alpha"])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200804 quantization = ifm.quantization.clone()
805 quantization.min = 0
806 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
807 quantization.zero_point = 0
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200808 if np.isinf(1 / alpha):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200809 # Handling of alpha near zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200810 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200811 scalar = 0
812 else:
813 quantization.scale_f32 = alpha
814 scalar = alpha
815 alpha_tens = create_const_tensor(
816 op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
817 )
James Peet7519d502021-07-19 16:47:58 +0100818 alpha_tens.values = np.array([1])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200819 mul_alpha.add_input_tensor(alpha_tens)
820 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
821 mul_alpha.set_output_tensor(fm_alpha)
822 mul_alpha.set_ifm_ofm_shapes()
823 DebugDatabase.add_optimised(op, mul_alpha)
824
825 if check_quantized_tens_scaling_equal(ifm, ofm):
826 # No identity multiplication is needed
827 fm_id = ifm
828 else:
829 # Add multiplication with identity
830 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
831 mul_identity.add_input_tensor(ifm)
832 # Create const tensor containing identity as scalar
833 quantization = ifm.quantization.clone()
834 quantization.min = 0
835 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200836 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200837 quantization.zero_point = 0
838 identity_tens = create_const_tensor(
839 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
840 )
841 mul_identity.add_input_tensor(identity_tens)
842 # Make sure that fm_id is allocated to a different address than fm_alpha
843 fm_id = ofm.clone(op.name + "_id", set_unique=True)
844 mul_identity.set_output_tensor(fm_id)
845 mul_identity.set_ifm_ofm_shapes()
846 DebugDatabase.add_optimised(op, mul_identity)
847
848 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
849 op.type = Op.Maximum
850 op.name = op.name.replace("LeakyRelu", "Maximum")
851 op.inputs = []
852 ifm.consumer_list.remove(op)
853 op.add_input_tensor(fm_alpha)
854 op.add_input_tensor(fm_id)
855 op.set_ifm_ofm_shapes()
856
857 DebugDatabase.add_optimised(op, op)
858 return op
859
860
861def convert_to_lut(op, lut_values, lut_name):
862 # Rewrite the operation by Add with scalar 0 + LUT activation
863 ifm = op.inputs[0]
864 if ifm is None:
865 return op
866 assert ifm.dtype.size_in_bytes() == 1
867 op.type = Op.Add
868 op.name = op.name + "_lut_" + lut_name
869 # Mark as no-op to enable potential fusing optimizations
870 op.attrs["is_nop"] = True
871 # Create an input tensor containing scalar zero
872 quantization = QuantizationParameters(0.0, 255.0)
873 quantization.scale_f32 = ifm.quantization.scale_f32
874 quantization.zero_point = 0
875 tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization)
876 op.add_input_tensor(tens)
877 op.ifm_shapes.append(Shape4D(tens.shape))
878
879 # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
880 # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
881 # should be the same as the IFM
882 op.forced_output_quantization = ifm.quantization
883 lut_tensor = lut.create_lut_tensor(op.name + "_values", lut_values, DataType.int8)
884 op.set_activation_lut(lut_tensor)
885 op.set_ifm_ofm_shapes()
886 return op
887
888
889def convert_to_lut8(op, fn, fn_name):
890 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
891 # fn is a function(real) -> real
892 ifm, ofm = op.get_ifm_ofm()
893 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
894 return op
895 # Generate the LUT
896 ifm_scale = np.double(ifm.quantization.scale_f32)
897 ofm_scale = np.double(ofm.quantization.scale_f32)
898 zp_in = ifm.quantization.zero_point
899 zp_out = ofm.quantization.zero_point
900 values = []
901 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
902 quantized_min = min(ix)
903 quantized_max = max(ix)
904 for x in ix:
905 x_real = ifm_scale * (x - zp_in)
906 y_real = fn(x_real)
907 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
908 lut_result = min(quantized_max, max(quantized_min, lut_result))
909 values.append(lut_result)
910 return convert_to_lut(op, values, fn_name)
911
912
913def convert_lrelu_to_lut(op, arch):
914 ifm, ofm = op.get_ifm_ofm()
915 # Generate the LUT
916 alpha = op.attrs["alpha"]
917 ifm_scale = np.double(ifm.quantization.scale_f32)
918 ofm_scale = np.double(ofm.quantization.scale_f32)
919 zp_in = ifm.quantization.zero_point
920 zp_out = ofm.quantization.zero_point
921 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
922 alpha_scalar = 1
923 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
924 if "alpha_scaling" in op.attrs:
925 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
926 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
927 values = []
928 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
929 quantized_min = min(ix)
930 quantized_max = max(ix)
931 for x in ix:
932 if x < zp_in:
933 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
934 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
935 )
936 else:
937 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
938 lut_result = min(quantized_max, max(quantized_min, lut_result))
939 values.append(lut_result)
940 return convert_to_lut(op, values, "lrelu")
941
942
943def convert_lrelu(op, arch, nng):
944 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
945 if op.type != Op.LeakyRelu:
946 return op
947 ifm, ofm = op.get_ifm_ofm()
948 if ifm is None or ofm is None:
949 return op
950 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
951 # use LUT for int8/uint8
952 return convert_lrelu_to_lut(op, arch)
953 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
954 # use LeakyRelu unmodified for int16 with equal input/output scaling
955 return op
956 return convert_lrelu_to_mul_max(op, arch)
957
958
959def convert_tanh_sigmoid_to_lut(op, arch, nng):
960 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
961 if op.type == Op.Sigmoid:
962 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
963 elif op.type == Op.Tanh:
964 return convert_to_lut8(op, math.tanh, "tanh")
965 return op
966
967
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200968def remove_memory_only_ops(op, arch):
969 if op.run_on_npu and op.type in memory_only_ops:
970 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200971
972
973def fuse_activation_function_with_prev(op, arch, nng):
974 # if op is a no-op: attempts to move the activation function to the preceding op
975 if not op.attrs.get("is_nop", False) or op.activation is None:
976 return op
977 ifm, ofm = op.get_ifm_ofm()
978 if ifm is None or ofm is None:
979 return op
980 # finds the input(s) to the operation
981 prev_op = ifm.ops[0]
982 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
983 fuse = (
984 prev_op.run_on_npu
985 and prev_op.type.npu_block_type != NpuBlockType.Default
986 and len(ifm.ops) == 1
987 and len(prev_op.outputs[0].consumers()) == 1
988 and prev_op.activation is None
989 )
990 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
991 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
992 # LUT currently only works correctly for elementwise ops
993 fuse = False
994 if not fuse:
995 return op
996 # Move the fused activation function + corresponding info to prev_op
997 prev_op.activation = op.activation
998 prev_op.forced_output_quantization = op.forced_output_quantization
999 if op.activation_lut is not None:
1000 prev_op.set_activation_lut(op.activation_lut)
1001 # Bypass op
1002 prev_op.set_output_tensor(ofm)
1003 DebugDatabase.add_optimised(op, prev_op)
1004 return op
1005
1006
1007def _leading_pad_ok(leading_pad, stride, kernel_size):
1008 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1009 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1010 max_size = kernel_size // 2
1011 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1012
1013
1014def replace_pad_by_hw_pad(op: Operation, arch, nng):
1015 """
1016 Tries to completely remove a PAD operator by using hardware padding.
1017 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1018 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1019 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1020 if both operations can be run on the NPU.
1021 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1022 """
1023 if (
1024 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
1025 and op.run_on_npu
1026 and op.attrs["padding"] == Padding.VALID
1027 ):
1028 pad_op = op.ifm.ops[0]
1029 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1030 return op
1031 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1032 return op
1033 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1034 k = op.kernel
1035 k_w, k_h = k.dilated_wh()
1036
1037 # Check if the PAD operator can be replaced by hardware padding
1038 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1039 # Too much padding, it would require hardware padding to actually insert zeros
1040 return op
1041 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1042 return op
1043
1044 if op.type.is_avgpool_op():
1045 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1046 for pad, k_size in (
1047 (left, k_w),
1048 (right, k_w),
1049 (top, k_h),
1050 (bottom, k_h),
1051 ):
1052 if pad not in (0, k_size // 2):
1053 return op
1054 # Average pool is converted to depthwise, because NPU average pool + same padding
1055 # has a special implementation that is different from PAD followed by average pool with
1056 # valid padding.
1057 k_w, k_h = op.kernel.width, op.kernel.height
1058 ifm = op.ifm
1059 # Remember other inputs
1060 other_inputs = op.inputs[1:]
1061 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1062 quantization = QuantizationParameters(0.0, 255.0)
1063 quantization.scale_f32 = 1.0 / (k_w * k_h)
1064 quantization.zero_point = 0
1065 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1066 weights = np.full(shape, 1)
1067
1068 weight_tens = create_const_tensor(
1069 op.name + "_weights",
1070 shape,
1071 op.ifm.dtype,
1072 weights,
1073 np.uint8,
1074 purpose=TensorPurpose.Weights,
1075 quantization=quantization,
1076 )
James Peet7519d502021-07-19 16:47:58 +01001077 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001078 op.type = Op.DepthwiseConv2DBias
1079 op.inputs = []
1080 op.add_input_tensor(ifm)
1081 op.add_input_tensor(weight_tens)
1082 # Add bias tensor, all biases set to 0
1083 op.inputs.append(None)
1084 fixup_bias_tensors(op, arch, nng)
1085 # Add other inputs
1086 op.inputs.extend(other_inputs)
1087 op.rounding_mode = NpuRoundingMode.NATURAL
1088
1089 # Bypass the PAD operator
1090 op.set_input_tensor(pad_op.ifm, 0)
1091 # Adjust the padding attributes of the convolution operator
1092 op.attrs["padding"] = Padding.EXPLICIT
1093 op.attrs["explicit_padding"] = (top, left, bottom, right)
1094 op.set_ifm_ofm_shapes()
1095 return op
1096
1097
1098def convert_pad(op: Operation, arch, nng):
1099 """
1100 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1101 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1102 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1103 """
1104 if op.type != Op.Pad or not op.run_on_npu:
1105 return op
1106 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1107
1108 ifm = op.ifm
1109 assert ifm is not None
1110 ifm_shape = Shape4D(ifm.shape)
1111 ofm = op.ofm
1112 assert ofm is not None
1113 ofm.ops = []
1114 ofm_shape = op.ofm_shapes[0]
1115
1116 # Average pool op that copies IFM to the right place inside the OFM
1117 shp0 = Shape4D(0, 0, 0, 0)
1118 shp_top = shp0.with_height(top)
1119 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1120 avgpool_op.activation = op.activation
1121 quant = ofm.quantization
1122 pad_value = quant.zero_point
1123 # Add operations that fill the borders of the OFM
1124 if top > 0:
1125 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1126 zero_tens = create_const_tensor(
1127 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1128 )
1129 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1130 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1131 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1132 if bottom > 0:
1133 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1134 zero_tens = create_const_tensor(
1135 op.name + "_bottom",
1136 shape.as_list(),
1137 ofm.dtype,
1138 shape.elements() * [pad_value],
1139 np.uint8,
1140 quantization=quant,
1141 )
1142 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1143 create_avg_pool_for_concat(
1144 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1145 )
1146 if left > 0:
1147 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1148 zero_tens = create_const_tensor(
1149 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1150 )
1151 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1152 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1153 if right > 0:
1154 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1155 zero_tens = create_const_tensor(
1156 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1157 )
1158 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1159 create_avg_pool_for_concat(
1160 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1161 )
1162
1163 op.type = Op.ConcatTFLite
1164 return avgpool_op
1165
1166
1167def add_attrs_to_resizebilinear(op, arch, nng):
1168 if op.type == Op.ResizeBilinear and op.run_on_npu:
1169 input_tensor = op.inputs[0]
1170 input_shape = op.ifm_shapes[0]
1171 upscaled_height = input_shape.height * 2
1172 upscaled_width = input_shape.width * 2
1173 out_shape = op.ofm_shapes[0]
1174 if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
1175 # this means the output is supposed to be a x2 upscale,
1176 # so we need to do SAME padding
1177 op.attrs["padding"] = Padding.SAME
1178 elif (
1179 op.attrs["align_corners"]
1180 and out_shape.height == (upscaled_height - 1)
1181 and out_shape.width == (upscaled_width - 1)
1182 ):
1183 # here we can just run the avg pool without padding and
1184 # produce a (M * 2 - 1, N * 2 - 1) sized output
1185 op.attrs["padding"] = Padding.VALID
1186 else:
1187 return op
1188 input_tensor.resampling_mode = resampling_mode.NEAREST
1189 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
1190 return op
1191
1192
1193def fixup_bias_tensors(op, arch, nng):
1194 if op.type.needs_bias() and op.bias is None:
1195 # Op has no bias, add bias tensor filled with zeros
1196 nr_biases = op.inputs[1].shape[-1]
1197 bias_values = [0] * nr_biases
1198 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001199 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1200
1201 return op
1202
1203
1204def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1205 if op.type == Op.Mean and op.run_on_npu:
1206 keep_dims = op.attrs.get("keep_dims", False)
1207 inp, axis = op.inputs
1208 shape = inp.shape
1209 dims = len(shape)
1210
1211 # Height and width axes have different index depending on dimensions
1212 if axis.shape == [] or axis.shape[0] == 1: # single axis
1213 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1214 if dims in (2, 3):
1215 if axis == 0:
1216 h, w = shape[axis], 1
1217 else:
1218 h, w = 1, shape[axis]
1219 else:
1220 if axis == 1:
1221 h, w = shape[axis], 1
1222 else:
1223 h, w = 1, shape[axis]
1224 else: # multiple axes
1225 axis = sorted(axis.values)
1226 h, w = [shape[i] for i in axis]
1227
1228 # Set necessary depthwise attributes
1229 op.attrs.update(
1230 {
1231 "padding": Padding.VALID,
1232 "stride_h": 1,
1233 "stride_w": 1,
1234 "strides": (1, 1, 1, 1),
1235 "depth_multiplier": 1,
1236 "channel_multiplier": 1,
1237 "dilation_h_factor": 1,
1238 "dilation_w_factor": 1,
1239 "dilation": (1, 1, 1, 1),
1240 }
1241 )
1242 # Change op type
1243 op.type = Op.DepthwiseConv2DBias
1244 # Set IFM/OFM shapes after changing op type
1245 op.set_ifm_ofm_shapes()
1246
1247 weight_scale, bias = 1, None
1248 ofmq, ifmq = op.ofm.quantization, inp.quantization
1249 # Set rounding mode, scaling and zero point based on which reference implementation to match
1250 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1251 if inp.dtype == DataType.uint8:
1252 # This attribute means a different scaling calculation is used in order to match reference
1253 op.low_precision_scaling = True
1254 weight_scale = h * w
1255 # Set zero points to 0 as they will be adjusted for with bias term
1256 foq = ofmq.clone()
1257 foq.zero_point = 0
1258 fiq = ifmq.clone()
1259 fiq.zero_point = 0
1260 op.forced_input_quantization = fiq
1261 bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
1262 # If the bias term is outside uint8 range, we need an Add op to apply it.
1263 if bias_term < 0 or bias_term > 255:
1264 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1265 # Bias term has higher bitness (i32) than input/output (u8).
1266 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1267 # the bias can only effectively assume values in the range [-255, 255].
1268 intermediate.dtype = DataType.int16
1269 intermediate.quantization.zero_point = 0
1270 add_op = Operation(Op.Add, op.name + "_bias")
1271 add_op.forced_output_quantization = foq
1272 add_op.add_input_tensor(intermediate)
1273 quant = QuantizationParameters()
1274 quant.zero_point = 0
1275 bias_term_tens = create_const_tensor(
James Peet7519d502021-07-19 16:47:58 +01001276 op.name + "_bias", [1, 1, 1, 1], DataType.int16, [bias_term], np.int16, quantization=quant,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001277 )
1278 add_op.add_input_tensor(bias_term_tens)
1279 add_op.set_output_tensor(op.ofm)
1280 add_op.set_ifm_ofm_shapes()
1281 add_op.activation = op.activation
1282 op.activation = None
1283 op.set_output_tensor(intermediate)
1284 op.set_ifm_ofm_shapes()
1285 # If not, we can just do it with the OFM zero point.
1286 else:
1287 foq.zero_point = bias_term
1288 op.forced_output_quantization = foq
1289 else:
1290 assert inp.dtype == DataType.int8
1291 # Use a depthwise to calculate the sum,
1292 # followed by a multiplication with 1/N to get the MEAN
1293 weight_scale = 1
1294 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1295 intermediate.dtype = DataType.int16
1296 mul_op = Operation(Op.Mul, op.name + "_mul")
1297 mul_op.add_input_tensor(intermediate)
1298 # Create scalar containing 1/N
1299 quant = QuantizationParameters()
1300 quant.zero_point = 0
1301 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1302 # while rounding mode NATURAL would round this to -1.
1303 # This can only occur if N is even, and can be emulated by
1304 # multiplying with a number that is slightly smaller than 1/N.
1305 # It must be so small that other roundings are not affected;
1306 # the calculated value is based on worst case,
1307 # which is sum 256 * N (the maximum sum that can occur with int8)
1308 n = int(h * w)
1309 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1310 quant.scale_f32 = 1 / (n - eps)
1311 scalar = create_const_tensor(
1312 op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
1313 )
1314 mul_op.add_input_tensor(scalar)
1315 mul_op.set_output_tensor(op.ofm)
1316 mul_op.set_ifm_ofm_shapes()
1317 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1318 mul_op.activation = op.activation
1319 op.activation = None
1320 op.set_output_tensor(intermediate)
1321 op.set_ifm_ofm_shapes()
1322 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1323 # Here we can just use a simple AvgPool with truncating rounding,
1324 # as we're emulating simple integer division.
1325 op.rounding_mode = NpuRoundingMode.TRUNCATE
1326 op.type = Op.AvgPool
1327 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1328 else:
1329 op.rounding_mode = NpuRoundingMode.NATURAL
1330 weight_scale = 1 / (h * w)
1331 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1332 bias = -ifmq.zero_point * h * w
1333 fiq = ifmq.clone()
1334 fiq.zero_point = 0
1335 op.forced_input_quantization = fiq
1336
1337 # Change dimensions to 4
1338 if dims < 4:
1339 shape = [1] + shape
1340 if dims == 2:
1341 shape += [1]
1342
1343 # If height is greater than max kernel height, reshape to from HxW to 1x(HxW)
1344 if h > 64:
1345 shape = [shape[0], 1, h * w, shape[3]]
1346 op.ifm_shapes[0] = Shape4D(shape)
1347 if h > 256 and op.type == Op.AvgPool:
1348 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1349
1350 # If the AvgPool version is used, we don't need to do anything else
1351 if op.type == Op.AvgPool:
1352 return op
1353
1354 # Make unit weight tensor quantization
1355 weight_quant = ifmq.clone()
1356 weight_quant.min = 0
1357 weight_quant.max = 255
1358 weight_quant.scale_f32 = weight_scale
1359 weight_quant.zero_point = 0
1360
1361 # Set weight shape to [H,W,C,B]
1362 weight_shape = shape[1:4] + [shape[0]]
1363 # Add unit weight tensor
1364 op.set_input_tensor(
1365 create_const_tensor(
1366 "weights",
1367 weight_shape,
1368 inp.dtype,
1369 np.ones(weight_shape),
1370 value_dtype=np.uint8,
1371 quantization=weight_quant,
1372 ),
1373 1,
1374 )
James Peet7519d502021-07-19 16:47:58 +01001375 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001376
1377 # Add None bias tensor
1378 op.inputs.append(None)
1379 # Add bias tensor
1380 if bias:
1381 bias_shape = [shape[-1]]
1382 op.set_input_tensor(
1383 create_const_tensor(
Tim Hall8ae29292021-07-28 16:52:03 +01001384 "bias", bias_shape, inp.dtype, np.ones(bias_shape) * bias, value_dtype=np.int32, quantization=None,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001385 ),
1386 2,
1387 )
1388
1389 return op
1390
1391
1392def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001393 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001394 return op
1395
1396
1397def tflite_optimise_graph(nng, arch):
1398 # Pre-processing step
1399 pre_process_list = [
1400 supported_operator_check,
1401 set_ifm_ofm_op_shapes,
1402 ]
1403
1404 for idx, sg in enumerate(nng.subgraphs):
1405 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1406 nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
1407 )
1408
1409 # Handle Concat Ops
1410 for idx, sg in enumerate(nng.subgraphs):
1411 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1412 sg.refresh_after_modification()
1413
1414 # Handle Split Ops
1415 for idx, sg in enumerate(nng.subgraphs):
1416 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1417 nng,
1418 sg,
1419 arch,
1420 [],
1421 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1422 rewrite_unsupported=False,
1423 )
1424
1425 for idx, sg in enumerate(nng.subgraphs):
1426 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1427 nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False,
1428 )
1429
1430 # Handle sg input output
1431 for idx, sg in enumerate(nng.subgraphs):
1432 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1433 nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
1434 )
1435
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001436 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001437 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001438 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001439 sg.refresh_after_modification()
1440
1441 # Rewrite of operators
1442 op_rewrite_list = [
1443 set_tensor_equivalence,
1444 convert_mean_to_depthwise_conv_or_avgpool,
1445 convert_depthwise_to_conv,
1446 convert_conv_to_fc,
1447 convert_softmax,
1448 optimise_strided_conv,
1449 convert_hardswish_to_lut,
1450 rewrite_fully_connected_input,
1451 convert_batched_fc_shape,
1452 fixup_conv2d_backprop,
1453 fixup_relus_with_differing_ifm_ofm_scaling,
1454 fixup_elementwise_with_scalars,
1455 reorder_depthwise_weights,
1456 fixup_resizebilinear,
1457 fixup_bias_tensors,
1458 convert_mul_max_to_abs_or_lrelu,
1459 convert_lrelu,
1460 convert_tanh_sigmoid_to_lut,
1461 replace_pad_by_hw_pad,
1462 ]
1463
1464 for idx, sg in enumerate(nng.subgraphs):
1465 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1466 nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
1467 )
1468
1469 for idx, sg in enumerate(nng.subgraphs):
1470 # remove passthrough tensors and attempt further optimizations
1471 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1472 nng,
1473 sg,
1474 arch,
1475 [remove_passthrough_tensor],
1476 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1477 )
1478
1479 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1480 # since ifm/ofm_shapes are of importance to this function
1481 for sg in nng.subgraphs:
1482 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1483 sg.refresh_after_modification()
1484
1485 return nng