blob: e8d5ac64655320a7d4c96152fd2c43faf3512288 [file] [log] [blame]
Johan Alfvén78fc9bc2023-01-05 15:09:27 +01001# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020017# Description:
18# Common functions and definitions used during the graph optimization.
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020019from typing import Tuple
20
Patrik Gustavssondf995102021-08-23 15:33:59 +020021import numpy as np
22
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020023from . import lut
Tim Halld6efcd32022-09-02 15:01:01 +010024from .architecture_features import Accelerator
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from .data_type import DataType
26from .debug_database import DebugDatabase
Patrik Gustavssondf995102021-08-23 15:33:59 +020027from .errors import UnsupportedFeatureError
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020028from .errors import VelaError
29from .operation import Op
Johan Alfven90724962023-02-02 09:07:48 +010030from .operation_util import create_memcpy
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020031from .shape4d import Shape4D
Patrik Gustavssonf436ada2021-09-14 14:56:48 +020032from .tensor import create_const_tensor
33from .tensor import QuantizationParameters
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020034
Jonas Ohlsson81942e92021-08-20 09:33:28 +020035memory_only_ops = (
36 Op.Reshape,
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020037 Op.QuantizedReshape,
Jonas Ohlsson81942e92021-08-20 09:33:28 +020038 Op.Squeeze,
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020039 Op.ExpandDims,
Patrik Gustavssonef3ebdd2021-10-01 11:10:25 +020040 Op.Identity,
Jonas Ohlsson81942e92021-08-20 09:33:28 +020041)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020042
Johan Alfvén48e51592022-09-28 20:06:25 +020043# Ops that are dependent that the original ifm tensor shape is not changed
44# by the bypass memory op function
45original_ifm_shape_ops = (Op.Mean,)
46
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020047
48def _avoid_nhcwb16_for_concat(tens):
49 # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
50 # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
51 # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
52 # and those addresses are always 16 byte aligned due to the NHCWB16 format.
53 return any(op.write_offset.depth % 16 != 0 for op in tens.ops if op.write_offset is not None)
54
55
56def _avoid_nhcwb16_for_split(tens):
57 # If read offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
James Ward6bf16132021-09-08 11:14:20 +010058
59 # Return True if NHCWB16 needs to be avoided
60 def offset_not_aligned(read_offset):
61 return read_offset is not None and (read_offset.depth % 16) != 0
62
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020063 for cons_op in tens.consumer_list:
64 if cons_op.ifm == tens:
James Ward6bf16132021-09-08 11:14:20 +010065 if offset_not_aligned(cons_op.read_offsets[0]):
66 return True
67 if cons_op.ifm2 is not None and cons_op.ifm2 == tens:
68 if offset_not_aligned(cons_op.read_offsets[1]):
69 return True
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020070 return False
71
72
73def _avoid_nhcwb16_for_shapes(tens):
74 # check all producers/consumers to see if any op shape is preventing NHCWB16
75 for cons_op in tens.consumer_list:
76 if cons_op.ifm == tens:
77 cons_op_shape = cons_op.ifm_shapes[0]
78 elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens:
79 cons_op_shape = cons_op.ifm_shapes[1]
80 else:
81 assert False
82 if Shape4D(tens.shape) != cons_op_shape:
83 return True
84
85 for prod_op in tens.ops:
86 if Shape4D(tens.shape) != prod_op.ofm_shapes[0]:
87 return True
88
89 return False
90
91
Johan Alfven90724962023-02-02 09:07:48 +010092def _avoid_nhcwb16_for_memory_only(tens):
93 # check all producers/consumers to see if any op is preventing NHCWB16
94 return any(op.type == Op.Memcpy for op in (tens.consumer_list + tens.ops))
95
96
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020097# Check if non linear format can be used
98def check_format_restrictions(tens, arch):
99 if len(tens.ops) < 1:
100 return
101 if tens.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const) or any(
102 cons is None for cons in tens.consumer_list
103 ):
104 return
105
106 # Check if any of the producers/consumers is run on CPU
107 if not all(cons.run_on_npu for cons in tens.consumer_list):
108 return
109 if not all(prod.run_on_npu for prod in tens.ops):
110 return
111
112 # "Concat" ofm exception:
113 if _avoid_nhcwb16_for_concat(tens):
114 return
115
116 # "Split" ifm exception:
117 if _avoid_nhcwb16_for_split(tens):
118 return
119
120 # Shapes checking: check all producers/consumers are NHCWB16 compatible with tens.shape
121 if _avoid_nhcwb16_for_shapes(tens):
122 return
123
Johan Alfven90724962023-02-02 09:07:48 +0100124 # Memory only ifm/ofm exception: DMA ops must use NHCW
125 if _avoid_nhcwb16_for_memory_only(tens):
126 return
127
Rickard Bolinfea15162022-07-04 16:19:16 +0000128 # Resize bilinear half pixel center implementation requires OFM with linear format to
129 # allow stride modification in H/W dimensions.
130 for op in tens.ops:
131 if op.original_type == Op.ResizeBilinear and op.type == Op.DepthwiseConv2DBias:
132 return
133
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200134 for op in tens.consumer_list:
Tim Halld6efcd32022-09-02 15:01:01 +0100135 if op.type == Op.ReduceSum and (
136 tens.dtype == DataType.int32 or arch.accelerator_config == Accelerator.Ethos_U65_512
137 ):
138 # ReduceSum requires NHWC input
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200139 return
140 if op.type == Op.Reshape:
141 # Using NHCWB16 format for a no-op reshape is only an option if subsequent
142 # consumers do not also need to perform a reshape or if the OFM is going to
143 # be processed by CPU operations. No-op reshape consumers with empty lists
144 # (those that have no consumers, or null-consumers used as list terminators)
145 # must use normal NHWC output.
146
147 def incompatible_consumers(oper):
148 if oper and oper.type == Op.Reshape:
149 for consumer in oper.outputs[0].consumer_list:
150 yield from incompatible_consumers(consumer)
151 yield not oper or not oper.run_on_npu
152
153 if not any(incompatible_consumers(op)):
154
155 def get_rewrites(oper):
156 if oper and oper.type == Op.Reshape:
157 for consumer in oper.outputs[0].consumer_list:
158 yield from get_rewrites(consumer)
159 yield oper
160
161 # Detect no-op reshapes by comparing their full input and output tensor shapes.
162 inshape = op.ifm_shapes[0]
163 compatible_shape = [(inshape == oper.ofm_shapes[0]) for oper in get_rewrites(op)]
164 if not (compatible_shape and all(compatible_shape)):
165 return
166 else:
167 return
168
169 tens.needs_linear_format = False
170
171
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200172def calc_explicit_padding(input_size, stride, filter_size, pad_before, pad_after) -> Tuple[int, int]:
173 """
174 Based on explicit padding provided in a PAD operation, returns the corresponding hardware padding
175 that provides equivalent results.
176 """
177 total_padding = needed_total_padding(input_size, stride, filter_size)
178
179 # The bottom/right padding might need downward adjustment depending on stride/input size
180 total_minus_before = total_padding - pad_before
181 output_pad_after = pad_after
182 while output_pad_after > 0 and output_pad_after % stride != total_minus_before % stride:
183 output_pad_after -= 1
184 return pad_before, output_pad_after
185
186
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200187def needed_total_padding(input_size, stride, filter_size):
188 out_size = (input_size + stride - 1) // stride
189 needed_input = (out_size - 1) * stride + filter_size
190 total_padding = max(0, needed_input - input_size)
191 return total_padding
192
193
194# Set input/output tensor equivalence to the same id for memory operations
195def set_tensor_equivalence(op, arch, nng):
196 if op.type in memory_only_ops:
197 eid = op.outputs[0].equivalence_id
198 for inp in op.inputs:
199 inp.equivalence_id = eid
200 return op
201
202
203def set_ifm_ofm_op_shapes(op, arch, nng):
204 if op.run_on_npu and op.type.needs_shapes():
205 if op.ifm_shapes or op.ofm_shapes:
206 # Shapes already set
207 return op
208 op.set_ifm_ofm_shapes()
209 return op
210
211
Johan Alfvén48e51592022-09-28 20:06:25 +0200212def bypass_need_to_keep_ofm_shape(op):
213 # Check if ifm must be replaced by ofm (rank is changed or the op that follow must have original ifm shape)
214 ifm_replaced_by_ofm = any(
215 ofm_cons is not None and ofm_cons.type in original_ifm_shape_ops for ofm_cons in op.ofm.consumer_list
216 ) or len(op.ifm.shape) != len(op.ofm.shape)
217 return ifm_replaced_by_ofm
218
219
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200220def bypass_memory_only_ops(op):
221 assert op.type in memory_only_ops
Patrik Gustavssondf995102021-08-23 15:33:59 +0200222 ofm = op.ofm
223 ifm = op.ifm
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200224
Johan Alfvén8484d6e2022-09-28 14:22:54 +0200225 # Check if ifm/ofm are network ifm/ofm
Patrik Gustavssondf995102021-08-23 15:33:59 +0200226 ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
Johan Alfvén8484d6e2022-09-28 14:22:54 +0200227 ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list)
228 ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list)
229 # Check if ifm/ofm is produced respectively consumed by CPU
Patrik Gustavssondf995102021-08-23 15:33:59 +0200230 ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
Johan Alfvén8484d6e2022-09-28 14:22:54 +0200231 ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200232
233 # This case should be handled prior to this function
Johan Alfvén8484d6e2022-09-28 14:22:54 +0200234 assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed))
Patrik Gustavssondf995102021-08-23 15:33:59 +0200235
Johan Alfvén48e51592022-09-28 20:06:25 +0200236 if (ifm.shape != ofm.shape) and (ofm_is_sg_ofm or ofm_is_cpu_consumed or bypass_need_to_keep_ofm_shape(op)):
Johan Alfvén8484d6e2022-09-28 14:22:54 +0200237 # Bypassed by replacing ifm with ofm
238 ofm.ops = []
239 for prev_op in ifm.ops:
240 prev_op.outputs = [ofm]
241 ofm.ops.append(prev_op)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200242
Johan Alfvén8484d6e2022-09-28 14:22:54 +0200243 # All ifm consumers need to use ofm as input
244 for ifm_cons in ifm.consumer_list:
245 for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
246 if cons_ifm == ifm:
247 ifm_cons.set_input_tensor(ofm, ifm_idx)
248 else:
249 # Bypassed by replacing ofm with ifm
250 for cons in ofm.consumer_list:
251 for ifm_idx, cons_ifm in enumerate(cons.inputs):
252 if cons_ifm == ofm:
253 cons.set_input_tensor(ifm, ifm_idx)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200254
255
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200256def move_splitsliceread_to_consumer(op, cons_op):
257 assert op.type == Op.SplitSliceRead
258
259 if cons_op.ifm == op.ofm:
260 cons_op.read_offsets[0] = op.read_offsets[0]
261 cons_op.read_shapes[0] = op.read_shapes[0]
262 cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[0])
263 cons_op.ifm_shapes[0] = op.ifm_shapes[0]
264 elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == op.ofm:
265 cons_op.read_offsets[1] = op.read_offsets[0]
266 cons_op.read_shapes[1] = op.read_shapes[0]
267 cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[1])
268 cons_op.ifm_shapes[1] = op.ifm_shapes[0]
Patrik Gustavssonf1580f02021-09-01 12:43:02 +0200269 op.ofm.consumer_list.remove(cons_op)
270 op.ofm.ops = []
271 op.ifm.consumer_list.remove(op)
272
273
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200274def check_memory_only_removed(op, arch):
275 if op.run_on_npu and op.type in memory_only_ops:
276 # Memory only operators should have been removed
277 raise VelaError(f"Memory only {op.type} op {op} expected to have been removed, still remains")
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200278
279
280def record_optimised(op, arch):
wilisa0179a89042022-11-02 17:18:43 +0000281 if op.type not in (Op.Const, Op.Placeholder):
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200282 DebugDatabase.add_optimised(op, op)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200283
284
Johan Alfvén48e51592022-09-28 20:06:25 +0200285def insert_copy_op_before_op(op):
Johan Alfven90724962023-02-02 09:07:48 +0100286 # Create a memcpy op with ifm as input
Johan Alfvén48e51592022-09-28 20:06:25 +0200287 tens = op.ifm
288 copy_tens = tens.clone()
Johan Alfven90724962023-02-02 09:07:48 +0100289 copy_op = create_memcpy(f"{tens.name}_memcpy")
Johan Alfvén48e51592022-09-28 20:06:25 +0200290 copy_op.add_input_tensor(tens)
291 copy_op.set_output_tensor(copy_tens)
292 copy_op.set_ifm_ofm_shapes()
293
294 op.set_input_tensor(copy_tens, 0)
295
296 DebugDatabase.add_optimised(op, copy_op)
297
298
Johan Alfvén8484d6e2022-09-28 14:22:54 +0200299def insert_copy_op_after_tens(tens):
300 tens_cons_list_copy = tens.consumer_list.copy()
Patrik Gustavssondf995102021-08-23 15:33:59 +0200301
Johan Alfven90724962023-02-02 09:07:48 +0100302 # Create a mempcy op with ifm as input
Patrik Gustavssondf995102021-08-23 15:33:59 +0200303 copy_tens = tens.clone()
Johan Alfven90724962023-02-02 09:07:48 +0100304 copy_op = create_memcpy(tens.name + "_memcpy")
Patrik Gustavssondf995102021-08-23 15:33:59 +0200305 copy_op.add_input_tensor(tens)
306 copy_op.set_output_tensor(copy_tens)
307 copy_op.set_ifm_ofm_shapes()
308 copy_op.run_on_npu = True
309
Johan Alfvén8484d6e2022-09-28 14:22:54 +0200310 # Set copy_ifm consumers
311 for tens_cons in tens_cons_list_copy:
312 if tens_cons is not None:
313 for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
314 if cons_inp == tens:
315 tens_cons.set_input_tensor(copy_tens, ifm_idx)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200316
317 DebugDatabase.add_optimised(tens.ops[0], copy_op)
318
319
320def fix_sg_input_output(op, arch, nng):
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200321 if not op.run_on_npu or op.type not in memory_only_ops:
Patrik Gustavssondf995102021-08-23 15:33:59 +0200322 return op
323
Johan Alfvén78fc9bc2023-01-05 15:09:27 +0100324 prev_op = op.ifm.ops[0]
325 while prev_op is not None and prev_op.run_on_npu and prev_op.type in memory_only_ops:
326 # Current op is preceded by another memory only op.
327 # Replace current op's ifm with the preceding op's ifm. By doing
328 # this the preceding op is removed from current path.
329 next_prev_op = prev_op.ifm.ops[0]
330 if next_prev_op is not None and next_prev_op.run_on_npu and next_prev_op.type in memory_only_ops:
331 # Preceding op also have a preceding memory only op
332 prev_op = next_prev_op
333 else:
334 op.set_input_tensor(prev_op.ifm, 0)
335 break
336
Johan Alfvén8484d6e2022-09-28 14:22:54 +0200337 # For the memory only operators we want to remove, tensors are removed.
338 # But in order to to do this, they cannot be outputs of the sg,
339 # this need to be fixed prior to the removal.
Patrik Gustavssondf995102021-08-23 15:33:59 +0200340 # Solution is to add a avgpool NOP, to maintain the original tensor.
Johan Alfvén8484d6e2022-09-28 14:22:54 +0200341 # This is also valid when reshape ifm/ofm is produced respectively
342 # consumed by CPU
Patrik Gustavssondf995102021-08-23 15:33:59 +0200343
Johan Alfvén48e51592022-09-28 20:06:25 +0200344 # Rare case: original_ifm_shape_ops contain ops that are dependent
345 # that the original ifm tensor shape is not changed by the bypass memory
346 # function. If the memory only op ifm is subgraph ifm/ifm is cpu produced
347 # or the ifm is consumed by many, then there is a need to insert an avgpool
348 # NOP before the original_ifm_shape_ops. Also note that the NOP is only inserted
349 # before original_ifm_shape_ops. The above is also true when the memory only
350 # op change the rank between the IFM and OFM.
351 #
352 # Below is an example showing the case when there is a need for an AVG NOP
353 # when RESHAPE is bypassed by replacing IFM with OFM.
354 #
355 # Converts to And in bypass_memory
356 # ---> --->
357 # -----ADD----- -----ADD----- -----ADD-----
358 # | | | | | |
359 # 1x6x6x10 1x6x6x10 1x6x6x10 1x6x6x10 1x6x6x10 1x6x6x10
360 # RESHAPE MEAN AVG POOL MEAN AVG POOL MEAN
361 # | | | |
362 # 1x20x3x6 1x6x6x10 1x20x3x6
363 # MEAN RESHAPE MEAN
364 # |
365 # 1x20x3x6
366 # MEAN
367 ifm_has_multiple_cons = len(op.ifm.consumer_list) > 1
368
Johan Alfvén8484d6e2022-09-28 14:22:54 +0200369 # Check if operator ifm/ofm are sg ifm/ofm
Patrik Gustavssondf995102021-08-23 15:33:59 +0200370 ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
Johan Alfvén8484d6e2022-09-28 14:22:54 +0200371 ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
372 ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
373 # Check if ifm/ofm is produced respectively consumed by CPU
Johan Alfvén5060ff52022-09-15 15:50:30 +0200374 ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
Johan Alfvén8484d6e2022-09-28 14:22:54 +0200375 ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
Johan Alfvén5060ff52022-09-15 15:50:30 +0200376
Johan Alfvén48e51592022-09-28 20:06:25 +0200377 if bypass_need_to_keep_ofm_shape(op):
378 # Bypass need to keep OFM shape
379 if ifm_has_multiple_cons:
380 # Rare case:
381 # IFM need to persist due to multiple consumers and copy op is needed
382 # OFM will replace IFM for the memory only op
383 insert_copy_op_before_op(op)
Johan Alfvén78fc9bc2023-01-05 15:09:27 +0100384 # One copy added so no need to check for another copy further down
385 return op
Johan Alfvén48e51592022-09-28 20:06:25 +0200386 elif not (ofm_is_sg_ofm or ofm_is_cpu_consumed):
387 # Only one consumer and OFM is not subgraph output or cpu consumed,
388 # safe to replace ifm.shape by ofm.shape
389 # IFM can then replace OFM for the memory only op and no copy op is needed
390 op.ifm.shape = op.ofm.shape
391
392 # Special case when when OFM is sg_ofm or cpu_consumed
Johan Alfvén8484d6e2022-09-28 14:22:54 +0200393 if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
394 # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the memory only operator.
395 insert_copy_op_after_tens(op.ifm)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200396
397 return op
398
399
400def convert_depthwise_to_conv(op, arch, nng):
401 # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and
402 # the ofm depth equals the depth multipler.
403 # If those conditions are true, then we can perform a simple
404 # switch of the operator type (and weight order)
405
406 if op.type == Op.DepthwiseConv2DBias and (op.attrs["depth_multiplier"] != 1):
407 ifm_shape = op.ifm_shapes[0]
408 weight_tensor = op.inputs[1]
409 ofm_shape = op.ofm_shapes[0]
410 if (ifm_shape.depth == 1) and (ofm_shape.depth == op.attrs["depth_multiplier"]):
411 # Change op type to Conv2d
412 op.type = Op.Conv2DBias
413 del op.attrs["channel_multiplier"]
414 del op.attrs["depth_multiplier"]
415
416 weight_tensor.values = np.transpose(weight_tensor.values, (0, 1, 3, 2))
417 weight_tensor.set_all_shapes(list(weight_tensor.values.shape))
wilisa0179a89042022-11-02 17:18:43 +0000418 DebugDatabase.add_optimised(op, op)
Patrik Gustavssondf995102021-08-23 15:33:59 +0200419 else:
420 raise UnsupportedFeatureError(
421 f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},",
422 f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}",
423 )
Patrik Gustavssondf995102021-08-23 15:33:59 +0200424 return op
Patrik Gustavssonf436ada2021-09-14 14:56:48 +0200425
426
427def convert_to_lut(op, lut_values, lut_name):
428 # Rewrite the operation by Add with scalar 0 + LUT activation
Tim Hall1c590482023-01-26 17:27:00 +0000429 ifm = op.ifm
430 ofm = op.ofm
Patrik Gustavssonf436ada2021-09-14 14:56:48 +0200431 if ifm is None:
432 return op
433 assert ifm.dtype.size_in_bytes() == 1
434 op.type = Op.Add
435 op.name = op.name + "_lut_" + lut_name
436 # Mark as no-op to enable potential fusing optimizations
437 op.attrs["is_nop"] = True
438 # Create an input tensor containing scalar zero
439 quantization = QuantizationParameters(0.0, 255.0)
440 quantization.scale_f32 = ifm.quantization.scale_f32
441 quantization.zero_point = 0
Tim Hall1c590482023-01-26 17:27:00 +0000442 tens = create_const_tensor(ifm.name + "_scalar0", [], ifm.dtype, [0], quantization=quantization)
Patrik Gustavssonf436ada2021-09-14 14:56:48 +0200443 op.add_input_tensor(tens)
444 op.ifm_shapes.append(Shape4D(tens.shape)) # TODO no shape?
445
446 # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
447 # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
448 # should be the same as the IFM
449 op.forced_output_quantization = ifm.quantization
Tim Hall1c590482023-01-26 17:27:00 +0000450
451 # the lut tensor datatype needs to match both; the ofm datatype, because these are the values output; and the
452 # datatype used to generate the lut values (which is probably the ifm datatype), because we want to avoid any
453 # potential overflow errors in create_lut_tensor() caused by converting Python int (which could represent a uint)
454 # to NumPy int. this can be guaranteed by checking that the ifm and ofm datatypes are the same
455 assert ifm.dtype == ofm.dtype
456 lut_tensor = lut.create_lut_tensor(op.name + "_values", lut_values, ofm.dtype)
Patrik Gustavssonf436ada2021-09-14 14:56:48 +0200457 op.set_activation_lut(lut_tensor)
458 op.set_ifm_ofm_shapes()
wilisa0179a89042022-11-02 17:18:43 +0000459 DebugDatabase.add_optimised(op, op)
Patrik Gustavssonf436ada2021-09-14 14:56:48 +0200460 return op