blob: ed8fa1e3f3bec96f95d94ebdaf5462763aa42022 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of a TensorFlow Lite based network graph, using the rewrite_graph module
18# to do the traversal of the graph.
19import math
20import uuid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021
22import numpy as np
23
24from . import fp_math
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from . import rewrite_graph
26from . import scaling
27from .api import NpuRoundingMode
Fredrik Svedberga04f2f72022-07-06 13:42:24 +020028from .data_type import BaseType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020029from .data_type import DataType
30from .debug_database import DebugDatabase
31from .errors import UnsupportedFeatureError
32from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020033from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020034from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020035from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020036from .graph_optimiser_util import convert_to_lut
Patrik Gustavssondf995102021-08-23 15:33:59 +020037from .graph_optimiser_util import fix_sg_input_output
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020038from .graph_optimiser_util import memory_only_ops
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020039from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020040from .graph_optimiser_util import needed_total_padding
41from .graph_optimiser_util import set_ifm_ofm_op_shapes
42from .graph_optimiser_util import set_tensor_equivalence
43from .numeric_util import clamp_sigmoid
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020044from .numeric_util import round_away_zero
45from .operation import create_activation_function
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +020046from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020047from .operation import NpuBlockType
48from .operation import Op
49from .operation import Operation
50from .operation import Padding
51from .operation_util import create_avgpool_nop
52from .operation_util import get_pad_values_from_input
Ayaan Masood25f48dd2022-06-29 18:16:04 +010053from .scaling import quantise_scale
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020054from .shape4d import Shape4D
55from .softmax import SoftMax
56from .tensor import check_quantized_tens_scaling_equal
57from .tensor import create_const_tensor
58from .tensor import create_equivalence_id
59from .tensor import QuantizationParameters
60from .tensor import Tensor
61from .tensor import TensorPurpose
62from .tflite_mapping import optype_to_builtintype
63
64passthrough_nodes = (Op.Identity,)
65
66
67def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
68 """Creates an average pool for the given concat op/input feature map"""
69 ofm = concat_op.ofm
70 avgpool_op = create_avgpool_nop(name)
71 avgpool_op.inputs = [ifm]
72 avgpool_op.outputs = [ofm]
73
74 avgpool_op.write_offset = write_offset
75 avgpool_op.write_shape = ifm_shape
76 ofm.ops.append(avgpool_op)
77 DebugDatabase.add_optimised(concat_op, avgpool_op)
78 avgpool_op.ifm_shapes.append(ifm_shape)
79 avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
80 avgpool_op.memory_function = Op.ConcatSliceWrite
81 return avgpool_op
82
83
84def remove_passthrough_tensor(tens, arch, nng):
85 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
86 assert len(tens.ops[0].inputs) == 1
87 tens = tens.ops[0].inputs[0]
88 return tens
89
90
91def rewrite_concat_ops(op, arch):
92 if not op.run_on_npu or not op.type.is_concat_op():
93 return
94
95 axis_4D = 0
96 ofm = op.ofm
97 ofm.ops = []
98 offset = 0
99
100 unfuse_activation_function(op)
101
102 if op.type == Op.Pack:
103 # Pack is also referred to as Stack
104 axis = int(op.attrs["axis"])
105 if axis < 0: # Convert to positive axis
106 axis = len(op.inputs[0].shape) + 1 + axis
107
108 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
109
110 axis_4D = axis + (4 - len(desired_shape))
111
112 for idx, inp in enumerate(op.inputs):
113 op.ifm_shapes[idx] = Shape4D(desired_shape)
114 op.type = Op.PackReshaped
115
116 inputs, axis = op.get_concat_inputs_axis()
117 for idx, inp in enumerate(inputs):
118 if op.type != Op.PackReshaped:
119 op.ifm_shapes[idx] = Shape4D(inp.shape)
120 if axis >= 0:
121 axis_4D = axis + (4 - len(inp.shape))
122 else:
123 axis_4D = axis
124 write_offset = [0, 0, 0, 0]
125 write_offset[axis_4D] = offset
126 concat_end = offset + op.ifm_shapes[idx][axis_4D]
127 create_avg_pool_for_concat(
128 op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
129 )
130 offset = concat_end
131 assert ofm.shape[axis] == offset
132
133 return op
134
135
136def rewrite_split_ops(tens, arch, nng):
137
138 if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack:
139 split_op = tens.ops[0]
140
141 # Not supported so leave it and run on CPU
142 if not split_op.run_on_npu:
143 return tens
144
145 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
146
147 tens.ops = []
148 new_op = Operation(Op.SplitSliceRead, split_op.name)
149 new_op.inputs = [inp]
150 ofm_shape_idx = 0
Tim Hall51a8dce2021-12-20 16:49:27 +0000151 if None in (offset_end, offset_start):
152 read_shape = None
153 else:
154 # the read shape is relative to each start offset
155 read_shape = [oe - os for oe, os in zip(offset_end, offset_start)]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200156
157 # For Split the offset cannot be extracted from the tensor so it has to
158 # be calculated from the index of the output tensor
159 if axis is not None:
160 # Get the start and end of the split
161 offset_start = [0] * 4
162 axis_4D_list = split_op.attrs.get("split_axis_4D", None) # Present for UnpackReshaped and some StridedSlice
163 for idx, out in enumerate(outputs):
164 if axis_4D_list is not None:
165 axis_4D = axis_4D_list[idx]
166 else:
167 split_op.ofm_shapes[idx] = Shape4D(out.shape)
168 if axis >= 0:
169 axis_4D = axis + (4 - len(out.shape))
170 else:
171 axis_4D = axis
172
173 if out == tens:
174 ofm_shape_idx = idx
175 read_shape = split_op.ofm_shapes[idx]
176 break
177
178 offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
179
180 new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
181 new_op.read_shapes[0] = read_shape
182 new_op.run_on_npu = True
183 new_op.set_output_tensor(tens)
184 new_op.ifm_shapes.append(Shape4D(inp.shape))
185 new_op.ofm_shapes.append(split_op.ofm_shapes[ofm_shape_idx])
186 DebugDatabase.add_optimised(split_op, new_op)
187
188 return tens
189
190
191def remove_SplitSliceRead(op, arch):
192
193 if op.type == Op.SplitSliceRead:
194 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
195 if (
196 len(op.ofm.consumer_list) == 1
197 and op.ofm.consumer_list[0] is not None
198 and op.ofm.consumer_list[0].run_on_npu
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200199 and op.ofm.consumer_list[0].type not in memory_only_ops
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200200 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
201 ):
202 # SplitSliceRead can be performed by tensor consumer
203 cons_op = op.ofm.consumer_list[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200204 move_splitsliceread_to_consumer(op, cons_op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200205 else:
206 avgpool_op = create_avgpool_nop(op.name + "_avgpool")
207 avgpool_op.add_input_tensor(op.ifm)
208 avgpool_op.outputs = [op.ofm]
209 op.ofm.ops.remove(op)
210 op.ofm.ops.append(avgpool_op)
211 avgpool_op.ifm_shapes.append(op.ifm_shapes[0])
212 avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
213 avgpool_op.read_offsets[0] = op.read_offsets[0]
214 avgpool_op.read_shapes[0] = op.read_shapes[0]
215
216 op.ifm.consumer_list.remove(op)
217 DebugDatabase.add_optimised(op, avgpool_op)
218
219
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200220def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
221 k_w, k_h = kernel.dilated_wh()
222 s_x, s_y = kernel.stride
223 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
224 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
225 if padding_type == Padding.SAME:
226 left_pad = (xpad + 0) // 2
227 right_pad = (xpad + 1) // 2
228 top_pad = (ypad + 0) // 2
229 bottom_pad = (ypad + 1) // 2
230 elif padding_type == Padding.VALID:
231 left_pad = 0
232 right_pad = 0
233 top_pad = 0
234 bottom_pad = 0
235 elif padding_type == Padding.EXPLICIT:
236 # Padding is specified in a PAD operator which has been bypassed.
237 top, left, bottom, right = explicit_padding
238 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
239 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
240 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000241 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200242 padding = (top_pad, left_pad, bottom_pad, right_pad)
243 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
244 return padding, skirt
245
246
247def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_shape, upscaling_factor):
248 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
249 if padding_type == Padding.SAME:
250 ypad = needed_total_padding(int(input_shape.height) * upscaling_factor, int(stride[1]), int(kernel_height))
251 xpad = needed_total_padding(int(input_shape.width) * upscaling_factor, int(stride[2]), int(kernel_width))
252 right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
253 bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
254 left_pad = max(kernel_width - 1 - right_pad, 0)
255 top_pad = max(kernel_height - 1 - bottom_pad, 0)
256 elif padding_type == Padding.VALID:
257 right_pad = max(kernel_width - 2, 0)
258 bottom_pad = max(kernel_height - 2, 0)
259 left_pad = kernel_width - 1
260 top_pad = kernel_height - 1
261 else:
Tim Hall0ab2edc2022-02-23 17:58:02 +0000262 raise UnsupportedFeatureError(f"Unsupported padding = {padding_type} for up-scaled padding calculation")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200263 padding = (top_pad, left_pad, bottom_pad, right_pad)
264 skirt = padding
265 return padding, skirt
266
267
268def fixup_conv2d_backprop(op, arch, nng):
269 if op.type == Op.Conv2DBackpropInput:
270 # flip the inputs
271 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
272 op.type = Op.Conv2DBackpropInputSwitchedBias
Tim Hall3c5cfe92022-03-16 16:31:57 +0000273 op.ifm_resampling_mode = resampling_mode.TRANSPOSE
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200274
275 # Update strides
276 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
277
278 return op
279
280
281# Convert the op to an elementwise add
Tim Hall885033b2022-07-21 11:46:03 +0100282def convert_resize_1x1_to_add(op):
283 op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200284 op.name = op.name + "_add"
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200285 # Create an input tensor filled with zeros
286 shape = op.ofm_shapes[0].as_list()
287 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
James Peet7519d502021-07-19 16:47:58 +0100288 tens.values = np.zeros(shape, tens.dtype.as_numpy_type())
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200289 tens.quantization = QuantizationParameters(0.0, 255.0)
290 tens.quantization.scale_f32 = 1.0
291 tens.quantization.zero_point = 0
292 tens.consumer_list = [op]
293 tens_op = op.inputs[1].ops[0]
294 tens_op.set_output_tensor(tens)
295 # Set the add inputs
296 op.inputs[1] = op.inputs[0]
297 op.inputs[0] = tens
298 op.set_ifm_ofm_shapes()
299
300 return op
301
302
Tim Hall885033b2022-07-21 11:46:03 +0100303# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
304# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
305# to select the appropriate nearest neighbor value
306def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
307 ifm = op.ifm
308 ofm = op.ofm
309 output_depth = ofm.shape[-1]
310 dw_op_attrs = {
311 "padding": Padding.VALID,
312 "stride_h": 1,
313 "stride_w": 1,
314 "strides": (1, 1, 1, 1),
315 "depth_multiplier": 1,
316 "channel_multiplier": 1,
317 "dilation_h_factor": 1,
318 "dilation_w_factor": 1,
319 "dilation": (1, 1, 1, 1),
320 }
321
322 # change resizebilinear to depthwise
323 op.type = Op.DepthwiseConv2DBias
324 op.attrs.update(dw_op_attrs)
325 op.set_input_tensor(ifm, 0) # ifm tensor index
326 op.activation = None
327
328 # add input resample to resize by x2
329 op.ifm_resampling_mode = resampling_mode.NEAREST
330
331 # don't care about the rounding mode as it is nearest neighbor
332
333 # setup weight tensor
334 weight_quant = QuantizationParameters()
335 weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
336 weight_quant.zero_point = 0
337 weight_quant.quant_dim = 0
338 ofm_dtype = ofm.dtype
339 if ofm_dtype == DataType.uint8:
340 weight_value_dtype = np.uint8
341 weight_quant.quant_min = 0
342 weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
343 else:
344 if ofm_dtype == DataType.int8:
345 weight_value_dtype = np.int8
346 else:
347 assert ofm_dtype == DataType.int16
348 weight_value_dtype = np.int16
349
350 weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
351 weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
352
353 weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
354
355 # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
356 # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
357 # below-and-right (i.e. next) to it (D).
358 # 0---1---2
359 # | A | B |
360 # 1---*---+
361 # | C | D |
362 # 2---+---+
363 weight_values = [0] * (upscale_factor * upscale_factor)
364 centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
365 weight_values[centre_coeff] = 1
366
367 # add weight tensor, this will discard the size tensor of the resize op
368 op.set_input_tensor(
369 create_const_tensor(
370 "weights",
371 weight_shape,
372 ofm.dtype,
373 np.array(weight_values).reshape(weight_shape),
374 value_dtype=weight_value_dtype,
375 quantization=weight_quant,
376 ),
377 1, # inputs tensor weight index
378 )
379
380 # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
381 # need to append the bias tensor as resize ops only have 2 inputs
382 assert len(op.inputs) == 2
383 op.inputs.append(None)
384 fixup_bias_tensors(op, None, None)
385
386 # finally update the shape incase we've change the tensor shapes or connections
387 op.set_ifm_ofm_shapes()
388
389 return op
390
391
392# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
393# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
394# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
395def convert_resize_to_upscale_and_average_pool(op):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200396 pre_op = op
397 outputs = op.outputs
Rickard Boline546def2022-01-25 15:45:00 +0000398 dtype = op.ifm.dtype
Tim Hall885033b2022-07-21 11:46:03 +0100399
Rickard Boline546def2022-01-25 15:45:00 +0000400 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
Tim Hall47c76362022-07-18 21:26:47 +0100401 op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
Tim Hall3c5cfe92022-03-16 16:31:57 +0000402 op.ifm_resampling_mode = resampling_mode.NEAREST
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200403
404 upscaled_shape = np.array(op.ifm_shapes[0].get_hw_as_list())
Tim Hall47c76362022-07-18 21:26:47 +0100405
406 # Get upscale factor that was calculated in the supported operators check
407 upscale_factor = op.attrs["upscale_factor"]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200408
Rickard Boline546def2022-01-25 15:45:00 +0000409 # Calculate how many times 2x2 upscaling needs to be performed
Tim Hallf9267da2022-04-20 20:19:48 +0100410 # Force the result of round to be an integer. This is because the behaviour of rounding numpy.float64 values changed
411 # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
Rickard Boline546def2022-01-25 15:45:00 +0000412 n = int(np.log2(upscale_factor))
413
Tim Hall885033b2022-07-21 11:46:03 +0100414 # Perform x2 upscaling n-1 times
Rickard Boline546def2022-01-25 15:45:00 +0000415 scaled_op = pre_op
416 for count in range(n - 1):
417 if count > 0:
418 scaled_op = op.clone(f"_{count}")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200419 scaled_op.inputs[0] = pre_op.outputs[0]
420
Tim Hall885033b2022-07-21 11:46:03 +0100421 # Nearest neighbor x2 upscaling
Tim Hall47c76362022-07-18 21:26:47 +0100422 upscaled_shape = upscaled_shape * 2
Rickard Boline546def2022-01-25 15:45:00 +0000423 shape = op.ofm_shapes[0].as_list()
424 shape[1:3] = upscaled_shape
425 out_tens = Tensor(shape, dtype, f"{op.outputs[0].name}_{count}")
426 out_tens.quantization = op.outputs[0].quantization.clone()
427 scaled_op.set_output_tensor(out_tens)
428 pre_op = scaled_op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200429
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200430 scaled_op.set_ifm_ofm_shapes()
431
Tim Hall885033b2022-07-21 11:46:03 +0100432 # Last x2 upscaling
Rickard Boline546def2022-01-25 15:45:00 +0000433 if n > 1:
434 scaled_op = op.clone(f"_{n-1}")
435 scaled_op.inputs[0] = pre_op.outputs[0]
Tim Hall885033b2022-07-21 11:46:03 +0100436
437 if scaled_op.original_type == Op.ResizeBilinear:
438 if scaled_op.attrs["align_corners"]:
439 # no padding
440 scaled_op.attrs["padding"] = Padding.VALID
441 else:
442 # padding to the right and bottom (limits average pool to 8x8 kernel)
443 scaled_op.attrs["padding"] = Padding.EXPLICIT
444 scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
445
446 # kernal size dependent on the upscaling factor
447 scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
448 else: # Op.ResizeNearestNeighbor
449 if scaled_op.attrs["align_corners"]:
450 # use depthwise conv to select the correct value
451 scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
452 else:
453 # keep 1x1 kernel and average pool
454 pass
455
Rickard Boline546def2022-01-25 15:45:00 +0000456 scaled_op.outputs = outputs
457 scaled_op.outputs[0].ops = [scaled_op]
458 scaled_op.set_ifm_ofm_shapes()
459
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200460 return op
461
462
Tim Hall885033b2022-07-21 11:46:03 +0100463def fixup_resize(op, arch, nng):
464 if op.type.is_resize_op() and op.run_on_npu:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200465 if op.ifm_shapes[0] == op.ofm_shapes[0]:
Tim Hall885033b2022-07-21 11:46:03 +0100466 # Bypass the resize op which is essentially a NOP
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200467 op.inputs = op.inputs[:1]
468 op.type = Op.Identity
469 elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
Tim Hall885033b2022-07-21 11:46:03 +0100470 convert_resize_1x1_to_add(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200471 else:
Tim Hall885033b2022-07-21 11:46:03 +0100472 convert_resize_to_upscale_and_average_pool(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200473
474 return op
475
476
477def convert_nop_split_to_identity(op, arch, nng):
478 if op.type == Op.Split and op.attrs.get("num_splits") == 1:
479 # the list comprehension should return a list with a single tensor
480 # if it shouldn't, remove_passthrough_tensor will fail appropriately
481 op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
482 op.type = Op.Identity
483 return op
484
485
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100486def rewrite_fully_connected_input(op: Operation, arch, nng):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200487
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100488 if op.type == Op.FullyConnected:
489 new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
490 assert new_shape is not None, "Tensor can not be reshaped to 2D"
491 op.ifm_shapes[0] = new_shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200492 return op
493
494
495def convert_batched_fc_shape(op, arch, nng):
496 if op.type == Op.FullyConnected:
497 # Check if the first dimension indicates batching
498 if op.ifm_shapes[0].batch > 1:
499 batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
500 n = op.ifm_shapes[0].batch
501 h, w = batching_split.get(n, (1, n))
502 op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
503
504 # Reshape Weights to be 4D. IO becomes HWIO
505 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100506 weight_tensor.values = np.expand_dims(np.expand_dims(weight_tensor.values, axis=0), axis=0)
507 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200508
509 n = op.ofm_shapes[0].batch
510 h, w = batching_split.get(n, (1, n))
511 op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
512 return op
513
514
515def unfuse_activation_function(op):
516 if op.type == Op.ConcatTFLite and op.run_on_npu and op.activation is not None:
517 act_op = Operation(op.activation.op_type, op.name + op.activation.op_type.name)
518 op.activation = None
519 out_tens = op.outputs[0]
520 intermediate_tens = out_tens.clone("_act_intermediate")
521 act_op.set_output_tensor(out_tens)
522 act_op.add_input_tensor(intermediate_tens)
523 op.set_output_tensor(intermediate_tens)
524 act_op.set_ifm_ofm_shapes()
525
526
527def rewrite_stridedslice_output(op, arch, nng):
528 if not op.run_on_npu or op.type != Op.StridedSlice:
529 return op
530
531 new_axis_mask = op.attrs["new_axis_mask"]
532 shrink_axis_mask = op.attrs["shrink_axis_mask"]
533
534 if shrink_axis_mask == 0 and new_axis_mask == 0:
535 return op
536
537 axis_4D = [0] * len(op.outputs)
538 for idx, out_tens in enumerate(op.outputs):
539 output_shape = list(out_tens.shape)
540
541 if shrink_axis_mask != 0:
542 n = 0
543 axis = 0
544 while shrink_axis_mask:
545 prev_mask = shrink_axis_mask
546 n += 1
547 shrink_axis_mask &= shrink_axis_mask - 1
548 axis = int(math.log2(prev_mask - shrink_axis_mask))
549 output_shape = output_shape[:axis] + [1] + output_shape[axis:]
550
551 assert len(out_tens.shape) == (len(op.inputs[0].shape) - n)
552 op.attrs["shrink_axis_mask"] = 0
553 if axis >= 0:
554 axis_4D[idx] = axis + (4 - len(output_shape))
555 else:
556 axis_4D[idx] = axis
557 op.ofm_shapes[idx] = Shape4D(output_shape)
558
559 elif new_axis_mask != 0:
560 n = 0
561 axis = 0
562 while new_axis_mask:
563 prev_mask = new_axis_mask
564 n += 1
565 new_axis_mask &= new_axis_mask - 1
566 axis = int(math.log2(prev_mask - new_axis_mask))
567 output_shape = output_shape[:axis] + output_shape[(axis + 1) :]
568 new_axis_mask >>= 1
569
570 assert len(out_tens.shape) == (len(op.inputs[0].shape) + n)
571 op.attrs["new_axis_mask"] = 0
572 if axis >= 0:
573 axis_4D[idx] = axis + (4 - len(output_shape))
574 else:
575 axis_4D[idx] = axis
576 op.ofm_shapes[idx] = Shape4D(output_shape)
577
578 op.attrs["split_axis_4D"] = axis_4D
579 return op
580
581
582def rewrite_unpack_output(op, arch, nng):
583 tens = op.outputs[0]
584 if op.run_on_npu and op.type == Op.Unpack:
585 # Unpack is also referred to as Unstack
586 axis = int(op.attrs["axis"])
587 if axis < 0: # Convert to positive axis
588 axis = len(op.inputs[0].shape) + 1 + axis
589 op.type = Op.UnpackReshaped
590 desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
591
592 axis_4D = axis + (4 - len(desired_output_shape))
593 op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
594
595 for idx, out_tens in enumerate(op.outputs):
596 op.ofm_shapes[idx] = Shape4D(desired_output_shape)
597 return op
598
599
600def add_padding_fields(op, arch, nng):
601 if op.run_on_npu:
602 if "padding" in op.attrs:
603 input_shape = op.ifm_shapes[0]
604 output_shape = op.ofm_shapes[0]
605 if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
606 kernel_size = op.inputs[1].shape[:2]
607 elif op.type.is_pool_op() or op.type.npu_block_type == NpuBlockType.ReduceSum:
608 kernel_size = op.attrs["ksize"][1:3]
609 else:
610 raise UnsupportedFeatureError(f"Unknown operation that uses padding: {optype_to_builtintype(op.type)}")
611
612 if op.type == Op.Conv2DBackpropInputSwitchedBias:
613 upscaling_factor = output_shape.height // input_shape.height
614 padding, skirt = calc_upscaled_padding_and_skirt(
615 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
616 )
617 else:
618 padding, skirt = calc_padding_and_skirt(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200619 op.attrs["padding"],
620 op.kernel,
621 input_shape,
622 op.attrs.get("explicit_padding"),
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200623 )
624
625 op.attrs["explicit_padding"] = padding
626 op.attrs["skirt"] = skirt
627
628 return op
629
630
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200631def reorder_depthwise_weights(op, arch, nng):
632 if op.type.is_depthwise_conv2d_op():
633 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100634 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
635 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200636 weight_tensor.weight_transpose_depthwise = True
637
638 return op
639
640
641def optimise_strided_conv(op, arch, nng):
Louis Verhaard43d27582022-03-17 14:06:00 +0100642 if op.type != Op.Conv2DBias or op.op_index != 0:
643 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200644 stride_x, stride_y = op.get_kernel_stride()
Louis Verhaard43d27582022-03-17 14:06:00 +0100645 weight_tensor = op.weights
646 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200647
648 if (
Louis Verhaard43d27582022-03-17 14:06:00 +0100649 stride_x == 2
650 and ifm_shape.depth <= 4
651 and ifm_shape.width % 2 == 0
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200652 and weight_tensor is not None
653 and weight_tensor.shape[1] >= 2
654 ):
Louis Verhaard43d27582022-03-17 14:06:00 +0100655 k_w, _ = op.get_kernel_size()
656 curr_padding_x = needed_total_padding(ifm_shape.width, 2, k_w)
657 optimised_padding_x = needed_total_padding(ifm_shape.width // 2, 1, (k_w + 1) // 2)
658 if curr_padding_x != optimised_padding_x:
659 # Horizontal padding would become different after optimisation; this would not work
660 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200661 # IFM
662 op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
663
664 # Weights
665 weight_shape = weight_tensor.shape
666 if weight_shape[1] % 2 != 0:
667 weight_shape[1] = weight_shape[1] + 1
668 padded_array = np.zeros(weight_shape)
669 for i in range(weight_shape[0]):
670 padded_array[i] = np.vstack(
671 [
James Peet7519d502021-07-19 16:47:58 +0100672 weight_tensor.values[i],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200673 np.full((1, weight_shape[2], weight_shape[3]), weight_tensor.quantization.zero_point),
674 ]
675 )
James Peet7519d502021-07-19 16:47:58 +0100676 weight_tensor.values = padded_array
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200677 weight_shape[1] //= 2
678 weight_shape[2] *= 2
James Peet7519d502021-07-19 16:47:58 +0100679 weight_tensor.values = np.reshape(weight_tensor.values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200680 weight_tensor.set_all_shapes(weight_shape)
681 # If multiple copies of the weights are used, we could avoid
682 # them having the same address by changing the value_id
683 weight_tensor.value_id = uuid.uuid4()
684
685 # Strides
686 stride_x = 1
687 op.attrs.update({"stride_w": stride_x, "stride_h": stride_y, "strides": (1, stride_y, stride_x, 1)})
688
689 return op
690
691
692def convert_conv_to_fc(op, arch, nng):
693 # Conv 1x1 can be equivalent to Fully Connected.
694 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
695 # caching/double buffering for the weights.
696 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
697 if op.type == Op.Conv2DBias:
698 h = op.ifm_shapes[0].height
699 w = op.ifm_shapes[0].width
700 kh, kw, _, _ = op.inputs[1].shape
701 if h == 1 and w == 1 and kh == 1 and kw == 1:
702 # Overwrite this op as a Fully Connected Op
703 op.name += "_fc"
704 op.type = Op.FullyConnected
705 op.attrs = {
706 "weights_format": 0,
707 }
708 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
709 weight_tensor = op.inputs[1]
James Peet7519d502021-07-19 16:47:58 +0100710 weight_tensor.values = weight_tensor.values.squeeze(axis=(0, 1))
711 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200712
713 DebugDatabase.add_optimised(op, op)
714 return op
715
716
717def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
718 if op.run_on_npu and op.type.is_relu_op():
719 ifm = op.inputs[0]
720 ofm = op.outputs[0]
721 # Relu with differing IFM and OFM scaling cannot be fused with another primary op
722 # and requires its own to be inserted
723 if not check_quantized_tens_scaling_equal(ifm, ofm):
724 # Override this op with its own primary op (avgpool)
725 relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
726 # And fuse the original activation function to it
727 relu_fused_op.activation = create_activation_function(op.type)
Fredrik Svedberg1a7527c2021-09-13 15:52:16 +0200728 # Add explicit rescaling
729 rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
730 multiplier, shift = scaling.quantise_scale(rescale)
731 relu_fused_op.rescale = ExplicitScaling(False, [shift], [multiplier])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200732 # Tidy up and assign the ifm and ofm to the new op
733 ifm.consumer_list.remove(op)
734
735 relu_fused_op.add_input_tensor(ifm)
736 relu_fused_op.set_output_tensor(ofm)
737 relu_fused_op.set_ifm_ofm_shapes()
738 op = relu_fused_op
739 return op
740
741
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200742def convert_softmax(op, arch, nng):
743 if op.type == Op.Softmax and op.run_on_npu:
744 softmax = SoftMax(op)
745 op = softmax.get_graph()
746 return op
747
748
749def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
750 r"""Whenever there is a subgraph with this topology:
751
Jonas Ohlssond8575072022-03-30 10:30:25 +0200752 Input X For X = -1 or X > 0
753 | \ / This subgraph can be replaced with either
754 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
755 | /
756 Max
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200757 """
758
759 if op.type == Op.Maximum:
760 # finds the Mul input(s) to the Max
761 muls = [i for i in op.inputs if i.ops[0].type == Op.Mul]
762 if len(muls) == 1:
763 mul = muls[0].ops[0]
764 elif len(muls) == 2:
765 # In the case both inputs are Muls, find the one with the same input as the Max
766 mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
767 else:
768 # No Mul inputs
769 return op
770
771 # make sure the Mul doesn't have any other consumers
772 mul_ofm = mul.outputs[0]
773 if len(mul_ofm.consumers()) != 1:
774 return op
775 # make sure the Mul doesn't have a fused activation function
776 if mul.activation:
777 return op
778 ifm, ofm = op.get_ifm_ofm()
779 if ifm is None or ofm is None:
780 return op
781
782 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
783 return op
784 if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
785 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
786 return op
787
788 # finds the branched input that goes to both the Max and the Mul
789 shared = set(op.inputs) & set(mul.inputs)
790 if len(shared) == 1:
791 shared_in = shared.pop()
792 # find the constant scalar input to the Mul
793 const_tens = (set(mul.inputs) - {shared_in}).pop()
794 # check that it is a scalar
795 if const_tens.shape != []:
796 return op
797 const = const_tens.ops[0]
798 # check that it is a constant
799 if const.type != Op.Const:
800 return op
801 # Remove the Mul from the shared input's consumers
802 shared_in.consumer_list.remove(mul)
803 else:
804 return op
805
806 val = const.outputs[0].values
807 if val >= 0:
808 new_op = Op.LeakyRelu
809 op.attrs["alpha"] = val
810 # to produce bit exact results, the alpha is not enough;
811 # save additional scaling info in attr "alpha_scale", to be used as input
812 # to the LUT construction
James Peet7519d502021-07-19 16:47:58 +0100813 alpha_scalar = const_tens.values - const_tens.quantization.zero_point
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200814 mul_ifm_scale = np.double(ifm.quantization.scale_f32)
815 mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
816 mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
817 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
818 op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
819 elif val == -1:
820 new_op = Op.Abs
821 else:
822 return op
823
824 op.type = new_op
825 op.name = op.name.replace("Maximum", new_op.name)
826 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name)
827 op.inputs = [shared_in]
828 op.set_ifm_ofm_shapes()
829
830 # Record optimisation in debug database
831 DebugDatabase.add_optimised(op, op)
832
833 return op
834
835
836def convert_hardswish_to_lut(op, arch, nng):
837 if op.type == Op.HardSwish:
838 ifm, ofm = op.get_ifm_ofm()
839 # Generate the LUT
840 ifm_scale = np.double(ifm.quantization.scale_f32)
841 ofm_scale = np.double(ofm.quantization.scale_f32)
842 zp_in = ifm.quantization.zero_point
843 zp_out = ofm.quantization.zero_point
844 ifm_scale_hires = (1 / 128) * ifm_scale
845 relu_multiplier = np.double(3 / 32768)
846 out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
847 relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
848 # Use 16bit scale
849 out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
850 relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
851
852 values = []
853 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
854 quantized_min = min(ix)
855 quantized_max = max(ix)
856 for x in ix:
857 input_value = x - zp_in
858 input_value_hires = input_value * 128
859 # Compute the input value on essentially the output scale, not shifted yet
860 input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
861 # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
862 relu_value = np.int16(input_value_hires)
863 if relu_shift < 31:
864 relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
865
866 relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
867
868 if relu_shift < 31:
869 relu_value = fp_math.shift_left16(relu_value, 1)
870
871 if relu_shift > 31:
872 relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
873
874 # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
875 # Now convert that to a 16bit fixedpoint value in [0, 1]
876 relu_value = (relu_value + (1 << 15)) >> 1
877 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
878 shift = 31 - out_shift
879 shift = -shift if shift < 0 else 0
880 # Finally apply the output shift
881 lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
882 lut_result = min(quantized_max, max(quantized_min, lut_result))
883 values.append(lut_result)
884 return convert_to_lut(op, values, "hardswish")
885 return op
886
887
888def convert_lrelu_to_mul_max(op, arch):
889 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
890 # (the opposite of convert_mul_max_to_abs_or_lrelu)
891 ifm, ofm = op.get_ifm_ofm()
892 if ifm is None or ofm is None:
893 return op
894
895 # Add multiplication with alpha
896 mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
897 mul_alpha.add_input_tensor(ifm)
898 # Create const tensor containing alpha as scalar
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200899 alpha = np.float32(op.attrs["alpha"])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200900 quantization = ifm.quantization.clone()
901 quantization.min = 0
902 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
903 quantization.zero_point = 0
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200904 if np.isinf(1 / alpha):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200905 # Handling of alpha near zero
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200906 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200907 scalar = 0
908 else:
909 quantization.scale_f32 = alpha
910 scalar = alpha
911 alpha_tens = create_const_tensor(
912 op.name + "_alpha_scalar", [], ifm.dtype, [scalar], np.float32, quantization=quantization
913 )
James Peet7519d502021-07-19 16:47:58 +0100914 alpha_tens.values = np.array([1])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200915 mul_alpha.add_input_tensor(alpha_tens)
916 fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
917 mul_alpha.set_output_tensor(fm_alpha)
918 mul_alpha.set_ifm_ofm_shapes()
919 DebugDatabase.add_optimised(op, mul_alpha)
920
921 if check_quantized_tens_scaling_equal(ifm, ofm):
922 # No identity multiplication is needed
923 fm_id = ifm
924 else:
925 # Add multiplication with identity
926 mul_identity = Operation(Op.Mul, op.name + "_mul_identity")
927 mul_identity.add_input_tensor(ifm)
928 # Create const tensor containing identity as scalar
929 quantization = ifm.quantization.clone()
930 quantization.min = 0
931 quantization.max = quantization.quant_max - quantization.quant_min
Fredrik Svedbergcce872b2021-09-02 15:20:52 +0200932 quantization.scale_f32 = np.float32(1)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200933 quantization.zero_point = 0
934 identity_tens = create_const_tensor(
935 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
936 )
937 mul_identity.add_input_tensor(identity_tens)
938 # Make sure that fm_id is allocated to a different address than fm_alpha
939 fm_id = ofm.clone(op.name + "_id", set_unique=True)
940 mul_identity.set_output_tensor(fm_id)
941 mul_identity.set_ifm_ofm_shapes()
942 DebugDatabase.add_optimised(op, mul_identity)
943
944 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
945 op.type = Op.Maximum
946 op.name = op.name.replace("LeakyRelu", "Maximum")
947 op.inputs = []
948 ifm.consumer_list.remove(op)
949 op.add_input_tensor(fm_alpha)
950 op.add_input_tensor(fm_id)
951 op.set_ifm_ofm_shapes()
952
953 DebugDatabase.add_optimised(op, op)
954 return op
955
956
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200957def convert_to_lut8(op, fn, fn_name):
958 # Converts op to a no-op + int8/uint8 LUT which is generated with the given function.
959 # fn is a function(real) -> real
960 ifm, ofm = op.get_ifm_ofm()
961 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
962 return op
963 # Generate the LUT
964 ifm_scale = np.double(ifm.quantization.scale_f32)
965 ofm_scale = np.double(ofm.quantization.scale_f32)
966 zp_in = ifm.quantization.zero_point
967 zp_out = ofm.quantization.zero_point
968 values = []
969 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
970 quantized_min = min(ix)
971 quantized_max = max(ix)
972 for x in ix:
973 x_real = ifm_scale * (x - zp_in)
974 y_real = fn(x_real)
975 lut_result = round_away_zero(zp_out + y_real / ofm_scale)
976 lut_result = min(quantized_max, max(quantized_min, lut_result))
977 values.append(lut_result)
978 return convert_to_lut(op, values, fn_name)
979
980
981def convert_lrelu_to_lut(op, arch):
982 ifm, ofm = op.get_ifm_ofm()
983 # Generate the LUT
984 alpha = op.attrs["alpha"]
985 ifm_scale = np.double(ifm.quantization.scale_f32)
986 ofm_scale = np.double(ofm.quantization.scale_f32)
987 zp_in = ifm.quantization.zero_point
988 zp_out = ofm.quantization.zero_point
989 identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
990 alpha_scalar = 1
991 alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
992 if "alpha_scaling" in op.attrs:
993 # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
994 alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
995 values = []
996 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
997 quantized_min = min(ix)
998 quantized_max = max(ix)
999 for x in ix:
1000 if x < zp_in:
1001 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
1002 alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
1003 )
1004 else:
1005 lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
1006 lut_result = min(quantized_max, max(quantized_min, lut_result))
1007 values.append(lut_result)
1008 return convert_to_lut(op, values, "lrelu")
1009
1010
1011def convert_lrelu(op, arch, nng):
1012 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
1013 if op.type != Op.LeakyRelu:
1014 return op
1015 ifm, ofm = op.get_ifm_ofm()
1016 if ifm is None or ofm is None:
1017 return op
1018 if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
1019 # use LUT for int8/uint8
1020 return convert_lrelu_to_lut(op, arch)
1021 if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
1022 # use LeakyRelu unmodified for int16 with equal input/output scaling
1023 return op
1024 return convert_lrelu_to_mul_max(op, arch)
1025
1026
1027def convert_tanh_sigmoid_to_lut(op, arch, nng):
1028 # Converts int8/uint8 Sigmoid and Tanh to a LUT based solution
1029 if op.type == Op.Sigmoid:
1030 return convert_to_lut8(op, clamp_sigmoid, "sigmoid")
1031 elif op.type == Op.Tanh:
1032 return convert_to_lut8(op, math.tanh, "tanh")
1033 return op
1034
1035
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001036def remove_memory_only_ops(op, arch):
1037 if op.run_on_npu and op.type in memory_only_ops:
1038 bypass_memory_only_ops(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001039
1040
1041def fuse_activation_function_with_prev(op, arch, nng):
1042 # if op is a no-op: attempts to move the activation function to the preceding op
1043 if not op.attrs.get("is_nop", False) or op.activation is None:
1044 return op
1045 ifm, ofm = op.get_ifm_ofm()
1046 if ifm is None or ofm is None:
1047 return op
1048 # finds the input(s) to the operation
1049 prev_op = ifm.ops[0]
1050 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
1051 fuse = (
1052 prev_op.run_on_npu
1053 and prev_op.type.npu_block_type != NpuBlockType.Default
1054 and len(ifm.ops) == 1
1055 and len(prev_op.outputs[0].consumers()) == 1
1056 and prev_op.activation is None
1057 )
1058 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
1059 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
1060 # LUT currently only works correctly for elementwise ops
1061 fuse = False
1062 if not fuse:
1063 return op
1064 # Move the fused activation function + corresponding info to prev_op
1065 prev_op.activation = op.activation
1066 prev_op.forced_output_quantization = op.forced_output_quantization
1067 if op.activation_lut is not None:
1068 prev_op.set_activation_lut(op.activation_lut)
1069 # Bypass op
1070 prev_op.set_output_tensor(ofm)
1071 DebugDatabase.add_optimised(op, prev_op)
1072 return op
1073
1074
1075def _leading_pad_ok(leading_pad, stride, kernel_size):
1076 # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
1077 # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
1078 max_size = kernel_size // 2
1079 return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
1080
1081
1082def replace_pad_by_hw_pad(op: Operation, arch, nng):
1083 """
1084 Tries to completely remove a PAD operator by using hardware padding.
1085 E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
1086 is rewritten such that the PAD is removed, and the CONV uses SAME padding.
1087 Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
1088 if both operations can be run on the NPU.
1089 This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
1090 """
1091 if (
1092 (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
Tim Hall0ab2edc2022-02-23 17:58:02 +00001093 and op.type not in (Op.Conv2DBackpropInput, Op.Conv2DBackpropInputSwitchedBias)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001094 and op.run_on_npu
1095 and op.attrs["padding"] == Padding.VALID
1096 ):
1097 pad_op = op.ifm.ops[0]
1098 if pad_op.type != Op.Pad or not pad_op.run_on_npu:
1099 return op
1100 if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
1101 return op
1102 top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
1103 k = op.kernel
1104 k_w, k_h = k.dilated_wh()
1105
1106 # Check if the PAD operator can be replaced by hardware padding
1107 if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
1108 # Too much padding, it would require hardware padding to actually insert zeros
1109 return op
1110 if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
1111 return op
1112
1113 if op.type.is_avgpool_op():
1114 # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
1115 for pad, k_size in (
1116 (left, k_w),
1117 (right, k_w),
1118 (top, k_h),
1119 (bottom, k_h),
1120 ):
1121 if pad not in (0, k_size // 2):
1122 return op
1123 # Average pool is converted to depthwise, because NPU average pool + same padding
1124 # has a special implementation that is different from PAD followed by average pool with
1125 # valid padding.
1126 k_w, k_h = op.kernel.width, op.kernel.height
1127 ifm = op.ifm
1128 # Remember other inputs
1129 other_inputs = op.inputs[1:]
1130 # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
1131 quantization = QuantizationParameters(0.0, 255.0)
1132 quantization.scale_f32 = 1.0 / (k_w * k_h)
1133 quantization.zero_point = 0
1134 shape = [k_h, k_w, 1, op.ofm.shape[-1]]
1135 weights = np.full(shape, 1)
1136
1137 weight_tens = create_const_tensor(
1138 op.name + "_weights",
1139 shape,
1140 op.ifm.dtype,
1141 weights,
1142 np.uint8,
1143 purpose=TensorPurpose.Weights,
1144 quantization=quantization,
1145 )
James Peet7519d502021-07-19 16:47:58 +01001146 weight_tens.values = weights
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001147 op.type = Op.DepthwiseConv2DBias
1148 op.inputs = []
1149 op.add_input_tensor(ifm)
1150 op.add_input_tensor(weight_tens)
1151 # Add bias tensor, all biases set to 0
1152 op.inputs.append(None)
1153 fixup_bias_tensors(op, arch, nng)
1154 # Add other inputs
1155 op.inputs.extend(other_inputs)
1156 op.rounding_mode = NpuRoundingMode.NATURAL
1157
1158 # Bypass the PAD operator
1159 op.set_input_tensor(pad_op.ifm, 0)
1160 # Adjust the padding attributes of the convolution operator
1161 op.attrs["padding"] = Padding.EXPLICIT
1162 op.attrs["explicit_padding"] = (top, left, bottom, right)
1163 op.set_ifm_ofm_shapes()
1164 return op
1165
1166
1167def convert_pad(op: Operation, arch, nng):
1168 """
1169 Rewrites PAD operator to an average pool that copies the IFM to the OFM
1170 + up to 4 average pool operators that fill the OFM with zeros at the borders.
1171 This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
1172 """
1173 if op.type != Op.Pad or not op.run_on_npu:
1174 return op
1175 top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
1176
1177 ifm = op.ifm
1178 assert ifm is not None
James Ward3e134342021-10-28 10:01:40 +01001179 ifm_shape = op.ifm_shapes[0]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001180 ofm = op.ofm
1181 assert ofm is not None
1182 ofm.ops = []
1183 ofm_shape = op.ofm_shapes[0]
1184
1185 # Average pool op that copies IFM to the right place inside the OFM
1186 shp0 = Shape4D(0, 0, 0, 0)
1187 shp_top = shp0.with_height(top)
1188 avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
1189 avgpool_op.activation = op.activation
1190 quant = ofm.quantization
1191 pad_value = quant.zero_point
1192 # Add operations that fill the borders of the OFM
1193 if top > 0:
1194 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
1195 zero_tens = create_const_tensor(
1196 op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1197 )
1198 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
1199 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1200 create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
1201 if bottom > 0:
1202 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
1203 zero_tens = create_const_tensor(
1204 op.name + "_bottom",
1205 shape.as_list(),
1206 ofm.dtype,
1207 shape.elements() * [pad_value],
1208 np.uint8,
1209 quantization=quant,
1210 )
1211 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1212 create_avg_pool_for_concat(
1213 op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
1214 )
1215 if left > 0:
1216 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
1217 zero_tens = create_const_tensor(
1218 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1219 )
1220 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1221 create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
1222 if right > 0:
1223 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
1224 zero_tens = create_const_tensor(
1225 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
1226 )
1227 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
1228 create_avg_pool_for_concat(
1229 op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
1230 )
1231
1232 op.type = Op.ConcatTFLite
1233 return avgpool_op
1234
1235
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001236def fixup_bias_tensors(op, arch, nng):
1237 if op.type.needs_bias() and op.bias is None:
1238 # Op has no bias, add bias tensor filled with zeros
1239 nr_biases = op.inputs[1].shape[-1]
1240 bias_values = [0] * nr_biases
1241 bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001242 op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
1243
1244 return op
1245
1246
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001247def fixup_asymmetric_weights(op, arch, nng):
1248 if op.run_on_npu and (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op()):
1249 if op.ifm.dtype == DataType.int8:
1250 if not np.all(op.weights.quantization.zero_point == 0):
1251 print(f"Warning: {op.type} '{op.name}' has asymmetric weights, zero points have been adjusted.")
1252 op.weights.quantization.zero_point *= 0
1253
1254 return op
1255
1256
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001257def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
1258 if op.type == Op.Mean and op.run_on_npu:
1259 keep_dims = op.attrs.get("keep_dims", False)
1260 inp, axis = op.inputs
1261 shape = inp.shape
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001262 ofm_shape = op.ofm.shape
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001263 dims = len(shape)
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001264 dims_ofm = len(ofm_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001265
1266 # Height and width axes have different index depending on dimensions
1267 if axis.shape == [] or axis.shape[0] == 1: # single axis
1268 axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
1269 if dims in (2, 3):
1270 if axis == 0:
1271 h, w = shape[axis], 1
1272 else:
1273 h, w = 1, shape[axis]
1274 else:
1275 if axis == 1:
1276 h, w = shape[axis], 1
1277 else:
1278 h, w = 1, shape[axis]
1279 else: # multiple axes
1280 axis = sorted(axis.values)
1281 h, w = [shape[i] for i in axis]
1282
1283 # Set necessary depthwise attributes
1284 op.attrs.update(
1285 {
1286 "padding": Padding.VALID,
1287 "stride_h": 1,
1288 "stride_w": 1,
1289 "strides": (1, 1, 1, 1),
1290 "depth_multiplier": 1,
1291 "channel_multiplier": 1,
1292 "dilation_h_factor": 1,
1293 "dilation_w_factor": 1,
1294 "dilation": (1, 1, 1, 1),
1295 }
1296 )
1297 # Change op type
1298 op.type = Op.DepthwiseConv2DBias
1299 # Set IFM/OFM shapes after changing op type
1300 op.set_ifm_ofm_shapes()
1301
1302 weight_scale, bias = 1, None
1303 ofmq, ifmq = op.ofm.quantization, inp.quantization
1304 # Set rounding mode, scaling and zero point based on which reference implementation to match
1305 if len(shape) == 4 and axis == [1, 2] and keep_dims:
1306 if inp.dtype == DataType.uint8:
1307 # This attribute means a different scaling calculation is used in order to match reference
1308 op.low_precision_scaling = True
1309 weight_scale = h * w
1310 # Set zero points to 0 as they will be adjusted for with bias term
1311 foq = ofmq.clone()
1312 foq.zero_point = 0
1313 fiq = ifmq.clone()
1314 fiq.zero_point = 0
1315 op.forced_input_quantization = fiq
1316 bias_term = ofmq.zero_point - int(ifmq.zero_point * ifmq.scale_f32 / ofmq.scale_f32)
1317 # If the bias term is outside uint8 range, we need an Add op to apply it.
1318 if bias_term < 0 or bias_term > 255:
1319 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1320 # Bias term has higher bitness (i32) than input/output (u8).
1321 # 16 bits is enough since the bias is added/subtracted from a u8 value,
1322 # the bias can only effectively assume values in the range [-255, 255].
1323 intermediate.dtype = DataType.int16
1324 intermediate.quantization.zero_point = 0
1325 add_op = Operation(Op.Add, op.name + "_bias")
1326 add_op.forced_output_quantization = foq
1327 add_op.add_input_tensor(intermediate)
1328 quant = QuantizationParameters()
1329 quant.zero_point = 0
1330 bias_term_tens = create_const_tensor(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001331 op.name + "_bias",
1332 [1, 1, 1, 1],
1333 DataType.int16,
1334 [bias_term],
1335 np.int16,
1336 quantization=quant,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001337 )
1338 add_op.add_input_tensor(bias_term_tens)
1339 add_op.set_output_tensor(op.ofm)
1340 add_op.set_ifm_ofm_shapes()
1341 add_op.activation = op.activation
1342 op.activation = None
1343 op.set_output_tensor(intermediate)
1344 op.set_ifm_ofm_shapes()
1345 # If not, we can just do it with the OFM zero point.
1346 else:
1347 foq.zero_point = bias_term
1348 op.forced_output_quantization = foq
1349 else:
1350 assert inp.dtype == DataType.int8
1351 # Use a depthwise to calculate the sum,
1352 # followed by a multiplication with 1/N to get the MEAN
1353 weight_scale = 1
1354 intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
1355 intermediate.dtype = DataType.int16
1356 mul_op = Operation(Op.Mul, op.name + "_mul")
1357 mul_op.add_input_tensor(intermediate)
1358 # Create scalar containing 1/N
1359 quant = QuantizationParameters()
1360 quant.zero_point = 0
1361 # The reference rounds negative numbers downwards, e.g. -1.5 is rounded to -2,
1362 # while rounding mode NATURAL would round this to -1.
1363 # This can only occur if N is even, and can be emulated by
1364 # multiplying with a number that is slightly smaller than 1/N.
1365 # It must be so small that other roundings are not affected;
1366 # the calculated value is based on worst case,
1367 # which is sum 256 * N (the maximum sum that can occur with int8)
1368 n = int(h * w)
1369 eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
1370 quant.scale_f32 = 1 / (n - eps)
1371 scalar = create_const_tensor(
1372 op.name + "_scalar", [1, 1, 1, 1], DataType.uint8, [1], np.uint8, quantization=quant
1373 )
1374 mul_op.add_input_tensor(scalar)
1375 mul_op.set_output_tensor(op.ofm)
1376 mul_op.set_ifm_ofm_shapes()
1377 mul_op.rounding_mode = NpuRoundingMode.NATURAL
1378 mul_op.activation = op.activation
1379 op.activation = None
1380 op.set_output_tensor(intermediate)
1381 op.set_ifm_ofm_shapes()
1382 elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
1383 # Here we can just use a simple AvgPool with truncating rounding,
1384 # as we're emulating simple integer division.
1385 op.rounding_mode = NpuRoundingMode.TRUNCATE
1386 op.type = Op.AvgPool
1387 op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
1388 else:
1389 op.rounding_mode = NpuRoundingMode.NATURAL
1390 weight_scale = 1 / (h * w)
1391 # Input zero point is adjusted after mean calculation, so we emulate that with a bias
1392 bias = -ifmq.zero_point * h * w
1393 fiq = ifmq.clone()
1394 fiq.zero_point = 0
1395 op.forced_input_quantization = fiq
1396
1397 # Change dimensions to 4
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001398 def extend_dims(dim, in_shape):
1399 if dim < 4:
1400 in_shape = [1] + in_shape
1401 if dim == 2:
1402 in_shape += [1]
1403 return in_shape
1404
1405 if dims < 4 or dims_ofm < 4:
1406 # Fix the ofm dimension when keep_dims is false
1407 # e.g. IFM=1xHxWxC axis=2 OFM=1xHxC, the ofm_shape should be 1xHx1xC, not 1x1xHxC
1408 if isinstance(axis, int) and dims_ofm + 1 == dims:
1409 ofm_shape.insert(axis, 1)
1410 elif isinstance(axis, list) and (dims_ofm + len(axis) == dims):
1411 for i in axis:
1412 ofm_shape.insert(i, 1)
1413 shape = extend_dims(dims, shape)
1414 dims_ofm = len(ofm_shape)
1415 ofm_shape = extend_dims(dims_ofm, ofm_shape)
1416 op.set_ifm_ofm_shapes()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001417
Rickard Bolin7d7cb672021-12-07 09:09:14 +00001418 # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
1419 if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001420 shape = [shape[0], 1, h * w, shape[3]]
1421 op.ifm_shapes[0] = Shape4D(shape)
1422 if h > 256 and op.type == Op.AvgPool:
1423 op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
1424
1425 # If the AvgPool version is used, we don't need to do anything else
1426 if op.type == Op.AvgPool:
1427 return op
1428
1429 # Make unit weight tensor quantization
1430 weight_quant = ifmq.clone()
1431 weight_quant.min = 0
1432 weight_quant.max = 255
1433 weight_quant.scale_f32 = weight_scale
1434 weight_quant.zero_point = 0
1435
1436 # Set weight shape to [H,W,C,B]
Diqing Zhong1ddb2ed2022-03-09 12:23:47 +01001437 weight_shape = [h, w, shape[3], shape[0]]
1438
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001439 # Add unit weight tensor
1440 op.set_input_tensor(
1441 create_const_tensor(
1442 "weights",
1443 weight_shape,
1444 inp.dtype,
1445 np.ones(weight_shape),
1446 value_dtype=np.uint8,
1447 quantization=weight_quant,
1448 ),
1449 1,
1450 )
James Peet7519d502021-07-19 16:47:58 +01001451 op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001452
1453 # Add None bias tensor
1454 op.inputs.append(None)
1455 # Add bias tensor
1456 if bias:
1457 bias_shape = [shape[-1]]
1458 op.set_input_tensor(
1459 create_const_tensor(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001460 "bias",
1461 bias_shape,
1462 inp.dtype,
1463 np.ones(bias_shape) * bias,
1464 value_dtype=np.int32,
1465 quantization=None,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001466 ),
1467 2,
1468 )
1469
1470 return op
1471
1472
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001473def optimise_quantize(op: Operation, arch, nng):
1474
1475 if op.type == Op.Quantize and op.run_on_npu:
1476
1477 ifm, ofm = op.get_ifm_ofm()
1478 input_values = ifm.values
1479
1480 # Guard clause - input not const or no values to quantize
1481 if ifm.ops[0].type != Op.Const or input_values is None:
1482 return op
1483
1484 # Singular val in numpy array, convert to indexable array
1485 if input_values.ndim == 0:
1486 input_values = np.array([input_values])
1487
Fredrik Svedberg11563172022-07-06 14:54:12 +02001488 # requantized int8 to int8 or int16 to int16
1489 if ifm.dtype == ofm.dtype == DataType.int8 or ifm.dtype == ofm.dtype == DataType.int16:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001490
1491 # scale needs to use double precision to match TFLite reference kernel
1492 effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
1493 effective_multiplier, effective_shift = quantise_scale(effective_scale)
1494
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001495 requantized_vals = []
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001496 for val in input_values.flatten():
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001497 input_val = val - ifm.quantization.zero_point
1498
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001499 ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
1500 ofm_val += ofm.quantization.zero_point
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001501
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001502 clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
1503 requantized_vals.append(clamped_ofm_value)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001504
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001505 ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
1506 ofm.values.shape = input_values.shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001507
1508 # Case: Float input - quantize to int
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001509 elif ifm.dtype.type == BaseType.Float:
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001510
1511 quantized_vals = []
1512 for val in input_values:
1513
1514 # Derive quantized value
1515 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001516 clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
1517 quantized_vals.append(clamped_quantized_val)
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001518
1519 # Pass the statically calculated quant val to output tensor
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001520 ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
1521
1522 # Unsupported data type
1523 else:
1524 return op
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001525
1526 # Make quantize op const and disconnect from parent node
1527
1528 # Remove reference of the current quant op from the parent tensor's consumer list
1529 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1530
1531 # Clear any references to parent node
1532 op.inputs = []
1533
1534 # Convert this quantize op to const
1535 op.type = Op.Const
1536
1537 return op
1538
1539
Ayaan Masood4965fae2022-06-29 11:30:57 +01001540def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
1541 """Static optimisation for SHAPE operator output value known at compile time"""
1542
1543 # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
1544
1545 if op.type == Op.Shape and op.run_on_npu:
1546
1547 ifm, ofm = op.get_ifm_ofm()
1548
1549 if len(ifm.shape) != ofm.shape[0]:
1550 return op
1551
1552 # Remove reference of the current shape op from the parent tensor's consumer list
1553 ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
1554
1555 # Clear any references to parent node
1556 op.inputs = []
1557
1558 # Convert this SHAPE op to const
1559 op.type = Op.Const
1560
1561 # Add size calculation to shape output tensors
1562 ofm.values = np.array(ifm.shape)
1563
1564 return op
1565
1566
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001567def supported_operator_check(op, arch, nng):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001568 op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001569 return op
1570
1571
1572def tflite_optimise_graph(nng, arch):
Fredrik Svedberg11563172022-07-06 14:54:12 +02001573 # Compile time static optimisations
Ayaan Masood25f48dd2022-06-29 18:16:04 +01001574 optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
1575
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001576 for idx, sg in enumerate(nng.subgraphs):
1577 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001578 nng,
1579 sg,
1580 arch,
1581 [],
Ayaan Masood4965fae2022-06-29 11:30:57 +01001582 optimisation_list,
1583 rewrite_unsupported=False,
1584 )
1585
Fredrik Svedberga04f2f72022-07-06 13:42:24 +02001586 # Pre-processing step
1587 pre_process_list = [
1588 supported_operator_check,
1589 set_ifm_ofm_op_shapes,
1590 ]
1591
Ayaan Masood4965fae2022-06-29 11:30:57 +01001592 for idx, sg in enumerate(nng.subgraphs):
1593 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1594 nng,
1595 sg,
1596 arch,
1597 [],
Jonas Ohlssond8575072022-03-30 10:30:25 +02001598 pre_process_list,
1599 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001600 )
1601
1602 # Handle Concat Ops
1603 for idx, sg in enumerate(nng.subgraphs):
1604 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
1605 sg.refresh_after_modification()
1606
1607 # Handle Split Ops
1608 for idx, sg in enumerate(nng.subgraphs):
1609 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1610 nng,
1611 sg,
1612 arch,
1613 [],
1614 [rewrite_unpack_output, rewrite_stridedslice_output, convert_nop_split_to_identity],
1615 rewrite_unsupported=False,
1616 )
1617
1618 for idx, sg in enumerate(nng.subgraphs):
1619 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001620 nng,
1621 sg,
1622 arch,
1623 [rewrite_split_ops],
1624 [],
1625 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001626 )
1627
1628 # Handle sg input output
1629 for idx, sg in enumerate(nng.subgraphs):
1630 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001631 nng,
1632 sg,
1633 arch,
1634 [],
1635 [fix_sg_input_output],
1636 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001637 )
1638
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001639 # Removal of memory only operators
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001640 for sg in nng.subgraphs:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +02001641 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_only_ops])
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001642 sg.refresh_after_modification()
1643
1644 # Rewrite of operators
1645 op_rewrite_list = [
1646 set_tensor_equivalence,
1647 convert_mean_to_depthwise_conv_or_avgpool,
1648 convert_depthwise_to_conv,
1649 convert_conv_to_fc,
1650 convert_softmax,
1651 optimise_strided_conv,
1652 convert_hardswish_to_lut,
1653 rewrite_fully_connected_input,
1654 convert_batched_fc_shape,
1655 fixup_conv2d_backprop,
1656 fixup_relus_with_differing_ifm_ofm_scaling,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001657 reorder_depthwise_weights,
Tim Hall885033b2022-07-21 11:46:03 +01001658 fixup_resize,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001659 fixup_bias_tensors,
Fredrik Svedbergcc8569f2021-11-01 14:25:29 +01001660 fixup_asymmetric_weights,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001661 convert_mul_max_to_abs_or_lrelu,
1662 convert_lrelu,
1663 convert_tanh_sigmoid_to_lut,
1664 replace_pad_by_hw_pad,
1665 ]
1666
1667 for idx, sg in enumerate(nng.subgraphs):
1668 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001669 nng,
1670 sg,
1671 arch,
1672 [],
1673 op_rewrite_list,
1674 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001675 )
1676
1677 for idx, sg in enumerate(nng.subgraphs):
1678 # remove passthrough tensors and attempt further optimizations
1679 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1680 nng,
1681 sg,
1682 arch,
1683 [remove_passthrough_tensor],
1684 [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
1685 )
1686
1687 # Removal of SplitSliceRead, need to be done after optimisation has been performed,
1688 # since ifm/ofm_shapes are of importance to this function
1689 for sg in nng.subgraphs:
1690 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
1691 sg.refresh_after_modification()
1692
1693 return nng