blob: 1ef044491f7fd759698483a75de439ab2e3fab49 [file] [log] [blame]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# Early optimisation of the TOSA based network graph, using the rewrite_graph module to do the traversal of the graph.
Patrik Gustavssonf366fb12021-09-07 13:30:29 +020018import numpy as np
19
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020020from . import rewrite_graph
21from .api import NpuRoundingMode
22from .data_type import DataType
23from .debug_database import DebugDatabase
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020024from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020025from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020026from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020027from .graph_optimiser_util import convert_to_lut
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020028from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020029from .graph_optimiser_util import needed_total_padding
30from .graph_optimiser_util import set_ifm_ofm_op_shapes
31from .graph_optimiser_util import set_tensor_equivalence
32from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020033from .operation import Op
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020034from .operation_util import create_add_nop
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020035from .operation_util import create_avgpool_nop
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020036from .shape4d import Shape4D
37from .tensor import create_const_tensor
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +020038from .tensor import create_equivalence_id
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020039
40
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020041def replace_rescale_with_avg_pool(rescale_op):
42 assert rescale_op.type == Op.Rescale
43
44 avgpool_op = create_avgpool_nop(rescale_op.name + "_avgpool")
45 rescale_op_clone = rescale_op.clone()
46 op = rescale_op
47 op.attrs = avgpool_op.attrs.copy()
48 op.type = Op.AvgPool
49 DebugDatabase.add_optimised(rescale_op_clone, op)
50
51 return op
52
53
54def calc_skirt(kernel, input_shape, explicit_padding):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020055 k_w, k_h = kernel.dilated_wh()
56 s_x, s_y = kernel.stride
57 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
58 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020059
60 top, left, bottom, right = explicit_padding
61 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
62 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020063
64 padding = (top_pad, left_pad, bottom_pad, right_pad)
65 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
66 return padding, skirt
67
68
69def add_padding_fields(op, arch, nng):
70 if op.run_on_npu:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020071 if "explicit_padding" in op.attrs:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020072 input_shape = op.ifm_shapes[0]
73
74 if op.type == Op.Conv2DBackpropInputSwitchedBias:
75 # TODO not yet supported, but there will be need for separate handling
76 assert False
77 else:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020078 padding, skirt = calc_skirt(op.kernel, input_shape, op.attrs.get("explicit_padding"))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020079
80 op.attrs["explicit_padding"] = padding
81 op.attrs["skirt"] = skirt
82
83 return op
84
85
Patrik Gustavssonf366fb12021-09-07 13:30:29 +020086# Counts leading zeroes for a (int32)
87def count_leading_zeros(a):
88 lz = int(32)
89 if a != 0:
90 mask = 1 << (32 - 1)
91 lz = 0
92 while (mask & a) == 0:
93 mask = mask >> 1
94 lz = lz + 1
95 return lz
96
97
98def calc_scaling_avgpool(op, arch, nng):
99 if op.type == Op.AvgPool:
100 top, left, _, _ = op.attrs["explicit_padding"]
101 # TODO Only support for when global scaling can be used.
102 # That is when there is no padding
103 assert top == 0 and left == 0
104 assert op.explicit_scaling is None
105 multiplier = []
106 shift = []
107
108 kernel_wh = op.kernel.elements_wh()
109 k = 32 - count_leading_zeros(kernel_wh - 1)
110 numerator = np.int64(((1 << 30) + 1) << k)
111 multiplier.append(numerator // kernel_wh)
112 shift.append(30 + k)
113
114 op.rounding_mode = NpuRoundingMode.NATURAL
115 op.explicit_scaling = ExplicitScaling(False, shift, multiplier)
116 return op
117
118
Patrik Gustavssondf995102021-08-23 15:33:59 +0200119def remove_const_transpose(op, arch, nng):
120 if op.type == Op.Transpose:
121 removed = False
122 if len(op.ifm.ops) == 1:
123 prev_op = op.ifm.ops[0]
124 if prev_op.type == Op.Const:
125 # Transpose the Tensor and data and remove Transpose
126 # TODO move to Tensor?
127 reorder = op.attrs["perms"]
128 shape = op.ifm.shape.copy()
129 tens = op.ifm
130
131 tens.shape = [shape[idx] for idx in reorder]
132 tens.bandwidth_shape = tens.shape
133 tens.storage_shape = tens.shape
134
135 if tens.values is not None:
136 tens.values = tens.values.transpose(reorder)
137
138 op.ofm.values = tens.values
139 # Bypass the Transpose op
140 prev_op.set_output_tensor(op.ofm)
141 DebugDatabase.add_optimised(op, prev_op)
142 removed = True
143
144 if not removed:
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200145 print("Warning: Cannot remove Transpose, and handling of Transpose is not supported")
Patrik Gustavssondf995102021-08-23 15:33:59 +0200146 assert False
147
148 return op
149
150
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200151# TODO can we change to add for both TFLite and TOSA?
152def insert_add_copy_op_after_tens(tens):
153 tens_cons_list_copy = tens.consumer_list.copy()
154 copy_tens = tens.clone()
155
156 name = tens.name + "_add"
157 ifm2 = create_const_tensor(
158 name + "_zero_scalar",
159 [1],
160 copy_tens.dtype,
161 [0],
162 copy_tens.dtype.as_numpy_type(),
163 quantization=copy_tens.quantization,
164 )
165 copy_op = create_add_nop(name)
166 copy_op.add_input_tensor(tens)
167 copy_op.add_input_tensor(ifm2)
168 copy_op.set_output_tensor(copy_tens)
169 copy_op.set_ifm_ofm_shapes()
170 copy_op.run_on_npu = True
171
172 # Set copy_ifm consumers
173 for tens_cons in tens_cons_list_copy:
174 if tens_cons is not None:
175 for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
176 if cons_inp == tens:
177 tens_cons.set_input_tensor(copy_tens, ifm_idx)
178
179 DebugDatabase.add_optimised(tens.ops[0], copy_op)
180
181
182def fix_sg_input_output_tosa(op, arch, nng):
183 if not op.run_on_npu or op.type != Op.Reshape:
184 return op
185
186 # For the Reshape operators we want to remove, tensors are removed.
187 # But in order to to do this, they cannot be outputs of the sg,
188 # this need to be fixed prior to the removal.
189 # Solution is to add a copy op, to maintain the original tensor.
190 # This is also valid when reshape ifm/ofm is produced respectively
191 # consumed by CPU
192
193 # Check if operator ifm/ofm are sg ifm/ofm
194 ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
195 ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
196 ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
197 # Check if ifm/ofm is produced repectivly consumed by CPU
198 ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
199 ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
200
201 if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
202 # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape
203 insert_add_copy_op_after_tens(op.ifm)
204
205 return op
206
207
208def create_add_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
209 """Creates an add op for the given concat op/input feature map"""
210 ofm = concat_op.ofm
211 ifm2 = create_const_tensor(
212 name + "_zero_scalar", [1], ofm.dtype, [0], ofm.dtype.as_numpy_type(), quantization=ofm.quantization
213 )
214 add_op = create_add_nop(name)
215
216 add_op.inputs = [ifm, ifm2]
217 add_op.outputs = [ofm]
218 add_op.write_offset = write_offset
219 add_op.write_shape = ifm_shape
220 ofm.ops.append(add_op)
221 DebugDatabase.add_optimised(concat_op, add_op)
222 add_op.ifm_shapes.append(ifm_shape)
223 add_op.ifm_shapes.append(Shape4D(ifm2.shape))
224 add_op.ofm_shapes.append(concat_op.ofm_shapes[0])
225 add_op.memory_function = Op.ConcatSliceWrite
226 return add_op
227
228
229# TODO Could be further optimized checking the type of the consumer,
230# rather than just mimic the TFLite behaviour depending on type.
231# TOSA bool_t not considered yet
232def remove_splitsliceread(op, arch):
233
234 if op.type == Op.SplitSliceRead:
235 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
236 if (
237 len(op.ofm.consumer_list) == 1
238 and op.ofm.consumer_list[0] is not None
239 and op.ofm.consumer_list[0].run_on_npu
240 and op.ofm.consumer_list[0].type != Op.Reshape
241 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
242 and op.ofm.dtype in (DataType.uint8, DataType.int8, DataType.int16)
243 ):
244 # SplitSliceRead can be performed by tensor consumer
245 cons_op = op.ofm.consumer_list[0]
246 move_splitsliceread_to_consumer(op, cons_op)
247 else:
248 name = op.name + "_add"
249 ofm = op.ofm
250 ifm2 = create_const_tensor(
251 name + "_zero_scalar", [1], ofm.dtype, [0], ofm.dtype.as_numpy_type(), quantization=ofm.quantization
252 )
253 add_op = create_add_nop(name)
254 add_op.inputs = [op.ifm, ifm2]
255 add_op.outputs = [ofm]
256 op.ofm.ops.remove(op)
257 op.ofm.ops.append(add_op)
258 add_op.ifm_shapes.append(op.ifm_shapes[0])
259 add_op.ifm_shapes.append(Shape4D(ifm2.shape))
260 add_op.ofm_shapes.append(op.ofm_shapes[0])
261 add_op.read_offsets[0] = op.read_offsets[0]
262 add_op.read_shapes[0] = op.read_shapes[0]
263
264 op.ifm.consumer_list.remove(op)
265 DebugDatabase.add_optimised(op, add_op)
266
267
268def rewrite_concat_ops(op, arch):
269 if not op.run_on_npu or not op.type == Op.Concat:
270 return
271
272 axis_4D = 0
273 ofm = op.ofm
274 ofm.ops = []
275 offset = 0
276
277 inputs = op.inputs
278 axis = op.attrs["axis"]
279
280 for idx, inp in enumerate(inputs):
281 op.ifm_shapes[idx] = Shape4D(inp.shape)
282 if axis >= 0:
283 axis_4D = axis + (4 - len(inp.shape))
284 else:
285 axis_4D = axis
286 write_offset = [0, 0, 0, 0]
287 write_offset[axis_4D] = offset
288 concat_end = offset + op.ifm_shapes[idx][axis_4D]
289 create_add_for_concat(op, op.name + str(idx) + "_add", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset))
290 offset = concat_end
291 assert ofm.shape[axis] == offset
292
293 return op
294
295
Patrik Gustavssondf995102021-08-23 15:33:59 +0200296def remove_reshapes(op, arch):
297 if op.run_on_npu and op.type == Op.Reshape:
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200298 bypass_memory_only_ops(op)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200299
300
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200301def rewrite_activation(op, arch, nng):
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200302 if op.type not in (Op.ReluN, Op.Clamp):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200303 return op
304
305 ifm = op.ifm
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200306 zp = ifm.quantization.zero_point if ifm.quantization.zero_point else 0
307 if op.ofm.quantization.zero_point is None:
308 op.ofm.quantization.zero_point = zp
309
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200310 if op.type == Op.Clamp:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200311 op.attrs["min"] = op.attrs["min_int"] - zp
312 op.attrs["max"] = op.attrs["max_int"] - zp
313 elif op.type == Op.ReluN:
314 op.attrs["max"] = op.attrs["max_int"] - zp
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200315
316 return op
317
318
319def rewrite_rescale(op, arch, nng):
320 if op.type == Op.Rescale:
321 ifm = op.ifm
322 ofm = op.ofm
323
324 # some error checking
325 assert len(ifm.ops) == 1
326 prev_op = ifm.ops[0]
327
328 # TODO currently not supported
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200329 assert len(ifm.consumer_list) == 1
330
331 input_zp = op.attrs["input_zp"]
332 output_zp = op.attrs["output_zp"]
333 multiplier = op.attrs["multiplier"]
334 shift = op.attrs["shift"]
335 scale32 = op.attrs["scale32"]
336 double_round = op.attrs["double_round"]
337 per_channel = op.attrs["per_channel"]
338
339 assert ifm.dtype in (DataType.uint8, DataType.int8, DataType.int32)
340 assert ifm.dtype in (DataType.uint8, DataType.int8) or input_zp == 0
341 assert ofm.dtype in (DataType.uint8, DataType.int8) or output_zp == 0
342 assert (scale32 and ifm.dtype != DataType.int48) or (not scale32 and not double_round)
343
344 # Check that input tensor has the same zp or no zp
345 ifm_zp = ifm.quantization.zero_point
346 if ifm_zp is not None and ifm_zp != input_zp:
347 print("Error (fuse_rescale): zp of tensors producer/consumer differs unexpectedidly ")
348 assert False
349 ifm.quantization.zero_point = input_zp
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200350 ofm.quantization.zero_point = output_zp
351 for s, m in zip(shift, multiplier):
352 # TODO these are the TOSA limitations
353 assert m >= 0
354 assert 2 <= s <= 62
355 # TODO these are the HW limitations
356 assert 0 <= s < (1 << 6)
357 explicit_scaling = ExplicitScaling(per_channel, shift, multiplier)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200358
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200359 if double_round and scale32:
360 rounding_mode = NpuRoundingMode.TFL
361 else:
362 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200363
364 if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected:
365 assert len(multiplier) == len(shift) == len(prev_op.bias.values)
366
367 if ifm.dtype == DataType.int32 and per_channel:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200368 prev_op.explicit_scaling = explicit_scaling
369 prev_op.rounding_mode = rounding_mode
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200370
371 # Bypass op
372 prev_op.set_output_tensor(ofm)
373 DebugDatabase.add_optimised(op, prev_op)
374 return op
375 else:
376 print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
377 assert False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200378 # TODO which are the cases we need to and can do standalone Rescale?
379 # TODO should we try to identify a conversion uint8<->int8 accomplished by 2 RESCALE ops?
380 # origin might be TFLite op QUANTIZE, should we look to see if they can be translated to QUANTIZE?
381 # limited to these at the moment:
382 elif (
383 (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8)
384 or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8)
385 or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8)
386 ):
387 # Create NOP performing the RESCALE
388 avgpool_op = replace_rescale_with_avg_pool(op)
389 avgpool_op.rounding_mode = rounding_mode
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200390
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200391 if per_channel:
392 # TODO
393 avgpool_op.explicit_scaling = explicit_scaling
394 print("Warning, unsupported TOSA Rescale")
395 assert False
396 else:
397 avgpool_op.explicit_scaling = explicit_scaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200398 else:
399 print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
400 assert False
401 return op
402
403
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200404# TODO modified copy of TFLite, solution for TOSA PAD will change so reuse has not been considered
405def convert_pad(op, arch, nng):
406 """
407 Rewrites PAD operator to an add that copies the IFM to the OFM
408 + up to 4 add operators that fill the OFM with zeros at the borders.
409 """
410
411 if op.type != Op.Pad:
412 return op
413
414 # TODO assuming rank <= 4 and N = 1 for rank ==4
415 # This is checked in tosa_supported_operators
416 ifm = op.ifm
417 assert ifm is not None
418 ifm_shape = Shape4D(ifm.shape)
419 ofm = op.ofm
420 assert ofm is not None
421 ofm.ops = []
422 ofm_shape = op.ofm_shapes[0]
423
424 rank = len(ifm.shape)
425 padding = op.inputs[1].values
426 pad_depth = padding[-1]
427 if not (pad_depth == 0).all():
428 print("Warning: For PAD, padding in depth not supported yet")
429 assert False
430
431 top, bottom = 0, 0
432 left, right = 0, 0
433 if rank > 1:
434 left, right = padding[-2][0], padding[-2][1]
435 if rank > 2:
436 top, bottom = padding[-3][0], padding[-3][1]
437 if rank == 4 and not (padding[-4] == 0).all():
438 print("Warning: For PAD, padding not supported in first dimension when rank == 4 yet")
439 assert False
440
441 # Add op that copies IFM to the right place inside the OFM
442 shp0 = Shape4D(0, 0, 0, 0)
443 shp_top = shp0.with_height(top)
444 add_op = create_add_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
445 add_op.activation = op.activation
446
447 quant = ofm.quantization
448 pad_value = ifm.quantization.zero_point
449 # Add operations that fill the borders of the OFM
450 if top > 0:
451 shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
452 zero_tens = create_const_tensor(
453 op.name + "_top",
454 shape.as_list(),
455 ofm.dtype,
456 shape.elements() * [pad_value],
457 np.uint8,
458 quantization=quant, # TODO
459 )
460 # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
461 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
462 create_add_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
463 if bottom > 0:
464 shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
465 zero_tens = create_const_tensor(
466 op.name + "_bottom",
467 shape.as_list(),
468 ofm.dtype,
469 shape.elements() * [pad_value],
470 np.uint8,
471 quantization=quant,
472 )
473 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
474 create_add_for_concat(op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom))
475 if left > 0:
476 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
477 zero_tens = create_const_tensor(
478 op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
479 )
480 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
481 create_add_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
482 if right > 0:
483 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
484 zero_tens = create_const_tensor(
485 op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
486 )
487 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
488 create_add_for_concat(op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right))
489
490 op.type = Op.ConcatTFLite
491 return add_op
492
493
Patrik Gustavssonf436ada2021-09-14 14:56:48 +0200494def convert_table_to_lut(op, arch, nng):
495 # Converts table op to a no-op + LUT
496 if op.type is not Op.Table:
497 return op
498
499 table = op.inputs[1]
500 op.inputs.remove(table)
501 op.set_ifm_ofm_shapes()
502
503 return convert_to_lut(op, table.values, "table")
504
505
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200506def fixup_quantization(op, arch, nng):
507 if op.ifm and op.ifm.quantization.zero_point is None:
508 op.ifm.quantization.zero_point = 0
509 if op.ifm2 and op.ifm2.quantization.zero_point is None:
Patrik Gustavssonf436ada2021-09-14 14:56:48 +0200510 op.ifm2.quantization.zero_point = 0
511 if not op.forced_output_quantization:
512 if op.ofm and op.ofm.quantization and op.ofm.quantization.zero_point is None:
513 op.ofm.quantization.zero_point = 0
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200514 return op
515
516
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200517def supported_operator_check(op, arch, nng):
518 op.run_on_npu = arch.tosa_supported_operators.is_operator_supported(op)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200519 assert op.run_on_npu or op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200520 return op
521
522
523def tosa_optimise_graph(nng, arch):
524 # Pre-processing step
525 pre_process_list = [
526 supported_operator_check,
527 set_ifm_ofm_op_shapes,
528 ]
529
530 for idx, sg in enumerate(nng.subgraphs):
531 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
532 nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
533 )
534
Patrik Gustavssondf995102021-08-23 15:33:59 +0200535 # Removal of Transpose
536 for idx, sg in enumerate(nng.subgraphs):
537 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
538 nng, sg, arch, [], [remove_const_transpose], rewrite_unsupported=False,
539 )
540
541 # Handle sg input output
542 for idx, sg in enumerate(nng.subgraphs):
543 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200544 nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=False,
Patrik Gustavssondf995102021-08-23 15:33:59 +0200545 )
546
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200547 # Rewrite concat ops
548 for idx, sg in enumerate(nng.subgraphs):
549 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
550 sg.refresh_after_modification()
551
Patrik Gustavssondf995102021-08-23 15:33:59 +0200552 # Removal of reshapes
553 for sg in nng.subgraphs:
554 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
555 sg.refresh_after_modification()
556
Patrik Gustavssonf366fb12021-09-07 13:30:29 +0200557 # TODO, when and where to best handle calc_scaling_avgpool
558 for idx, sg in enumerate(nng.subgraphs):
559 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
560 nng, sg, arch, [], [calc_scaling_avgpool], rewrite_unsupported=False,
561 )
562
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200563 # Rewite Operators step
Patrik Gustavssonf436ada2021-09-14 14:56:48 +0200564 op_rewrite_list = [set_tensor_equivalence, rewrite_rescale, convert_depthwise_to_conv, convert_table_to_lut]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200565
566 for idx, sg in enumerate(nng.subgraphs):
567 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
568 nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
569 )
570
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200571 # Post-processing step 1
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200572 for idx, sg in enumerate(nng.subgraphs):
573 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200574 nng, sg, arch, [], [rewrite_activation, convert_pad, add_padding_fields],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200575 )
576
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200577 # Removal of Slice, need to be done after optimisation has been performed,
578 # since ifm/ofm_shapes are of importance to this function
579 for sg in nng.subgraphs:
580 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_splitsliceread])
581 sg.refresh_after_modification()
582
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200583 # Post-processing step 2
584 for idx, sg in enumerate(nng.subgraphs):
585 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [fixup_quantization],)
586
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200587 return nng