blob: 8d920d83f1418a226e9a0d3af800ebc2674f1888 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# Early optimisation of the network graph, using the rewrite_graph module to do the traversal of the graph. These are
18# split into two parts optimise_graph_a and optimise_graph_b.
Tim Hall79d07d22020-04-27 18:20:16 +010019import math
Diego Russoea6111a2020-04-14 18:41:58 +010020
21import numpy as np
22
Louis Verhaardb9fc33c2020-08-13 11:47:36 +020023from . import lut
Diego Russoea6111a2020-04-14 18:41:58 +010024from . import rewrite_graph
Diego Russoea6111a2020-04-14 18:41:58 +010025from .data_type import DataType
Louis Verhaard7db78962020-05-25 15:05:26 +020026from .errors import UnsupportedFeatureError
Dwight Lidman42fed942020-05-29 09:37:03 +020027from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Louis Verhaarde0ef2732020-06-03 08:56:44 +020028from .numeric_util import full_shape
Diego Russoe8a10452020-04-21 17:39:10 +010029from .operation import NpuBlockType
30from .operation import Operation
Fredrik Svedberga0c36242020-06-03 15:43:31 +020031from .softmax import SoftMax
Michael McGeaghc5b549b2020-08-07 11:54:28 +010032from .tensor import create_const_tensor
33from .tensor import create_reshape_tensor
Charles Xu9a03fdf2020-07-02 15:12:40 +020034from .tensor import QuantizationParameters
Diego Russoe8a10452020-04-21 17:39:10 +010035from .tensor import Tensor
Tim Hall79d07d22020-04-27 18:20:16 +010036
37passthrough_nodes = set(("Identity",))
38
39
40def remove_passthrough_tensor(tens, arch):
41 if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
42 assert len(tens.ops[0].inputs) == 1
43 tens = tens.ops[0].inputs[0]
44 return tens
45
46
47def rewrite_concat(tens, arch):
48 if len(tens.ops) == 1 and tens.ops[0].is_concat_op():
49 concat_op = tens.ops[0]
50 if tens != concat_op.outputs[0]:
51 return tens # don't attempt to rewrite the min/max outputs of QuantizedConcat
52
53 # Not supported so leave it and run on CPU
54 if not concat_op.run_on_npu:
55 return tens
56
57 inputs, axis = concat_op.get_concat_inputs_axis()
58
59 tens.ops = []
60 offset = 0
61 for idx, inp in enumerate(inputs):
62 new_op = Operation("ConcatSliceWrite", concat_op.name + str(idx))
63 new_op.inputs = [inp]
64 new_op.outputs = [tens]
65 new_op.attrs["concat_axis"] = axis
66 new_op.attrs["concat_start"] = offset
67 offset += inp.shape[axis]
68 new_op.attrs["concat_end"] = offset
69 new_op.run_on_npu = True
70 tens.ops.append(new_op)
71 assert tens.shape[axis] == offset
72
Patrik Gustavsson29d568e2020-08-18 10:11:21 +020073 # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
74 # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
75 # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
Patrik Gustavsson458a2082020-08-13 13:41:05 +020076 # and those addresses are always 16 byte aligned due to the NHCWB16 format.
Patrik Gustavsson29d568e2020-08-18 10:11:21 +020077 if axis == (len(tens.shape) - 1):
Patrik Gustavsson458a2082020-08-13 13:41:05 +020078 for op in tens.ops:
79 if op.attrs["concat_start"] % 16 != 0:
80 tens.avoid_NHCWB16 = True
81 break
82
Tim Hall79d07d22020-04-27 18:20:16 +010083 return tens
84
85
86def rewrite_split(tens, arch):
87
88 if len(tens.ops) == 1 and tens.ops[0].is_split_op():
89 split_op = tens.ops[0]
90
91 # Not supported so leave it and run on CPU
92 if not split_op.run_on_npu:
93 return tens
94
95 inp, outputs, axis, offset_start, offset_end = split_op.get_split_inputs_axis()
96
97 tens.ops = []
98 new_op = Operation("SplitSliceRead", split_op.name)
99 new_op.inputs = [inp]
Tim Hall79d07d22020-04-27 18:20:16 +0100100
101 # For Split the offset cannot be extracted from the tensor so it has to
102 # be calculated from the index of the output tensor
Diego Russoea6111a2020-04-14 18:41:58 +0100103 if axis is not None:
Tim Hall79d07d22020-04-27 18:20:16 +0100104 # Get the start and end of the split
105 offset_start = [0] * len(tens.shape)
106 offset_end = [0] * len(tens.shape)
107 for out in outputs:
108 if out == tens:
109 break
110 offset_start[axis] += out.shape[axis]
111
Patrik Gustavssoneebb1c22020-08-18 15:03:04 +0200112 # If start offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
113 if (offset_start[-1] % 16) != 0:
114 inp.avoid_NHCWB16 = True
115
Tim Hall79d07d22020-04-27 18:20:16 +0100116 offset_end[axis] = offset_start[axis] + tens.shape[axis]
117
118 new_op.attrs["split_start"] = offset_start
119 new_op.attrs["split_end"] = offset_end
120 new_op.run_on_npu = True
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100121 new_op.set_output_tensor(tens)
Tim Hall79d07d22020-04-27 18:20:16 +0100122
123 return tens
124
125
126def needed_total_padding(input_size, stride, filter_size):
127 out_size = (input_size + stride - 1) // stride
128 needed_input = (out_size - 1) * stride + filter_size
129 total_padding = max(0, needed_input - input_size)
130 return total_padding
131
132
133def calc_padding_and_skirt(padding_type, kernel_size, stride, input_dims):
134 ypad = needed_total_padding(int(input_dims[1]), int(stride[1]), int(kernel_size[0]))
135 xpad = needed_total_padding(int(input_dims[2]), int(stride[2]), int(kernel_size[1]))
136 if padding_type == b"SAME":
137 left_pad = (xpad + 0) // 2
138 right_pad = (xpad + 1) // 2
139 top_pad = (ypad + 0) // 2
140 bottom_pad = (ypad + 1) // 2
141 elif padding_type == b"VALID":
142 left_pad = 0
143 right_pad = 0
144 top_pad = 0
145 bottom_pad = 0
146 else:
Louis Verhaard7db78962020-05-25 15:05:26 +0200147 raise UnsupportedFeatureError("Unknown padding {}".format(str(padding_type)))
Tim Hall79d07d22020-04-27 18:20:16 +0100148 padding = (top_pad, left_pad, bottom_pad, right_pad)
149 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
150 return padding, skirt
151
Tim Hallc30f4952020-06-15 20:47:35 +0100152
Jacob Bohlin9b64ba02020-07-07 17:15:22 +0200153def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_dims, upscaling_factor):
154 kernel_height, kernel_width = kernel_size[0], kernel_size[1]
Jacob Bohlincf7da102020-05-20 09:03:40 +0200155 if padding_type == b"SAME":
Jacob Bohlin9b64ba02020-07-07 17:15:22 +0200156 ypad = needed_total_padding(int(input_dims[1]) * upscaling_factor, int(stride[1]), int(kernel_height))
157 xpad = needed_total_padding(int(input_dims[2]) * upscaling_factor, int(stride[2]), int(kernel_width))
158
159 right_pad = ((xpad + 1) // upscaling_factor) - 1
160 bottom_pad = ((ypad + 1) // upscaling_factor) - 1
161 left_pad = max(kernel_width - 1 - right_pad, 0)
162 top_pad = max(kernel_height - 1 - bottom_pad, 0)
163
Jacob Bohlincf7da102020-05-20 09:03:40 +0200164 elif padding_type == b"VALID":
Jacob Bohlin9b64ba02020-07-07 17:15:22 +0200165 right_pad = max(kernel_width - 2, 0)
166 bottom_pad = max(kernel_height - 2, 0)
167 left_pad = kernel_width - 1
168 top_pad = kernel_height - 1
Jacob Bohlincf7da102020-05-20 09:03:40 +0200169 else:
170 assert 0, "Unknown padding"
171
172 padding = (top_pad, left_pad, bottom_pad, right_pad)
Jacob Bohlin9b64ba02020-07-07 17:15:22 +0200173 skirt = padding
Jacob Bohlincf7da102020-05-20 09:03:40 +0200174 return padding, skirt
175
Tim Hall79d07d22020-04-27 18:20:16 +0100176
177def fixup_conv2d_backprop(op, arch):
178 if op.type == "Conv2DBackpropInput":
179 # flip the inputs
180 op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
Jacob Bohlincf7da102020-05-20 09:03:40 +0200181 op.type = "Conv2DBackpropInputSwitchedBias"
Jacob Bohlincf7da102020-05-20 09:03:40 +0200182
183 # Update strides
Tim Hallc30f4952020-06-15 20:47:35 +0100184 op.attrs.update({"stride_w": 1, "stride_h": 1, "strides": (1, 1, 1, 1)})
Tim Hall79d07d22020-04-27 18:20:16 +0100185
186 return op
187
188
Charles Xu9a03fdf2020-07-02 15:12:40 +0200189# Convert the op to an elementwise add
190def convert_resizebilinear_1x1_to_add(op):
191 op.type = "AddAct"
192 op.name = op.name + "_add"
193 op.attrs.update({"npu_block_type": NpuBlockType.ElementWise})
194 op.attrs["resizebilinear"] = True
195 # Create an input tensor filled with zeros
196 shape = op.outputs[0].shape
197 tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
198 tens.values = np.zeros(shape)
199 tens.quant_values = np.zeros(shape, np.uint8)
200 tens.quantization = QuantizationParameters(0.0, 255.0)
201 tens.quantization.scale_f32 = 1.0
202 tens.quantization.zero_point = 0
203 tens.consumer_list = [op]
204 tens_op = op.inputs[1].ops[0]
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100205 tens_op.set_output_tensor(tens)
Charles Xu9a03fdf2020-07-02 15:12:40 +0200206 # Set the add inputs
207 op.inputs[1] = op.inputs[0]
208 op.inputs[0] = tens
209
210 return op
211
212
213def fixup_resizebilinear(op, arch):
214 if op.type == "ResizeBilinear":
215 if op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1:
216 convert_resizebilinear_1x1_to_add(op)
Charles Xu36ffaf32020-08-05 15:40:44 +0200217 elif op.inputs[0].shape == op.outputs[0].shape:
218 # Bypass nop resizebilinear
219 op.inputs = op.inputs[:1]
220 op.type = "Identity"
Charles Xu9a03fdf2020-07-02 15:12:40 +0200221
222 return op
223
224
Tim Hall79d07d22020-04-27 18:20:16 +0100225def fixup_fully_connected_input(op, arch):
226 if op.type == "FullyConnectedAct":
227 inp = op.inputs[0]
228 weights = op.inputs[1]
229
230 n_in_elems = weights.shape[-2]
231 elms = inp.elements()
232 batch_size = elms // n_in_elems
233 assert batch_size * n_in_elems == elms
234
235 desired_shape = [batch_size, n_in_elems]
236 if inp.shape != desired_shape:
237 # mismatch, insert a reshape to fix this.
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100238 op.inputs[0] = create_reshape_tensor(inp, desired_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100239
240 return op
241
242
243def fixup_pack_input(op, arch):
244 if op.type == "Pack":
245 # Pack is also referred to as Stack
246 # Requires the rewrite_concat function to be called on the op afterwards
247 axis = int(op.attrs["axis"])
248 desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
249
250 # Construct 1 shape tensor to be used by all inserted reshape ops
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100251 new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, desired_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100252
253 for idx, inp in enumerate(op.inputs):
Tim Hall79d07d22020-04-27 18:20:16 +0100254 reshape_out = inp.clone("_reshaped")
Michael McGeagh6a8d4242020-07-28 12:17:59 +0100255 reshape_out.set_all_shapes(desired_shape)
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100256
257 reshape_op = Operation("Reshape", "{}{}_reshape".format(op.name, idx))
258 reshape_op.attrs["new_shape"] = desired_shape
259 reshape_op.inputs = [inp, new_shape_tens]
260 reshape_op.set_output_tensor(reshape_out)
Tim Hall79d07d22020-04-27 18:20:16 +0100261
262 op.inputs[idx] = reshape_out
263
264 op.type = "PackReshaped"
265
266 return op
267
268
269def fixup_unpack_output(tens, arch):
270 op = tens.ops[0]
271 if op.type in set(("Unpack", "StridedSlice")):
272 # Unpack is also referred to as Unstack
273 # Requires the rewrite_split function to be called on the op afterwards
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200274
275 reshape_input_shape = tens.shape
Tim Hall79d07d22020-04-27 18:20:16 +0100276 if op.type == "StridedSlice":
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200277 new_axis_mask = op.attrs["new_axis_mask"]
Tim Hall79d07d22020-04-27 18:20:16 +0100278 shrink_axis_mask = op.attrs["shrink_axis_mask"]
Louis Verhaard7db78962020-05-25 15:05:26 +0200279 ellipsis_mask = op.attrs["ellipsis_mask"]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200280
281 if (new_axis_mask != 0 and shrink_axis_mask != 0) or ellipsis_mask != 0:
282 # Not supported, will be put on CPU
283 return tens
284 if shrink_axis_mask == 0 and new_axis_mask == 0:
Tim Hall79d07d22020-04-27 18:20:16 +0100285 # Equal Rank StridedSlice, no need to insert reshape
286 return tens
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200287 elif shrink_axis_mask != 0:
288 n = 0
289 axis = 0
290 while shrink_axis_mask:
291 prev_mask = shrink_axis_mask
292 n += 1
293 shrink_axis_mask &= shrink_axis_mask - 1
294 axis = int(math.log2(prev_mask - shrink_axis_mask))
295 reshape_input_shape = reshape_input_shape[:axis] + [1] + reshape_input_shape[axis:]
Tim Hall79d07d22020-04-27 18:20:16 +0100296
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200297 assert len(tens.shape) == (len(op.inputs[0].shape) - n)
298 op.attrs["shrink_axis_mask"] = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100299
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200300 elif new_axis_mask != 0:
301 n = 0
302 axis = 0
303 while new_axis_mask:
304 prev_mask = new_axis_mask
305 n += 1
306 new_axis_mask &= new_axis_mask - 1
307 axis = int(math.log2(prev_mask - new_axis_mask))
Louis Verhaard7db78962020-05-25 15:05:26 +0200308 reshape_input_shape = reshape_input_shape[:axis] + reshape_input_shape[(axis + 1) :]
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200309 new_axis_mask >>= 1
310
311 assert len(tens.shape) == (len(op.inputs[0].shape) + n)
312 op.attrs["new_axis_mask"] = 0
Tim Hall79d07d22020-04-27 18:20:16 +0100313 else:
314 axis = int(op.attrs["axis"])
315 op.type = "UnpackReshaped"
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200316 reshape_input_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
Tim Hall79d07d22020-04-27 18:20:16 +0100317
318 # Construct 1 shape tensor to be used by all inserted reshape ops
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100319 new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100320
321 for idx, out_tens in enumerate(op.outputs):
Tim Hall79d07d22020-04-27 18:20:16 +0100322 reshape_in = out_tens.clone("_reshaped")
Michael McGeagh6a8d4242020-07-28 12:17:59 +0100323 reshape_in.set_all_shapes(reshape_input_shape)
Tim Hall79d07d22020-04-27 18:20:16 +0100324 reshape_in.ops = [op]
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100325
326 reshape_op = Operation("Reshape", "{}{}_reshape".format(op.name, idx))
327 reshape_op.attrs["new_shape"] = reshape_input_shape
Tim Hall79d07d22020-04-27 18:20:16 +0100328 reshape_op.inputs = [reshape_in, new_shape_tens]
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100329 reshape_op.set_output_tensor(out_tens)
Tim Hall79d07d22020-04-27 18:20:16 +0100330
331 op.outputs[idx] = reshape_in
332
333 return tens
334
335
336def add_padding_fields(op, arch):
337 if "padding" in op.attrs:
338 if "Conv" in op.type:
339 kernel_size = op.inputs[1].shape[:2]
340 input_shape = op.inputs[0].shape
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200341 elif "Pool" in op.type or op.type in ("ResizeBilinear", "ReduceSum"):
Tim Hall79d07d22020-04-27 18:20:16 +0100342 kernel_size = op.attrs["ksize"][1:3]
343 input_shape = op.inputs[0].shape
344 elif op.type == "ExtractImagePatches":
345 kernel_size = op.attrs["ksizes"][1:3]
346 input_shape = op.inputs[0].shape
347 else:
Louis Verhaard7db78962020-05-25 15:05:26 +0200348 raise UnsupportedFeatureError("Unknown operation that uses padding: {}".format(op.type))
Tim Hall79d07d22020-04-27 18:20:16 +0100349
Jacob Bohlincf7da102020-05-20 09:03:40 +0200350 if op.type == "Conv2DBackpropInputSwitchedBias":
Jacob Bohlin9b64ba02020-07-07 17:15:22 +0200351 upscaling_factor = op.outputs[0].shape[1] // input_shape[1]
Tim Hallc30f4952020-06-15 20:47:35 +0100352 padding, skirt = calc_upscaled_padding_and_skirt(
Jacob Bohlin9b64ba02020-07-07 17:15:22 +0200353 op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
Tim Hallc30f4952020-06-15 20:47:35 +0100354 )
Jacob Bohlincf7da102020-05-20 09:03:40 +0200355 else:
356 dilation_h, dilation_w = op.get_dilation_h_w()
357 dilated_kernel_size = [dilation_h * (kernel_size[0] - 1) + 1, dilation_w * (kernel_size[1] - 1) + 1]
Tim Hallc30f4952020-06-15 20:47:35 +0100358 padding, skirt = calc_padding_and_skirt(
359 op.attrs["padding"], dilated_kernel_size, op.attrs["strides"], input_shape
360 )
Jacob Bohlincf7da102020-05-20 09:03:40 +0200361
Tim Hall79d07d22020-04-27 18:20:16 +0100362 op.attrs["explicit_padding"] = padding
363 op.attrs["skirt"] = skirt
Jacob Bohlincf7da102020-05-20 09:03:40 +0200364
Tim Hall79d07d22020-04-27 18:20:16 +0100365 return op
366
367
Jacob Bohlincf7da102020-05-20 09:03:40 +0200368conv_op = set(("Conv2D", "QuantizedConv2D", "Conv2DBackpropInputSwitchedBias", "Conv2DBiasAct"))
Tim Hall79d07d22020-04-27 18:20:16 +0100369fc_op = set(
370 (
371 "MatMul",
372 "QuantizedMatMul",
373 "BlockLSTM",
374 "RnnAct",
375 "UnidirectionalSequenceRnnAct",
376 "BidirectionalSequenceRnnAct",
377 "LstmAct",
378 "UnidirectionalSequenceLstmAct",
379 "BidirectionalSequenceLstmAct",
380 "FullyConnectedAct",
381 )
382)
383depthwise_op = set(("DepthwiseConv2dNative", "DepthwiseConv2dBiasAct",))
Louis Verhaard7db78962020-05-25 15:05:26 +0200384pool_op = set(
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200385 ("AvgPool", "MaxPool", "QuantizedAvgPool", "QuantizedMaxPool", "AvgPoolAct", "MaxPoolAct", "ResizeBilinear")
Louis Verhaard7db78962020-05-25 15:05:26 +0200386)
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200387reduce_sum_ops = set(("ReduceSum",))
388elementwise_op = set(("AddAct", "MulAct", "SubAct", "Maximum", "Minimum", "LeakyRelu", "Abs", "CLZ", "SHL", "SHR"))
Charles Xu78792222020-05-13 10:15:26 +0200389binary_elementwise_op = set(("AddAct", "MulAct", "SubAct", "Maximum", "Minimum"))
Tim Hall79d07d22020-04-27 18:20:16 +0100390activation_ops = set(("Relu", "Relu6", "ReluN1To1", "Sigmoid", "Tanh"))
391memory_only_ops = set(("Reshape",))
392
Diego Russoea6111a2020-04-14 18:41:58 +0100393
Tim Hall79d07d22020-04-27 18:20:16 +0100394# Check if the op can be reordered
395def get_prepend_op(op):
396 inp = op.inputs[0]
397 # The op should be reordered between prev_op and prep_op
398 prev_op = inp.ops[-1]
399 prep_op = None
400 while prev_op.type in memory_only_ops and len(prev_op.outputs) == 1 and len(prev_op.outputs[0].consumers()) == 1:
401 prep_op = prev_op
402 inp = prev_op.inputs[0]
403 prev_op = inp.ops[-1]
Diego Russoea6111a2020-04-14 18:41:58 +0100404 if prev_op is not None and len(prev_op.outputs) == 1 and len(prev_op.outputs[0].consumers()) == 1:
Tim Hall79d07d22020-04-27 18:20:16 +0100405 return prep_op
406
407 return None
408
409
410def mark_npu_block_type(op, arch):
411 npu_block_type = NpuBlockType.Default
412 if op.type in conv_op:
413 npu_block_type = NpuBlockType.ConvolutionMxN
414 elif op.type in fc_op:
415 npu_block_type = NpuBlockType.VectorProduct
416 elif op.type in depthwise_op:
417 npu_block_type = NpuBlockType.ConvolutionDepthWise
418 elif op.type in pool_op:
419 npu_block_type = NpuBlockType.Pooling
420 elif op.type in elementwise_op:
421 npu_block_type = NpuBlockType.ElementWise
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200422 elif op.type in reduce_sum_ops:
423 npu_block_type = NpuBlockType.ReduceSum
Tim Hall79d07d22020-04-27 18:20:16 +0100424
425 op.attrs["npu_block_type"] = npu_block_type
426 return op
427
428
429def convert_depthwise_to_conv(op, arch):
430 # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and
431 # the ofm depth equals the depth multipler.
432 # If those conditions are true, then we can perform a simple
433 # switch of the operator type (and weight order)
434
435 if ("DepthwiseConv2d" in op.type) and (op.attrs["depth_multiplier"] != 1):
436 ifm_tensor = op.inputs[0]
437 weight_tensor = op.inputs[1]
438 ofm_tensor = op.outputs[0]
439 if (ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"]):
440 # Change op type to Conv2d
441 op.type = op.type.replace("DepthwiseConv2d", "Conv2D")
442 del op.attrs["channel_multiplier"]
443 del op.attrs["depth_multiplier"]
444
445 weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2))
Michael McGeagh6a8d4242020-07-28 12:17:59 +0100446 weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
Tim Hall79d07d22020-04-27 18:20:16 +0100447 else:
Louis Verhaard7db78962020-05-25 15:05:26 +0200448 raise UnsupportedFeatureError(
449 "Unsupported DepthwiseConv2d with depth_multiplier = {}, ifm channels = {}, ofm channels = {}".format(
Tim Hall79d07d22020-04-27 18:20:16 +0100450 op.attrs["depth_multiplier"], ifm_tensor.shape[3], ofm_tensor.shape[3]
451 )
452 )
Tim Hall79d07d22020-04-27 18:20:16 +0100453 return op
454
455
Jacob Bohline843d332020-06-23 12:12:56 +0200456def reorder_depthwise_weights(op, arch):
457 if "DepthwiseConv2d" in op.type:
458 weight_tensor = op.inputs[1]
459 weight_tensor.quant_values = np.transpose(weight_tensor.quant_values, (0, 1, 3, 2))
Michael McGeagh6a8d4242020-07-28 12:17:59 +0100460 weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
Jacob Bohline843d332020-06-23 12:12:56 +0200461 weight_tensor.weight_transpose_depthwise = True
462
463 return op
464
465
Michael McGeagh8d939c02020-07-29 13:11:43 +0100466def convert_conv_to_fc(op, arch):
467 # Conv 1x1 can be equivalent to Fully Connected.
468 # By representing certain convs as fully connected layers, Vela can better determine wether or not to use
469 # caching/double buffering for the weights.
470 # (Weights dont need to be reloaded for convs when IFM H and W are 1)
471 if op.type == "Conv2DBiasAct":
472 _, h, w, _ = op.inputs[0].shape
473 kh, kw, _, _ = op.inputs[1].shape
474 if h == 1 and w == 1 and kh == 1 and kw == 1:
475 # Overwrite this op as a Fully Connected Op
476 op.name += "_fc"
477 op.type = "FullyConnectedAct"
478 faf = op.attrs.get("fused_activation_function", None)
479 op.attrs = {
480 "fused_activation_function": faf,
481 "weights_format": 0,
482 "npu_block_type": NpuBlockType.VectorProduct,
483 }
484 # Reshape Weights to be 2D. HWIO becomes just IO (as H and W are 1, they can just be dropped)
485 weight_tensor = op.inputs[1]
486 weight_tensor.quant_values = weight_tensor.quant_values.squeeze(axis=(0, 1))
487 weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
488 # The output from a fully connected is expected to be 2D so we need to add a reshape layer to convert it
489 # back to 4D afterwards as the next layer is expecting that shape
490 orig_ofm_tensor = op.outputs[0]
491 # Reshape this ops output to be 2D: {(N*H*W), C} (We know N H and W are all 1 so this becomes {1, C})
492 fc_ofm_tensor = orig_ofm_tensor.clone("_fc")
493 fc_ofm_tensor.set_all_shapes([1, fc_ofm_tensor.shape[-1]])
494 fc_ofm_tensor.ops = [op]
495 # Add a reshape after the new OFM to convert it back to the original 4D shape
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100496 reshape_name = op.name + "_reshape"
497 new_shape_tens = create_const_tensor(reshape_name + "_shape", [1], DataType.int32, orig_ofm_tensor.shape)
Michael McGeagh8d939c02020-07-29 13:11:43 +0100498 reshape_op = Operation("Reshape", reshape_name)
Michael McGeagh8d939c02020-07-29 13:11:43 +0100499 reshape_op.attrs["new_shape"] = orig_ofm_tensor.shape
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100500 reshape_op.inputs = [fc_ofm_tensor, new_shape_tens]
501 reshape_op.set_output_tensor(orig_ofm_tensor)
Michael McGeagh8d939c02020-07-29 13:11:43 +0100502 # Replace this ops OFM to point to the 2D tensor
503 op.outputs[0] = fc_ofm_tensor
504 return op
505
506
Tim Hall79d07d22020-04-27 18:20:16 +0100507# Reorder activation op if it's after the memory only operations
508def fixup_act_reorder(op, arch):
509 if op.type in activation_ops:
510 prep_op = get_prepend_op(op)
Diego Russoea6111a2020-04-14 18:41:58 +0100511 if prep_op is not None:
Tim Hall79d07d22020-04-27 18:20:16 +0100512 act_op = op.clone("_reordered")
513 act_op.inputs = [prep_op.inputs[0]]
514 act_op_out = act_op.inputs[0].clone("_acted")
515 act_op_out.quantization = op.outputs[0].quantization.clone()
Michael McGeaghc5b549b2020-08-07 11:54:28 +0100516 act_op.set_output_tensor(act_op_out)
Tim Hall79d07d22020-04-27 18:20:16 +0100517 prep_op.inputs[0] = act_op_out
518 prep_op.outputs[0].quantization = act_op_out.quantization.clone()
519
520 # Mark the op so that it will be removed as passthrough later on
521 op.type = "Identity"
522 return op
523
Louis Verhaarde0ef2732020-06-03 08:56:44 +0200524
Charles Xu78792222020-05-13 10:15:26 +0200525def fixup_elementwise_with_scalars(op, arch):
526 if op.type in binary_elementwise_op:
Louis Verhaarde0ef2732020-06-03 08:56:44 +0200527 ifm_tensor, ifm2_tensor, _, _ = op.get_ifm_ifm2_weights_ofm()
Charles Xu78792222020-05-13 10:15:26 +0200528 if ifm2_tensor.shape != [] and ifm_tensor.shape != []:
529 diff = len(ifm_tensor.shape) - len(ifm2_tensor.shape)
530 if diff > 0:
531 ifm2_tensor.shape = full_shape(len(ifm_tensor.shape), ifm2_tensor.shape, 1)
532 elif diff < 0:
533 ifm_tensor.shape = full_shape(len(ifm2_tensor.shape), ifm_tensor.shape, 1)
Louis Verhaarde0ef2732020-06-03 08:56:44 +0200534 elif ifm_tensor.shape == [] and ifm_tensor.quant_values is None:
535 # IFM is marked as a scalar, but is a result of an operation; change it to a shape of size 1
536 ifm_tensor.shape = len(ifm2_tensor.shape) * [1]
537 ifm_tensor.storage_shape = ifm_tensor.shape
538 elif ifm2_tensor.shape == [] and ifm2_tensor.quant_values is None:
539 # IFM2 is marked as a scalar, but is a result of an operation; change it to a shape of size 1
540 ifm2_tensor.shape = len(ifm_tensor.shape) * [1]
541 ifm2_tensor.storage_shape = ifm2_tensor.shape
Charles Xu78792222020-05-13 10:15:26 +0200542 return op
Tim Hall79d07d22020-04-27 18:20:16 +0100543
Louis Verhaarde0ef2732020-06-03 08:56:44 +0200544
Tim Hall4e127762020-05-15 16:05:49 +0100545# Set input/output tensor equivalence to the same id for memory operations
546def set_tensor_equivalence(op, arch):
547 if op.type == "Reshape":
548 eid = op.outputs[0].equivalence_id
549 for inp in op.inputs:
550 inp.equivalence_id = eid
551 return op
552
553
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200554def convert_softmax(op, arch):
555 if op.type == "Softmax" and op.run_on_npu:
556 softmax = SoftMax(op)
557 op = softmax.get_graph()
558 return op
559
560
Tim Hall79d07d22020-04-27 18:20:16 +0100561def convert_mul_max_to_abs_or_lrelu(op, arch):
Diego Russoea6111a2020-04-14 18:41:58 +0100562 r"""Whenever there is a subgraph with this topology:
Tim Hall79d07d22020-04-27 18:20:16 +0100563
564 Input X For X = -1 or X > 0
565 | \ / This subgraph can be replaced with either
566 | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
567 | /
568 Max
569 """
570
571 if op.type == "Maximum":
572 # finds the Mul input(s) to the Max
573 muls = [i for i in op.inputs if i.ops[0].type == "MulAct"]
574 if len(muls) == 1:
575 mul = muls[0].ops[0]
576 elif len(muls) == 2:
577 # In the case both inputs are Muls, find the one with the same input as the Max
578 mul = [m for m in muls if len(set(op.inputs + m.ops[0].inputs)) == 1][0].ops[0]
579 else:
580 # No Mul inputs
581 return op
582
583 # make sure the Mul doesn't have any other consumers
584 if len(mul.outputs[0].consumers()) != 1:
585 return op
586 # make sure the Mul doesn't have a faf
587 if mul.attrs["fused_activation_function"]:
588 return op
Louis Verhaardb9fc33c2020-08-13 11:47:36 +0200589 ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
590 if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
591 return op
592 if not ifm.is_scaling_equal(ofm):
593 # rewrite to LeakyRelu currently only makes sense if the quantization is identical
594 return op
Tim Hall79d07d22020-04-27 18:20:16 +0100595
596 # finds the branched input that goes to both the Max and the Mul
597 shared = set(op.inputs) & set(mul.inputs)
598 if len(shared) == 1:
599 shared_in = shared.pop()
600 # find the constant scalar input to the Mul
601 const_tens = (set(mul.inputs) - {shared_in}).pop()
602 # check that it is a scalar
603 if const_tens.shape != []:
604 return op
605 const = const_tens.ops[0]
606 # check that it is a constant
607 if const.type != "Const":
608 return op
Louis Verhaardb9fc33c2020-08-13 11:47:36 +0200609 # Remove the Mul from the shared input's consumers
610 shared_in.consumer_list.remove(mul)
Tim Hall79d07d22020-04-27 18:20:16 +0100611 else:
612 return op
613
614 val = const.outputs[0].values
615 if val >= 0:
616 new_op = "LeakyRelu"
617 op.attrs["alpha"] = val
618 elif val == -1:
619 new_op = "Abs"
620 else:
621 return op
622
623 op.type = op.type.replace("Maximum", new_op)
624 op.name = op.name.replace("Maximum", new_op)
625 op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op)
626 op.inputs = [shared_in]
627 return op
628
629
Louis Verhaardb9fc33c2020-08-13 11:47:36 +0200630def convert_lrelu_to_mul_max(op, arch):
631 # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
632 # (the opposite of convert_mul_max_to_abs_or_lrelu)
633 ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
634
635 # Add multiplication with alpha
636 mul_alpha = Operation("MulAct", op.name + "_mul_alpha")
637 mul_alpha.add_input_tensor(ifm)
638 # Create const tensor containing alpha as scalar
639 alpha = op.attrs["alpha"]
640 quantization = ifm.quantization.clone()
641 quantization.min = 0
642 quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
643 quantization.scale_f32 = alpha
644 quantization.zero_point = 0
645 alpha_tens = create_const_tensor(op.name + "_alpha_scalar", [], ifm.dtype, [1], np.int8, quantization=quantization)
646 mul_alpha.add_input_tensor(alpha_tens)
647 fm_alpha = ofm.clone(op.name + "_alpha")
648 mul_alpha.set_output_tensor(fm_alpha)
649
650 if ifm.is_scaling_equal(ofm):
651 # No identity multiplication is needed
652 fm_id = ifm
653 else:
654 # Add multiplication with identity
655 mul_identity = Operation("MulAct", op.name + "_mul_identity")
656 mul_identity.add_input_tensor(ifm)
657 # Create const tensor containing identity as scalar
658 quantization = ifm.quantization.clone()
659 quantization.min = 0
660 quantization.max = quantization.quant_max - quantization.quant_min
661 quantization.scale_f32 = 1
662 quantization.zero_point = 0
663 identity_tens = create_const_tensor(
664 op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
665 )
666 mul_identity.add_input_tensor(identity_tens)
667 fm_id = ofm.clone(op.name + "_id")
668 mul_identity.set_output_tensor(fm_id)
669
670 # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
671 op.type = "Maximum"
672 op.name = op.name.replace("LeakyRelu", "Maximum")
673 op.inputs = []
674 ifm.consumer_list.remove(op)
675 op.add_input_tensor(fm_alpha)
676 op.add_input_tensor(fm_id)
677 return op
678
679
680def convert_lrelu_to_lut(op, arch):
681 ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
682 # Rewrite LeakyRelu by Add with scalar 0 + LUT activation
683 op.type = "AddAct"
684 op.name = op.name + "_add"
685 op.attrs.update({"npu_block_type": NpuBlockType.ElementWise})
686 # Mark as no-op to enable potential fusing optimizations
687 op.attrs["is_nop"] = True
688 # Create an input tensor containing scalar zero
689 quantization = QuantizationParameters(0.0, 255.0)
690 quantization.scale_f32 = 1.0
691 quantization.zero_point = 0
692 tens = create_const_tensor(op.inputs[0].name + "_add", [], ifm.dtype, [0], np.uint8, quantization=quantization)
693 op.add_input_tensor(tens)
694 alpha = op.attrs["alpha"]
695 zp = ofm.quantization.zero_point
696 # Generate the LUT
697 if ifm.dtype.size_in_bytes() == 1:
698 dtype = DataType.int8
699 ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
700 values = [int(x) if x >= zp else int(round(zp - alpha * (zp - x))) for x in ix]
701 else:
702 # int16
703 dtype = DataType.int32
704 values = []
705 for ix in range(512):
706 x = (ix - 256) * 128
707 if x >= zp:
708 base = x
709 slope = 128
710 else:
711 base = int(round(zp - alpha * (zp - x)))
712 next_base = int(round(zp - alpha * (zp - (x + 127))))
713 slope = int(round(128 * (next_base - base) / 127))
714 value = ((slope << 16) & 0xFFFF0000) + (base & 0xFFFF)
715 values.append(value)
716 lut_tensor = lut.create_lut_tensor(op.name + "_lut", values, dtype)
717 op.set_activation_lut(lut_tensor)
718 return op
719
720
721def convert_lrelu(op, arch):
722 # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
723 if op.type != "LeakyRelu":
724 return op
725 ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
726 use_lut = (ifm.is_scaling_equal(ofm)) and (ifm.dtype == ofm.dtype) and ifm.dtype in (DataType.uint8, DataType.int8)
727 if use_lut:
728 return convert_lrelu_to_lut(op, arch)
729 return convert_lrelu_to_mul_max(op, arch)
730
731
732def fuse_activation_function_with_prev(op, arch):
733 # if op is a no-op: attempts to move the activation function to the preceding op
734 if not op.attrs.get("is_nop", False) or op.attrs.get("fused_activation_function", None) is None:
735 return op
736 ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
737 # finds the input(s) to the operation
738 prev_op = ifm.ops[0]
739 # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
740 fuse = (
741 prev_op.run_on_npu
742 and prev_op.attrs["npu_block_type"] != NpuBlockType.Default
743 and len(ifm.ops) == 1
744 and len(prev_op.outputs[0].consumers()) == 1
745 and prev_op.attrs.get("fused_activation_function", None) is None
746 and ifm.is_scaling_equal(ofm)
747 )
748 if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
749 # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
750 # LUT currently only works correctly for elementwise ops
751 fuse = False
752 if fuse and op.activation_lut is not None:
753 # Check if LUT can be used with prev_op
754 prev_ifm, prev_ifm2, _, _ = prev_op.get_ifm_ifm2_weights_ofm()
755 fuse = prev_ifm is not None and prev_ifm.quantization is not None and prev_ifm.is_scaling_equal(ifm)
756 if prev_ifm2 is not None:
757 fuse = fuse and prev_ifm2.quantization is not None and prev_ifm2.is_scaling_equal(ifm)
758 if not fuse:
759 return op
760 # Move the fused activation function + corresponding info to prev_op
761 for attr in ("fused_activation_function", "alpha"):
762 if attr in op.attrs:
763 prev_op.attrs[attr] = op.attrs[attr]
764 if op.activation_lut is not None:
765 prev_op.set_activation_lut(op.activation_lut)
766 # Bypass op
767 prev_op.set_output_tensor(op.outputs[0])
768 return op
769
770
Dwight Lidman42fed942020-05-29 09:37:03 +0200771def add_attrs_to_resizebilinear(op, arch):
Tim Hallc30f4952020-06-15 20:47:35 +0100772 if op.type == "ResizeBilinear" and op.run_on_npu:
Dwight Lidman42fed942020-05-29 09:37:03 +0200773 input_tensor = op.inputs[0]
774 upscaled_shape = [input_tensor.shape[1] * 2, input_tensor.shape[2] * 2]
775 out_shape = op.outputs[0].shape[1:3]
776 if not op.attrs["align_corners"] and out_shape == upscaled_shape:
777 # this means the output is supposed to be a x2 upscale,
778 # so we need to do SAME padding
779 op.attrs["padding"] = b"SAME"
780 elif op.attrs["align_corners"] and out_shape == [upscaled_shape[0] - 1, upscaled_shape[1] - 1]:
781 # here we can just run the avg pool without padding and
782 # produce a (M * 2 - 1, N * 2 - 1) sized output
783 op.attrs["padding"] = b"VALID"
784 else:
Charles Xu9a03fdf2020-07-02 15:12:40 +0200785 return op
Dwight Lidman42fed942020-05-29 09:37:03 +0200786 input_tensor.resampling_mode = resampling_mode.NEAREST
Tim Hallc30f4952020-06-15 20:47:35 +0100787 op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
Dwight Lidman42fed942020-05-29 09:37:03 +0200788 return op
789
790
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200791def add_bias_tensor(op, arch):
792 if ("Conv2d" in op.type or op.type.startswith("FullyConnected")) and not op.inputs[-1]:
793 # Add bias/scale tensor filled with zeros
794 weight_shape = op.inputs[1].shape
795 weight_sets = weight_shape[-1]
796 bias_values = [0] * weight_sets
797 scale_tens = create_const_tensor(op.name + "_bias", [weight_sets], DataType.int32, bias_values)
798 op.set_input_tensor(scale_tens, -1)
799
800 return op
801
802
Tim Hall79d07d22020-04-27 18:20:16 +0100803def supported_operator_check(op, arch):
804 op.run_on_npu = arch.supported_operators.is_operator_supported(op)
805 return op
806
807
808def optimise_graph_a(nng, arch, verbose_graph=False):
809 if verbose_graph:
810 nng.print_graph()
811
812 op_rewrite_list = [
813 # mark block type and check if the operations are supported
814 mark_npu_block_type,
Tim Hall4e127762020-05-15 16:05:49 +0100815 set_tensor_equivalence,
Tim Hall79d07d22020-04-27 18:20:16 +0100816 supported_operator_check,
817 # then do any rewrites of supported operators
818 convert_depthwise_to_conv,
Michael McGeagh8d939c02020-07-29 13:11:43 +0100819 convert_conv_to_fc,
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200820 convert_softmax,
Tim Hall79d07d22020-04-27 18:20:16 +0100821 fixup_fully_connected_input,
822 fixup_pack_input,
823 fixup_conv2d_backprop,
824 fixup_act_reorder,
Dwight Lidman42fed942020-05-29 09:37:03 +0200825 add_attrs_to_resizebilinear,
Tim Hall79d07d22020-04-27 18:20:16 +0100826 add_padding_fields,
827 mark_npu_block_type,
Charles Xu78792222020-05-13 10:15:26 +0200828 fixup_elementwise_with_scalars,
Jacob Bohline843d332020-06-23 12:12:56 +0200829 reorder_depthwise_weights,
Charles Xu9a03fdf2020-07-02 15:12:40 +0200830 fixup_resizebilinear,
Jacob Bohlin67e0d8f2020-08-20 10:53:02 +0200831 add_bias_tensor,
Louis Verhaardb9fc33c2020-08-13 11:47:36 +0200832 convert_mul_max_to_abs_or_lrelu,
833 convert_lrelu,
Tim Hall79d07d22020-04-27 18:20:16 +0100834 ]
835
836 for idx, sg in enumerate(nng.subgraphs):
837 # rewrite graph pass
838 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Diego Russoea6111a2020-04-14 18:41:58 +0100839 sg, arch, [fixup_unpack_output], op_rewrite_list, rewrite_unsupported=False
Tim Hall79d07d22020-04-27 18:20:16 +0100840 )
841
842 for idx, sg in enumerate(nng.subgraphs):
Louis Verhaardb9fc33c2020-08-13 11:47:36 +0200843 # remove passthrough tensors and attempt further optimizations
844 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
845 sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev]
846 )
Tim Hall79d07d22020-04-27 18:20:16 +0100847
848 if verbose_graph:
849 nng.print_graph()
850 return nng
851
Diego Russoea6111a2020-04-14 18:41:58 +0100852
Tim Hall79d07d22020-04-27 18:20:16 +0100853def optimise_graph_b(nng, arch, verbose_graph=False):
854 if verbose_graph:
855 nng.print_graph()
856
857 for idx, sg in enumerate(nng.subgraphs):
858 # combined rewrite graph pass
Diego Russoea6111a2020-04-14 18:41:58 +0100859 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(sg, arch, [rewrite_concat, rewrite_split], [])
Tim Hall79d07d22020-04-27 18:20:16 +0100860
861 if verbose_graph:
862 nng.print_graph()
863 return nng