blob: 26d3dcad8b0a11e119d4443970680a4508b26c77 [file] [log] [blame]
Rob Elliott78b94122024-01-25 13:05:16 +00001# SPDX-FileCopyrightText: Copyright 2021-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020017# Description:
18# Early optimisation of the TOSA based network graph, using the rewrite_graph module to do the traversal of the graph.
Patrik Gustavssonf366fb12021-09-07 13:30:29 +020019import numpy as np
20
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020021from . import rewrite_graph
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020022from .data_type import DataType
23from .debug_database import DebugDatabase
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020024from .graph_optimiser_util import bypass_memory_only_ops
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020025from .graph_optimiser_util import calc_explicit_padding
Patrik Gustavssondf995102021-08-23 15:33:59 +020026from .graph_optimiser_util import convert_depthwise_to_conv
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020027from .graph_optimiser_util import move_splitsliceread_to_consumer
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020028from .graph_optimiser_util import needed_total_padding
29from .graph_optimiser_util import set_ifm_ofm_op_shapes
30from .graph_optimiser_util import set_tensor_equivalence
Johan Alfvence502732023-04-24 13:35:40 +020031from .lut import convert_to_lut
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020032from .operation import ExplicitScaling
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020033from .operation import Op
Tim Hall5ff4cd12023-05-16 22:39:14 +010034from .operation import RoundingMode
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020035from .operation_util import create_add_nop
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020036from .operation_util import create_avgpool_nop
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +020037from .operation_util import create_pad_nop
Patrik Gustavssonf1580f02021-09-01 12:43:02 +020038from .shape4d import Shape4D
39from .tensor import create_const_tensor
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +020040from .tensor import create_equivalence_id
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +020041from .tensor import shape_num_elements
42from .tensor import Tensor
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020043
44
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020045def replace_rescale_with_avg_pool(rescale_op):
46 assert rescale_op.type == Op.Rescale
47
48 avgpool_op = create_avgpool_nop(rescale_op.name + "_avgpool")
49 rescale_op_clone = rescale_op.clone()
50 op = rescale_op
51 op.attrs = avgpool_op.attrs.copy()
52 op.type = Op.AvgPool
53 DebugDatabase.add_optimised(rescale_op_clone, op)
54
55 return op
56
57
58def calc_skirt(kernel, input_shape, explicit_padding):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020059 k_w, k_h = kernel.dilated_wh()
60 s_x, s_y = kernel.stride
61 ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
62 xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020063
64 top, left, bottom, right = explicit_padding
65 top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
66 left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020067
68 padding = (top_pad, left_pad, bottom_pad, right_pad)
69 skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
70 return padding, skirt
71
72
73def add_padding_fields(op, arch, nng):
74 if op.run_on_npu:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020075 if "explicit_padding" in op.attrs:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020076 input_shape = op.ifm_shapes[0]
77
78 if op.type == Op.Conv2DBackpropInputSwitchedBias:
79 # TODO not yet supported, but there will be need for separate handling
80 assert False
81 else:
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020082 padding, skirt = calc_skirt(op.kernel, input_shape, op.attrs.get("explicit_padding"))
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020083
84 op.attrs["explicit_padding"] = padding
85 op.attrs["skirt"] = skirt
86
87 return op
88
89
Patrik Gustavssonf366fb12021-09-07 13:30:29 +020090# Counts leading zeroes for a (int32)
91def count_leading_zeros(a):
92 lz = int(32)
93 if a != 0:
94 mask = 1 << (32 - 1)
95 lz = 0
96 while (mask & a) == 0:
97 mask = mask >> 1
98 lz = lz + 1
99 return lz
100
101
102def calc_scaling_avgpool(op, arch, nng):
103 if op.type == Op.AvgPool:
104 top, left, _, _ = op.attrs["explicit_padding"]
105 # TODO Only support for when global scaling can be used.
106 # That is when there is no padding
107 assert top == 0 and left == 0
108 assert op.explicit_scaling is None
109 multiplier = []
110 shift = []
111
112 kernel_wh = op.kernel.elements_wh()
113 k = 32 - count_leading_zeros(kernel_wh - 1)
114 numerator = np.int64(((1 << 30) + 1) << k)
115 multiplier.append(numerator // kernel_wh)
116 shift.append(30 + k)
117
Tim Hall5ff4cd12023-05-16 22:39:14 +0100118 op.rounding_mode = RoundingMode.HalfUp
Patrik Gustavssonf366fb12021-09-07 13:30:29 +0200119 op.explicit_scaling = ExplicitScaling(False, shift, multiplier)
120 return op
121
122
Patrik Gustavssondf995102021-08-23 15:33:59 +0200123def remove_const_transpose(op, arch, nng):
124 if op.type == Op.Transpose:
125 removed = False
126 if len(op.ifm.ops) == 1:
127 prev_op = op.ifm.ops[0]
128 if prev_op.type == Op.Const:
129 # Transpose the Tensor and data and remove Transpose
130 # TODO move to Tensor?
131 reorder = op.attrs["perms"]
132 shape = op.ifm.shape.copy()
133 tens = op.ifm
134
135 tens.shape = [shape[idx] for idx in reorder]
136 tens.bandwidth_shape = tens.shape
137 tens.storage_shape = tens.shape
138
139 if tens.values is not None:
140 tens.values = tens.values.transpose(reorder)
141
142 op.ofm.values = tens.values
143 # Bypass the Transpose op
144 prev_op.set_output_tensor(op.ofm)
145 DebugDatabase.add_optimised(op, prev_op)
146 removed = True
147
148 if not removed:
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200149 print("Warning: Cannot remove Transpose, and handling of Transpose is not supported")
Patrik Gustavssondf995102021-08-23 15:33:59 +0200150 assert False
151
152 return op
153
154
Patrik Gustavsson1bf0f192021-10-06 14:46:46 +0200155def insert_add_copy_for_const(op, ifm_ofm_shape):
156 assert op.type == Op.Const
157 ofm = op.ofm
158 copy_tens = ofm.clone()
159 op.set_output_tensor(copy_tens)
160
161 name = ofm.name + "_add"
162 ifm2 = create_const_tensor(
163 name + "_zero_scalar",
164 [1],
165 copy_tens.dtype,
166 [0],
Patrik Gustavsson1bf0f192021-10-06 14:46:46 +0200167 quantization=copy_tens.quantization,
168 )
169 copy_op = create_add_nop(name)
170 copy_op.add_input_tensor(copy_tens)
171 copy_op.add_input_tensor(ifm2)
172 copy_op.set_output_tensor(ofm)
173 copy_op.ifm_shapes.append(ifm_ofm_shape)
174 copy_op.ifm_shapes.append(Shape4D(ifm2.shape))
175 copy_op.ofm_shapes.append(ifm_ofm_shape)
176 copy_op.run_on_npu = True
177
178 DebugDatabase.add_optimised(op, copy_op)
179
180
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200181# TODO can we change to add for both TFLite and TOSA?
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200182def insert_add_copy_op_after_tens(tens, ifm_ofm_shape):
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200183 tens_cons_list_copy = tens.consumer_list.copy()
184 copy_tens = tens.clone()
185
186 name = tens.name + "_add"
187 ifm2 = create_const_tensor(
188 name + "_zero_scalar",
189 [1],
190 copy_tens.dtype,
191 [0],
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200192 quantization=copy_tens.quantization,
193 )
194 copy_op = create_add_nop(name)
195 copy_op.add_input_tensor(tens)
196 copy_op.add_input_tensor(ifm2)
197 copy_op.set_output_tensor(copy_tens)
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200198 copy_op.ifm_shapes.append(ifm_ofm_shape)
199 copy_op.ifm_shapes.append(Shape4D(ifm2.shape))
200 copy_op.ofm_shapes.append(ifm_ofm_shape)
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200201 copy_op.run_on_npu = True
202
203 # Set copy_ifm consumers
204 for tens_cons in tens_cons_list_copy:
205 if tens_cons is not None:
206 for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
207 if cons_inp == tens:
208 tens_cons.set_input_tensor(copy_tens, ifm_idx)
209
210 DebugDatabase.add_optimised(tens.ops[0], copy_op)
211
212
Patrik Gustavsson1bf0f192021-10-06 14:46:46 +0200213def get_shape_for_copy_op(shape):
214 # remove dimensions that are set to 1
215 new_shape = []
216 for dim in shape:
217 if dim != 1:
218 new_shape.append(dim)
219 if not new_shape:
220 new_shape = [1]
221
222 rank = len(new_shape)
223 if rank > 3:
224 # Reshape so that batch becomes 1, by moving elements to H dimension
225 n = rank - 2
226 h = 1
227 for i in range(n):
228 h *= shape[i]
229 new_shape = Shape4D(new_shape[n:]).with_height(h)
230 else:
231 new_shape = Shape4D(new_shape)
232 return new_shape
233
234
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200235def fix_sg_input_output_tosa(op, arch, nng):
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200236
Patrik Gustavsson1bf0f192021-10-06 14:46:46 +0200237 if op.type == Op.Const and any(ofm_cons is None for ofm_cons in op.ofm.consumer_list):
238 # Const operator with sg output, insert copy op before the ofm
239 new_shape = get_shape_for_copy_op(op.ofm.shape.copy())
240 insert_add_copy_for_const(op, new_shape)
241 elif op.run_on_npu and op.type in (Op.Reshape, Op.Identity):
242 # For the Reshape operators we want to remove, tensors are removed.
243 # But in order to to do this, they cannot be outputs of the sg,
244 # this need to be fixed prior to the removal.
245 # Solution is to add a copy op, to maintain the original tensor.
246 # This is also valid when reshape ifm/ofm is produced respectively
247 # consumed by CPU
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200248
Patrik Gustavsson1bf0f192021-10-06 14:46:46 +0200249 # Check if operator ifm/ofm are sg ifm/ofm
Per Åstrandbab7f282024-04-22 11:48:09 +0200250 ifm_is_sg_ifm = op.ifm.ops[0].type in (
251 Op.Placeholder,
252 Op.SubgraphInput,
253 Op.Const,
254 )
Patrik Gustavsson1bf0f192021-10-06 14:46:46 +0200255 ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
256 ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
257 # Check if ifm/ofm is produced repectivly consumed by CPU
258 ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
259 ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200260
Patrik Gustavsson1bf0f192021-10-06 14:46:46 +0200261 if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
262 # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Operator
263 # Decide on ifm/ofm shapes for the copy op based on ifm
264 new_shape = get_shape_for_copy_op(op.ifm.shape.copy())
265 insert_add_copy_op_after_tens(op.ifm, new_shape)
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200266 return op
267
268
269def create_add_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
270 """Creates an add op for the given concat op/input feature map"""
271 ofm = concat_op.ofm
Tim Hall3b1578e2023-01-13 17:57:25 +0000272 ifm2 = create_const_tensor(name + "_zero_scalar", [1], ofm.dtype, [0], quantization=ofm.quantization)
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200273 add_op = create_add_nop(name)
274
275 add_op.inputs = [ifm, ifm2]
276 add_op.outputs = [ofm]
277 add_op.write_offset = write_offset
278 add_op.write_shape = ifm_shape
279 ofm.ops.append(add_op)
280 DebugDatabase.add_optimised(concat_op, add_op)
281 add_op.ifm_shapes.append(ifm_shape)
282 add_op.ifm_shapes.append(Shape4D(ifm2.shape))
283 add_op.ofm_shapes.append(concat_op.ofm_shapes[0])
284 add_op.memory_function = Op.ConcatSliceWrite
285 return add_op
286
287
288# TODO Could be further optimized checking the type of the consumer,
289# rather than just mimic the TFLite behaviour depending on type.
290# TOSA bool_t not considered yet
291def remove_splitsliceread(op, arch):
292
293 if op.type == Op.SplitSliceRead:
294 # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
295 if (
296 len(op.ofm.consumer_list) == 1
297 and op.ofm.consumer_list[0] is not None
298 and op.ofm.consumer_list[0].run_on_npu
299 and op.ofm.consumer_list[0].type != Op.Reshape
300 and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
301 and op.ofm.dtype in (DataType.uint8, DataType.int8, DataType.int16)
302 ):
303 # SplitSliceRead can be performed by tensor consumer
304 cons_op = op.ofm.consumer_list[0]
305 move_splitsliceread_to_consumer(op, cons_op)
306 else:
307 name = op.name + "_add"
308 ofm = op.ofm
Per Åstrandbab7f282024-04-22 11:48:09 +0200309 ifm2 = create_const_tensor(
310 name + "_zero_scalar",
311 [1],
312 ofm.dtype,
313 [0],
314 quantization=ofm.quantization,
315 )
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200316 add_op = create_add_nop(name)
317 add_op.inputs = [op.ifm, ifm2]
318 add_op.outputs = [ofm]
319 op.ofm.ops.remove(op)
320 op.ofm.ops.append(add_op)
321 add_op.ifm_shapes.append(op.ifm_shapes[0])
322 add_op.ifm_shapes.append(Shape4D(ifm2.shape))
323 add_op.ofm_shapes.append(op.ofm_shapes[0])
324 add_op.read_offsets[0] = op.read_offsets[0]
325 add_op.read_shapes[0] = op.read_shapes[0]
326
327 op.ifm.consumer_list.remove(op)
328 DebugDatabase.add_optimised(op, add_op)
329
330
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200331def rewrite_concat(op):
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200332 if not op.run_on_npu or not op.type == Op.Concat:
333 return
334
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200335 offset = 0
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200336 inputs = op.inputs
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200337 axis_4D = op.attrs["axis4D"]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200338
339 for idx, inp in enumerate(inputs):
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200340 write_offset = [0, 0, 0, 0]
341 write_offset[axis_4D] = offset
342 concat_end = offset + op.ifm_shapes[idx][axis_4D]
Per Åstrandbab7f282024-04-22 11:48:09 +0200343 create_add_for_concat(
344 op,
345 op.name + str(idx) + "_add",
346 inp,
347 op.ifm_shapes[idx],
348 Shape4D.from_list(write_offset),
349 )
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200350 offset = concat_end
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200351 assert op.ofm_shapes[0][axis_4D] == offset
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200352
353
Patrik Gustavssonef3ebdd2021-10-01 11:10:25 +0200354def remove_memory_ops(op, arch):
355 if op.run_on_npu and op.type in (Op.Reshape, Op.Identity):
Rob Elliott78b94122024-01-25 13:05:16 +0000356 bypass_memory_only_ops(op, arch, None)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200357
358
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200359def rewrite_activation(op, arch, nng):
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200360 if op.type not in (Op.ReluN, Op.Clamp):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200361 return op
362
363 ifm = op.ifm
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200364 zp = ifm.quantization.zero_point if ifm.quantization.zero_point else 0
365 if op.ofm.quantization.zero_point is None:
366 op.ofm.quantization.zero_point = zp
367
Patrik Gustavsson5e26eda2021-06-30 09:07:16 +0200368 if op.type == Op.Clamp:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200369 op.attrs["min"] = op.attrs["min_int"] - zp
370 op.attrs["max"] = op.attrs["max_int"] - zp
371 elif op.type == Op.ReluN:
372 op.attrs["max"] = op.attrs["max_int"] - zp
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200373
374 return op
375
Johan Alfven31947ad2024-04-04 15:50:08 +0200376
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200377def rewrite_rescale(op, arch, nng):
378 if op.type == Op.Rescale:
379 ifm = op.ifm
380 ofm = op.ofm
381
382 # some error checking
383 assert len(ifm.ops) == 1
Per Åstrand931613d2024-03-21 12:58:50 +0100384 prev_op = ifm.ops[0]
385
386 # TODO currently not supported
387 assert len(ifm.consumer_list) == 1
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200388
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200389 input_zp = op.attrs["input_zp"]
390 output_zp = op.attrs["output_zp"]
391 multiplier = op.attrs["multiplier"]
392 shift = op.attrs["shift"]
393 scale32 = op.attrs["scale32"]
394 double_round = op.attrs["double_round"]
395 per_channel = op.attrs["per_channel"]
396
397 assert ifm.dtype in (DataType.uint8, DataType.int8, DataType.int32)
398 assert ifm.dtype in (DataType.uint8, DataType.int8) or input_zp == 0
399 assert ofm.dtype in (DataType.uint8, DataType.int8) or output_zp == 0
400 assert (scale32 and ifm.dtype != DataType.int48) or (not scale32 and not double_round)
401
402 # Check that input tensor has the same zp or no zp
403 ifm_zp = ifm.quantization.zero_point
404 if ifm_zp is not None and ifm_zp != input_zp:
405 print("Error (fuse_rescale): zp of tensors producer/consumer differs unexpectedidly ")
406 assert False
407 ifm.quantization.zero_point = input_zp
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200408 ofm.quantization.zero_point = output_zp
Rob Elliott78b94122024-01-25 13:05:16 +0000409
Oscar Anderssonb90666d2024-02-29 14:35:58 +0100410 assert per_channel is False, "per_channel rescale not supported"
411
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200412 for s, m in zip(shift, multiplier):
413 # TODO these are the TOSA limitations
414 assert m >= 0
415 assert 2 <= s <= 62
416 # TODO these are the HW limitations
417 assert 0 <= s < (1 << 6)
418 explicit_scaling = ExplicitScaling(per_channel, shift, multiplier)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200419
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200420 if double_round and scale32:
Tim Hall5ff4cd12023-05-16 22:39:14 +0100421 rounding_mode = RoundingMode.TFLite
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200422 else:
Tim Hall5ff4cd12023-05-16 22:39:14 +0100423 rounding_mode = RoundingMode.HalfUp
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200424
Per Åstrand931613d2024-03-21 12:58:50 +0100425 if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected:
426 # Currently not supporting per_channel quantization
427 if ifm.dtype == DataType.int32 and not per_channel:
428 prev_op.explicit_scaling = explicit_scaling
429 prev_op.rounding_mode = rounding_mode
430
431 # Bypass op
432 prev_op.set_output_tensor(ofm)
433 DebugDatabase.add_optimised(op, prev_op)
434 return op
435 else:
Per Åstrandbab7f282024-04-22 11:48:09 +0200436 print(
437 "Warning, unsupported fusing of TOSA Rescale previous operator is of type:",
438 prev_op.type,
439 )
Per Åstrand931613d2024-03-21 12:58:50 +0100440 assert False
441 elif (
442 (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8)
443 or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8)
444 or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8)
445 ):
446 # Create NOP performing the RESCALE
447 avgpool_op = replace_rescale_with_avg_pool(op)
448 avgpool_op.rounding_mode = rounding_mode
449
450 if per_channel:
451 # TODO
452 avgpool_op.explicit_scaling = explicit_scaling
453 print("Warning, unsupported TOSA Rescale")
454 assert False
455 else:
456 avgpool_op.explicit_scaling = explicit_scaling
457 elif prev_op.type == Op.Add:
458 # Check that the operations before the Add which creates the IFMs
459 # are Op.Rescale that we can fuse into the add
460 rescale_1 = prev_op.ifm.ops[0]
461 rescale_2 = prev_op.ifm2.ops[0]
462
463 if rescale_1.type == Op.Rescale and rescale_2.type == Op.Rescale:
464 # We are assuming the quantization to be the same for IFMs
465 equal_attributes = ["multiplier", "shift", "double_round"]
466 for a in equal_attributes:
467 assert op.attrs[a] == rescale_1.attrs[a] == rescale_2.attrs[a], (
468 f"Only handling equal {a} for all operands "
Per Åstrandbab7f282024-04-22 11:48:09 +0200469 f"({op.attrs[a]}, {rescale_1.attrs[a]}, {rescale_2.attrs[a]}) "
Per Åstrand931613d2024-03-21 12:58:50 +0100470 "for all the rescale operations to be fused with Add!"
471 )
472
473 assert rescale_1.attrs["input_zp"] == rescale_2.attrs["input_zp"], (
474 f"Only handling equal input_zp ({rescale_1.attrs['input_zp']}!={rescale_2.attrs['input_zp']}) "
475 "for the rescale operations to be fused with Add!"
476 )
477 for op in [rescale_1, rescale_2]:
478 assert op.attrs["output_zp"] == 0, ""
479 assert op.attrs["per_channel"] is False, "per channel quantization is not supported."
480
481 # Create a new add op to set the rescaled ifms and ofm
482 add_op = create_add_nop(prev_op.name + "_fused_rescales")
483 add_op.type = Op.Add
484
485 # set the IFMs and OFM for the cloned operation
486 add_op.set_output_tensor(ofm)
487 add_op.add_input_tensor(rescale_1.ifm)
488 add_op.add_input_tensor(rescale_2.ifm)
489 add_op.set_ifm_ofm_shapes()
490
491 # Remove the consumption of the IFMs to the Add
492 # since we are pruning them from the graph
493 for i, c in enumerate(prev_op.ifm.consumers()):
494 if c == rescale_1:
495 prev_op.ifm.consumers().pop(i)
496 for i, c in enumerate(prev_op.ifm2.consumers()):
497 if c == rescale_2:
498 prev_op.ifm2.consumers().pop(i)
499
500 DebugDatabase.add_optimised(prev_op, op)
501 DebugDatabase.add_optimised(prev_op, rescale_1)
502 DebugDatabase.add_optimised(prev_op, rescale_2)
503 op = add_op
504 else:
505 print("Warning, unsupported fusing of TOSA Rescale with Add.")
506 assert False
507 else:
Per Åstrandbab7f282024-04-22 11:48:09 +0200508 print(
509 "Warning, unsupported fusing of TOSA Rescale previous operator is of type:",
510 prev_op.type,
511 )
Per Åstrand931613d2024-03-21 12:58:50 +0100512 assert False
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200513
Rob Elliott78b94122024-01-25 13:05:16 +0000514 return op
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200515
Johan Alfven31947ad2024-04-04 15:50:08 +0200516
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200517def convert_pad_in_width(op):
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200518 """
519 Rewrites PAD operator to an add that copies the IFM to the OFM
520 + up to 4 add operators that fill the OFM with zeros at the borders.
521 """
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200522 assert op.type == Op.Pad
523 assert op.ifm_shapes[0] is not None and op.ofm_shapes[0] is not None
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200524 ifm = op.ifm
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200525 ofm = op.ofm
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200526 ifm_shape = op.ifm_shapes[0]
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200527 ofm.ops = []
528 ofm_shape = op.ofm_shapes[0]
529
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200530 padding = op.inputs[1].values
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200531 left, right = padding[-2]
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200532
533 # Add op that copies IFM to the right place inside the OFM
534 shp0 = Shape4D(0, 0, 0, 0)
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200535 add_op = create_add_for_concat(op, op.name + "_main", ifm, ifm_shape, shp0.with_width(left))
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200536 add_op.activation = op.activation
537
538 quant = ofm.quantization
539 pad_value = ifm.quantization.zero_point
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200540 ifm.quantization.zero_point = 0
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200541 if left > 0:
542 shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
543 zero_tens = create_const_tensor(
Per Åstrandbab7f282024-04-22 11:48:09 +0200544 op.name + "_left",
545 shape.as_list(),
546 ofm.dtype,
547 shape.elements() * [pad_value],
548 quantization=quant,
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200549 )
550 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200551 create_add_for_concat(op, op.name + "_left", zero_tens, shape, shp0)
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200552 if right > 0:
553 shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
554 zero_tens = create_const_tensor(
Per Åstrandbab7f282024-04-22 11:48:09 +0200555 op.name + "_right",
556 shape.as_list(),
557 ofm.dtype,
558 shape.elements() * [pad_value],
559 quantization=quant,
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200560 )
561 zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
Per Åstrandbab7f282024-04-22 11:48:09 +0200562 create_add_for_concat(
563 op,
564 op.name + "_right",
565 zero_tens,
566 shape,
567 shp0.with_width(ofm_shape.width - right),
568 )
Patrik Gustavssone2bfa7e2021-09-08 15:04:11 +0200569
570 op.type = Op.ConcatTFLite
571 return add_op
572
573
Patrik Gustavssonf436ada2021-09-14 14:56:48 +0200574def convert_table_to_lut(op, arch, nng):
575 # Converts table op to a no-op + LUT
576 if op.type is not Op.Table:
577 return op
578
579 table = op.inputs[1]
580 op.inputs.remove(table)
581 op.set_ifm_ofm_shapes()
582
583 return convert_to_lut(op, table.values, "table")
584
585
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200586def decompose_elem_tensors_hwc(op):
587 """
588 Decomposes elementwise op if any of the ifm(s)/ofm are to large in any dimension to be handled by the NPU
589 """
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200590 max_t_size = 65535
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200591 ofm_shape = op.write_shape if op.write_shape is not None else op.ofm_shapes[0]
592 ifm_shape = op.read_shapes[0] if op.read_shapes[0] is not None else op.ifm_shapes[0]
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200593 ifm2_shape = op.ifm_shapes[1] if op.ifm_shapes[1] else None
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200594 ifm2_shape = op.read_shapes[1] if op.read_shapes[1] is not None else ifm2_shape
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200595 limit_shape = Shape4D(1, max_t_size, max_t_size, max_t_size)
596
597 if any(dim_size > max_t_size for dim_size in ofm_shape.as_list()):
598 ofm_split = ofm_shape.floordiv_const(max_t_size).add(1, 1, 1, 1)
599
600 for height in range(ofm_split.height):
601 for width in range(ofm_split.width):
602 for depth in range(ofm_split.depth):
603 ofm_offset = Shape4D(0, height * max_t_size, width * max_t_size, depth * max_t_size)
604 ofm_part_shape = ofm_shape.clip(ofm_offset, limit_shape)
605 ofm_cut = (ofm_offset, ofm_part_shape)
606
607 ifm_d = depth * max_t_size if ifm_shape.depth == ofm_shape.depth else 0
608 ifm_w = width * max_t_size if ifm_shape.width == ofm_shape.width else 0
609 ifm_h = height * max_t_size if ifm_shape.height == ofm_shape.height else 0
610 ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d)
611 ifm_part_shape = ifm_shape.clip(ifm_offset, limit_shape)
612 ifm_cut = (ifm_offset, ifm_part_shape)
613
614 if ifm2_shape is not None:
615 ifm2_d = depth * max_t_size if ifm2_shape.depth == ofm_shape.depth else 0
616 ifm2_w = width * max_t_size if ifm2_shape.width == ofm_shape.width else 0
617 ifm2_h = height * max_t_size if ifm2_shape.height == ofm_shape.height else 0
618 ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d)
619 ifm2_part_shape = ifm2_shape.clip(ifm2_offset, limit_shape)
620 ifm2_cut = (ifm2_offset, ifm2_part_shape)
621 else:
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200622 ifm2_cut = (None, None)
623
624 create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut)
625 op.ofm.ops.remove(op)
626 op.ifm.consumer_list.remove(op)
627 if op.ifm2 is not None:
628 op.ifm2.consumer_list.remove(op)
629 return
630
631
632def create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut):
633 part_op = op.clone()
634 ifm_read_offset = op.read_offsets[0] if op.read_offsets[0] is not None else Shape4D(0, 0, 0, 0)
635 ofm_write_offset = op.write_offset if op.write_offset is not None else Shape4D(0, 0, 0, 0)
636 ifm_offset, ifm_shape = ifm_cut
637 ofm_offset, ofm_shape = ofm_cut
638
639 part_op.read_offsets[0] = ifm_read_offset + ifm_offset
640 part_op.read_shapes[0] = ifm_shape
641 part_op.write_offset = ofm_write_offset + ofm_offset
642 part_op.write_shape = ofm_shape
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200643 part_op.ifm_shapes = op.ifm_shapes.copy()
644 part_op.ofm_shapes = op.ofm_shapes.copy()
645 part_op.ifm.consumer_list.append(part_op)
646 op.ofm.ops.append(part_op)
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200647
648 ifm2_offset, ifm2_shape = ifm2_cut
649 if ifm2_offset:
650 ifm2_read_offset = op.read_offsets[1] if op.read_offsets[1] is not None else Shape4D(0, 0, 0, 0)
651 part_op.read_offsets[1] = ifm2_read_offset + ifm2_offset
652 part_op.read_shapes[1] = ifm2_shape
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200653 part_op.ifm2.consumer_list.append(part_op)
654
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200655 return part_op
656
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200657
658def get_nhwc_stride(shape):
659 stride_x = shape.depth
660 stride_y = shape.width * stride_x
661 stride_n = shape.height * stride_y
662 return Shape4D(stride_n, stride_y, stride_x, 1)
663
664
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200665def pad_to_rank(shape, rank):
666 """
667 Pads a shape to the given rank
668 """
669 while len(shape) < rank:
670 shape = [1] + shape
671
672 return shape
673
674
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200675def get_elem_shapes_removed_singles(op):
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200676 """
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200677 Returns the shapes of ifm(s)/ofms after removing all the dimensions that are 1 for all ifm(s)/ofm
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200678 """
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200679 binary = op.ifm2 is not None
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200680 ofm_shape = op.ofm_shapes[0].as_list() if len(op.ofm_shapes) > 0 else op.ofm.shape
681 ifm_shape = op.ifm_shapes[0].as_list() if len(op.ifm_shapes) > 0 else op.ifm.shape
682 if binary:
683 ifm2_shape = op.ifm_shapes[1].as_list() if len(op.ofm_shapes) else op.ifm2.shape
684
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200685 rank = max(len(ofm_shape), len(ifm_shape), len(ifm2_shape) if binary else 0)
686 ofm_shape = pad_to_rank(ofm_shape, rank)
687 ifm_shape = pad_to_rank(ifm_shape, rank)
688 if binary:
689 ifm2_shape = pad_to_rank(ifm2_shape, rank)
690
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200691 new_ofm_shape = []
692 new_ifm_shape = []
693 new_ifm2_shape = []
694 for idx in range(rank):
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200695 if ofm_shape[idx] != 1:
696 new_ofm_shape.append(ofm_shape[idx])
697 new_ifm_shape.append(ifm_shape[idx])
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200698 if binary:
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200699 new_ifm2_shape.append(ifm2_shape[idx])
700
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200701 if new_ofm_shape == []:
702 new_ofm_shape = [1]
703 new_ifm_shape = [1]
704 new_ifm2_shape = [1] if binary else None
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200705
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200706 return new_ofm_shape, new_ifm_shape, new_ifm2_shape
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200707
708
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200709def decomp_dims_elementwise(op):
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200710 """
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200711 Decompose elementwise ops with Rank > 3 (H,W,D).
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200712 If Rank > 3, all the dimensions above H are viewed as the N dimension.
713 the elementwise operation will be decomposed to N (of ofm) elementwise operations.
714 By reading and writing with offsets from/to the ifm(s)/ofm.
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200715 Note: Broadcast need to be handled for binary elementwise ops, and TOSA allowes for broadcast by both ifm and ifm2
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200716 """
717
718 ifm = op.ifm
719 ifm2 = op.ifm2
720 ofm = op.ofm
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200721 binary = op.ifm2 is not None
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200722
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200723 # Remove dimensions that are all 1
724 new_ofm_shape, new_ifm_shape, new_ifm2_shape = get_elem_shapes_removed_singles(op)
725 rank = len(new_ofm_shape)
726
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200727 if rank > 3:
728 n = rank - 3
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200729 ofm_decomp_shape = Shape4D(new_ofm_shape[0:n])
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200730 ofm_decomp_stride = get_nhwc_stride(ofm_decomp_shape)
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200731 ofm_part_shape = Shape4D(new_ofm_shape[n:])
732 op.ofm_shapes.append(Shape4D([ofm_decomp_shape.elements()] + new_ofm_shape[n:]))
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200733
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200734 if binary:
735 ifm_decomp_shape = Shape4D(new_ifm_shape[0:n])
736 ifm2_decomp_shape = Shape4D(new_ifm2_shape[0:n])
737 ifm_decomp_stride = get_nhwc_stride(ifm_decomp_shape)
738 ifm2_decomp_stride = get_nhwc_stride(ifm2_decomp_shape)
739 ifm_part_shape = Shape4D(new_ifm_shape[n:])
740 ifm2_part_shape = Shape4D(new_ifm2_shape[n:])
741 op.ifm_shapes.append(Shape4D([ifm_decomp_shape.elements()] + new_ifm_shape[n:]))
742 op.ifm_shapes.append(Shape4D([ifm2_decomp_shape.elements()] + new_ifm2_shape[n:]))
743 else:
744 op.ifm_shapes.append(Shape4D([ofm_decomp_shape.elements()] + new_ofm_shape[n:]))
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200745
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200746 op_list = []
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200747 for height in range(ofm_decomp_shape.height):
748 for width in range(ofm_decomp_shape.width):
749 for depth in range(ofm_decomp_shape.depth):
750 ofm_offset = Shape4D(0, height, width, depth)
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200751 ofm_offset = Shape4D(ofm_offset.dot_prod(ofm_decomp_stride), 0, 0, 0)
752 ofm_cut = (ofm_offset, ofm_part_shape)
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200753
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200754 if binary:
755 ifm_d = depth if ifm_decomp_shape.depth == ofm_decomp_shape.depth else 0
756 ifm_w = width if ifm_decomp_shape.width == ofm_decomp_shape.width else 0
757 ifm_h = height if ifm_decomp_shape.height == ofm_decomp_shape.height else 0
758 ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d)
759 ifm_offset = Shape4D(ifm_offset.dot_prod(ifm_decomp_stride), 0, 0, 0)
760 ifm_cut = (ifm_offset, ifm_part_shape)
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200761
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200762 ifm2_d = depth if ifm2_decomp_shape.depth == ofm_decomp_shape.depth else 0
763 ifm2_w = width if ifm2_decomp_shape.width == ofm_decomp_shape.width else 0
764 ifm2_h = height if ifm2_decomp_shape.height == ofm_decomp_shape.height else 0
765 ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d)
766 ifm2_offset = Shape4D(ifm2_offset.dot_prod(ifm2_decomp_stride), 0, 0, 0)
767 ifm2_cut = (ifm2_offset, ifm2_part_shape)
768 op_list.append(create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut))
769 else:
770 op_list.append(create_elem_part_op(op, ofm_cut, None, ofm_cut))
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200771
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200772 ofm.ops.remove(op)
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200773 ifm.consumer_list.remove(op)
774 if binary:
775 ifm2.consumer_list.remove(op)
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200776
777 return op_list
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200778 else:
779 op.ofm_shapes.append(Shape4D(new_ofm_shape))
780 op.ifm_shapes.append(Shape4D(new_ifm_shape))
781 op.ifm_shapes.append(Shape4D(new_ifm2_shape))
782
783 return [op]
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200784
785
786def decomp_elementwise(tens, arch, nng):
787 """
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200788 Decompose elementwise ops with Rank > 3 (H,W,C).
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200789 Decompose size of tensors exceeding NPU max size
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200790 """
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200791 tens_ops = tens.ops.copy()
792 for op in tens_ops:
793 if op.type.is_elementwise_op():
794 decomp_list = decomp_dims_elementwise(op)
795 for part_op in decomp_list:
796 decompose_elem_tensors_hwc(part_op)
797 return tens
798
799
800def reshape_concat_shape(shape, rank, axis):
801 new_h = 1
802 for i in range(axis):
803 new_h *= shape[i]
804 new_c = 1
805 for i in range(axis + 1, rank):
806 new_c *= shape[i]
807 if axis == (rank - 1):
808 new_shape = [new_h, shape[axis], 1]
809 else:
810 new_shape = [new_h, shape[axis], new_c]
811 return new_shape
812
813
814def reshape_concat(op):
815 """
816 Reshapes concat ops with Rank > 3 (H,W,C).
817 """
818 ofm = op.ofm
819 rank = len(ofm.shape)
820 axis = op.attrs["axis"]
821 if axis < 0:
822 axis += rank
823
824 if rank > 3:
825 # Reshape so that axis in to be concatenated is the W dimension
826 # Reshape inputs
827 for inp in op.inputs:
828 new_shape = reshape_concat_shape(inp.shape, rank, axis)
829 op.ifm_shapes.append(Shape4D(new_shape))
830 # Reshape output
831 new_shape = reshape_concat_shape(ofm.shape, rank, axis)
832 op.ofm_shapes.append(Shape4D(new_shape))
833 op.attrs["axis4D"] = 2
834 else:
835 for inp in op.inputs:
836 op.ifm_shapes.append(Shape4D(inp.shape))
837 op.ofm_shapes.append(Shape4D(ofm.shape))
838 op.attrs["axis4D"] = axis + (4 - rank)
839
840
841def decomp_rewrite_concat(tens, arch, nng):
842 """
843 Decompose concat ops with Rank > 3 (H,W,C).
844 Rewrite of concat to elementwise operations
845 """
846 if len(tens.ops) == 1 and tens.ops[0].type == Op.Concat:
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200847 op = tens.ops[0]
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200848
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200849 reshape_concat(op)
850 rewrite_concat(op)
Patrik Gustavsson3f22ec22021-09-21 14:18:44 +0200851
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200852 op.ofm.ops.remove(op)
853 for inp in op.inputs:
854 inp.consumer_list.remove(op)
855
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200856 return tens
857
858
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200859def decomp_rewrite_pad(op, arch):
860 """
861 Decomposition of pad to elementwise operations:
862 For each dimension that needs padding:
863 -Create a new PAD operator for each dimension to be added
864 Ifm/ofm are reshape so this is the width dimension is to be padded
865 (rank for each is 3)
866 -Rewrite the the new PAD operator so there is:
867 -1 Add operator for copying the data
868 -1 Add operator for each left/right to be padded
869 """
870 # TODO several things would be possible to optimize
871 # For instance there are cases when it should be possible to pad 2
872 # dimensions at the same time.
873 if op.type == Op.Pad:
874 ofm_elements = shape_num_elements(op.ofm.shape)
875 padding = op.inputs[1].values
876
877 rank = len(op.ifm.shape)
878 next_ifm = op.ifm
879 next_ifm_shape = next_ifm.shape.copy()
880
881 first_pad_rewrite_op = None
882 ifm_quant = op.ifm.quantization.clone()
883
884 for dim in range(padding.shape[0]):
885 # Check if padding is to be applied in this dimension
886 dim_pad = padding[dim]
887 if not (dim_pad == 0).all():
888 # Reshape so that width dimension is to be padded
889 new_ifm_shape = reshape_concat_shape(next_ifm_shape, rank, dim)
890 new_pad_input = np.zeros((4, 2), dtype=np.int32)
891 new_pad_input[2] = dim_pad
892
893 pad_op = create_pad_nop(f"{op.name}_dim_{dim}")
894 pad_op.add_input_tensor(next_ifm)
895 new_pad_tens = op.inputs[1].clone("_dim_{dim}")
896
897 name = op.inputs[1].name + f"_dim_{dim}"
Tim Hall3b1578e2023-01-13 17:57:25 +0000898 new_pad_tens = create_const_tensor(name, list(new_pad_input.shape), DataType.int32, new_pad_input)
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200899 pad_op.add_input_tensor(new_pad_tens)
900
901 new_ofm_shape = new_ifm_shape.copy()
902 new_ofm_shape[-2] = new_ofm_shape[-2] + dim_pad.sum()
903 next_ifm_shape[dim] = next_ifm_shape[dim] + dim_pad.sum()
904
905 if Shape4D(new_ofm_shape).elements() == ofm_elements:
906 # Last one, use op.ofm
907 ofm = op.ofm
908 else:
909 # add a new ofm Tensor
910 ofm = Tensor(new_ofm_shape, op.ofm.dtype, f"{pad_op.name}_tens")
911 ofm.quantization = ifm_quant.clone()
912
913 pad_op.set_output_tensor(ofm)
914 pad_op.ifm_shapes.append(Shape4D(new_ifm_shape))
915 pad_op.ofm_shapes.append(Shape4D(new_ofm_shape))
916 DebugDatabase.add_optimised(op, pad_op)
917 next_ifm = ofm
918
919 # Rewrite the pad op
920 converted_pad_op = convert_pad_in_width(pad_op)
921 first_pad_rewrite_op = converted_pad_op
922 else:
923 # Change to Identity operation (will be removed)
924 op.type = Op.Identity
925
926 if first_pad_rewrite_op:
927 assert op.ofm.shape == next_ifm_shape
928 for inp in op.inputs:
929 inp.consumer_list.remove(op)
930 return first_pad_rewrite_op
931
932 return op
933
934
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200935def fixup_quantization(op, arch, nng):
936 if op.ifm and op.ifm.quantization.zero_point is None:
937 op.ifm.quantization.zero_point = 0
938 if op.ifm2 and op.ifm2.quantization.zero_point is None:
Patrik Gustavssonf436ada2021-09-14 14:56:48 +0200939 op.ifm2.quantization.zero_point = 0
940 if not op.forced_output_quantization:
941 if op.ofm and op.ofm.quantization and op.ofm.quantization.zero_point is None:
942 op.ofm.quantization.zero_point = 0
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200943 return op
944
945
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200946def supported_operator_check(op, arch, nng):
947 op.run_on_npu = arch.tosa_supported_operators.is_operator_supported(op)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200948 assert op.run_on_npu or op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200949 return op
950
951
952def tosa_optimise_graph(nng, arch):
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200953
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200954 # TODO the supported operator checking need to be split in semantic and HW checks
955 for idx, sg in enumerate(nng.subgraphs):
956 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200957 nng,
958 sg,
959 arch,
960 [],
961 [supported_operator_check],
962 rewrite_unsupported=False,
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200963 )
964
965 # Decomposing and rewrite of concat
966 for idx, sg in enumerate(nng.subgraphs):
967 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
968 nng, sg, arch, [decomp_rewrite_concat], [], rewrite_unsupported=False
969 )
970
Patrik Gustavssonb4936ad2021-10-05 13:53:34 +0200971 # Decomposing of pad
972 for idx, sg in enumerate(nng.subgraphs):
973 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [decomp_rewrite_pad])
974 sg.refresh_after_modification()
975
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200976 # Handle sg input output
977 for idx, sg in enumerate(nng.subgraphs):
978 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200979 nng,
980 sg,
981 arch,
982 [],
983 [fix_sg_input_output_tosa],
984 rewrite_unsupported=True,
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200985 )
986
987 # Removal of reshapes
988 for sg in nng.subgraphs:
Patrik Gustavssonef3ebdd2021-10-01 11:10:25 +0200989 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_ops])
Patrik Gustavsson008cd102021-09-24 13:46:42 +0200990 sg.refresh_after_modification()
991
Patrik Gustavssonc2b129d2021-09-23 13:52:34 +0200992 # Decomposing of elementwise
Patrik Gustavsson46408a82021-09-20 10:47:47 +0200993 for idx, sg in enumerate(nng.subgraphs):
994 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
995 nng, sg, arch, [decomp_elementwise], [], rewrite_unsupported=False
996 )
997
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200998 for idx, sg in enumerate(nng.subgraphs):
999 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001000 nng,
1001 sg,
1002 arch,
1003 [],
1004 [set_ifm_ofm_op_shapes],
1005 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001006 )
1007
Patrik Gustavssondf995102021-08-23 15:33:59 +02001008 # Removal of Transpose
1009 for idx, sg in enumerate(nng.subgraphs):
1010 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001011 nng,
1012 sg,
1013 arch,
1014 [],
1015 [remove_const_transpose],
1016 rewrite_unsupported=False,
Patrik Gustavssondf995102021-08-23 15:33:59 +02001017 )
1018
Patrik Gustavssonf366fb12021-09-07 13:30:29 +02001019 # TODO, when and where to best handle calc_scaling_avgpool
1020 for idx, sg in enumerate(nng.subgraphs):
1021 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001022 nng,
1023 sg,
1024 arch,
1025 [],
1026 [calc_scaling_avgpool],
1027 rewrite_unsupported=False,
Patrik Gustavssonf366fb12021-09-07 13:30:29 +02001028 )
1029
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001030 # Rewite Operators step
Per Åstrandbab7f282024-04-22 11:48:09 +02001031 op_rewrite_list = [
1032 set_tensor_equivalence,
1033 rewrite_rescale,
1034 convert_depthwise_to_conv,
1035 convert_table_to_lut,
1036 ]
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001037
1038 for idx, sg in enumerate(nng.subgraphs):
1039 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001040 nng,
1041 sg,
1042 arch,
1043 [],
1044 op_rewrite_list,
1045 rewrite_unsupported=False,
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001046 )
1047
Patrik Gustavssonc74682c2021-08-17 14:26:38 +02001048 # Post-processing step 1
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001049 for idx, sg in enumerate(nng.subgraphs):
1050 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
Jonas Ohlssond8575072022-03-30 10:30:25 +02001051 nng,
1052 sg,
1053 arch,
1054 [],
1055 [rewrite_activation, add_padding_fields],
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001056 )
1057
Patrik Gustavssonf1580f02021-09-01 12:43:02 +02001058 # Removal of Slice, need to be done after optimisation has been performed,
1059 # since ifm/ofm_shapes are of importance to this function
1060 for sg in nng.subgraphs:
1061 rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_splitsliceread])
1062 sg.refresh_after_modification()
1063
Patrik Gustavssonc74682c2021-08-17 14:26:38 +02001064 # Post-processing step 2
1065 for idx, sg in enumerate(nng.subgraphs):
Jonas Ohlssond8575072022-03-30 10:30:25 +02001066 nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
1067 nng,
1068 sg,
1069 arch,
1070 [],
1071 [fixup_quantization],
1072 )
Patrik Gustavssonc74682c2021-08-17 14:26:38 +02001073
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02001074 return nng