blob: 8e4d33a5515d6a0912a47893e62da7a703df94b5 [file] [log] [blame]
Louis Verhaarde8a5a782020-11-02 18:04:27 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
20from typing import List
21from typing import Optional
22
23from .api import NpuActivation
24from .api import NpuActivationOp
25from .api import NpuAddressRange
26from .api import NpuBlockOperation
27from .api import NpuBlockTraversal
28from .api import NpuConv2DOperation
29from .api import NpuConvDepthWiseOperation
30from .api import NpuDataType
31from .api import NpuDmaOperation
32from .api import NpuElementWiseOp
33from .api import NpuElementWiseOperation
34from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010035from .api import NpuLayout
36from .api import NpuOperation
37from .api import NpuPadding
38from .api import NpuPoolingOp
39from .api import NpuPoolingOperation
40from .api import NpuQuantization
41from .api import NpuResamplingMode
42from .api import NpuRoundingMode
43from .api import NpuShape3D
44from .api import NpuTileBox
45from .architecture_features import ArchitectureFeatures
Louis Verhaard69b31762020-11-17 09:45:20 +010046from .architecture_features import Block
Louis Verhaarde8a5a782020-11-02 18:04:27 +010047from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010048from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000049from .errors import UnsupportedFeatureError
Louis Verhaarde8a5a782020-11-02 18:04:27 +010050from .high_level_command_stream import Box
51from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010052from .high_level_command_stream import DMA
53from .high_level_command_stream import NpuStripe
Louis Verhaarde8a5a782020-11-02 18:04:27 +010054from .operation import NpuBlockType
55from .operation import Op
56from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010057from .register_command_stream_generator import generate_command_stream
58from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010059from .register_command_stream_util import to_npu_kernel
60from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000061from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010062from .tensor import MemType
63from .tensor import Tensor
64from .tensor import TensorBlockTraversal
65from .tensor import TensorFormat
66from .tensor import TensorPurpose
67
68
Louis Verhaarde8a5a782020-11-02 18:04:27 +010069class BasePointerIndex(IntEnum):
70 WeightTensor = 0 # base address index for the Weight tensor
71 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
72 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010073
74
75dtype_map = {
76 DataType.uint8: NpuDataType.UINT8,
77 DataType.int8: NpuDataType.INT8,
78 DataType.uint16: NpuDataType.UINT16,
79 DataType.int16: NpuDataType.INT16,
80 DataType.int32: NpuDataType.INT32,
81}
82
83
84block_traversal_map = {
85 TensorBlockTraversal.DepthFirst: NpuBlockTraversal.DEPTH_FIRST,
86 TensorBlockTraversal.PartKernelFirst: NpuBlockTraversal.PART_KERNEL_FIRST,
87}
88
89
90# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
91elementwise_op_map = {
92 Op.Mul: NpuElementWiseOp.MUL,
93 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010094 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010095 Op.Sub: NpuElementWiseOp.SUB,
96 Op.Minimum: NpuElementWiseOp.MIN,
97 Op.Maximum: NpuElementWiseOp.MAX,
98 Op.LeakyRelu: NpuElementWiseOp.LRELU,
99 Op.Abs: NpuElementWiseOp.ABS,
100 Op.CLZ: NpuElementWiseOp.CLZ,
101 Op.SHR: NpuElementWiseOp.SHR,
102 Op.SHL: NpuElementWiseOp.SHL,
103}
104
105
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100106def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
107 if ifm_shape == []:
108 # Scalar needs to be in IFM2
109 return False
110 if ifm2_shape == []:
111 return True
112
113 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
114 if ifm != ifm2 and ifm == 1:
115 # Broadcasted FM needs to be in IFM2
116 return False
117 return True
118
119
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100120def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100121 """Specifies type of rounding to be used"""
122 rounding_mode = NpuRoundingMode.TFL
123 if op.type == Op.ResizeBilinear:
124 rounding_mode = NpuRoundingMode.TRUNCATE
125 elif (
126 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
127 and op.ifm.dtype == DataType.int16
128 ):
129 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100130 elif (
131 not fused_quantize
132 and op.type.is_avgpool_op()
133 and op.memory_function == Op.ConcatSliceWrite
134 and op.kernel.elements_wh() == 1
135 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100136 rounding_mode = NpuRoundingMode.NATURAL
137 rounding_mode = op.attrs.get("rounding_mode", rounding_mode)
138 return rounding_mode
139
140
141def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
142 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
143 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100144 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100145
146 # Check if this is for horizontal ifm streaming
147 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100148 top = cmd.pad_top
149 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100150
151 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
152 # because of activation function needed to be fused.
Andreas Nevalainen083f1032020-11-18 10:45:50 +0100153 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > 0:
Louis Verhaard69b31762020-11-17 09:45:20 +0100154 left = 0
Andreas Nevalainen083f1032020-11-18 10:45:50 +0100155 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < Block.from_shape(cmd.ifm_tensor.shape).width:
Louis Verhaard69b31762020-11-17 09:45:20 +0100156 right = 0
157 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100158
159
160def get_region(tens: Tensor, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000161 base_ptr_idx_map = {
162 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
163 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
164 MemType.Scratch: BasePointerIndex.ScratchTensor,
165 }
166
167 if arch.is_spilling_enabled():
168 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100169 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000170 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
171
Dwight Lidman9b43f842020-12-08 17:56:44 +0100172 return base_ptr_idx_map[tens.mem_type].value
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100173
174
175def get_upscale(op: Operation) -> NpuResamplingMode:
176 upscale = NpuResamplingMode.NONE
177 if op.type == Op.ResizeBilinear:
178 # perform nearest neighbor upscale
179 upscale = NpuResamplingMode.NEAREST
180 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
181 # perform insert zero upscale
182 upscale = NpuResamplingMode.TRANSPOSE
183 return upscale
184
185
186def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
187 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100188 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100189 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100190 block = ofm_box.get_block()
191 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100192
193
194def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
195 """Checks if quantization should use 0 as zero point"""
196 if tens.dtype == DataType.int32 and is_ifm_tensor:
197 return True
198 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
199 return False
200 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
201 forced_ofm_quantization = ps.primary_op.forced_output_quantization
202 use_0 = (
203 (ps.primary_op.activation is None or forced_ofm_quantization is not None)
204 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
205 and not fused_quantize
206 )
207 return use_0
208
209
210def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
211 """Gets quantization for IFM/IFM2"""
212 if tens.quantization is None:
213 return None
214 if use_zero_point_0(ps, tens, True):
215 zero_point = 0
216 else:
217 zero_point = int(tens.quantization.zero_point)
218 return NpuQuantization(scale_f32=tens.quantization.scale_f32, zero_point=zero_point)
219
220
221def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
222 """Gets quantization for OFM"""
223 op = ps.primary_op
224 # Check if operation's output quantization is should be used instead of the output tensor's quantization
225 # (used in LUTs)
226 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
227 if ofm_quant is None:
228 return None
229 if use_zero_point_0(ps, tens, False):
230 zero_point = 0
231 else:
232 zero_point = int(ofm_quant.zero_point)
233 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
234
235
patrik.gustavssoneeb85152020-12-21 17:10:40 +0000236def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, fm_shape: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100237 """Creates feature map with common fields populated"""
238 fm = NpuFeatureMap()
239 fm.region = get_region(tens, arch)
240 fm.data_type = dtype_map[tens.dtype]
241 if tens.format == TensorFormat.NHWC:
242 fm.layout = NpuLayout.NHWC
243 elif tens.format == TensorFormat.NHCWB16:
244 fm.layout = NpuLayout.NHCWB16
245 else:
246 assert 0, "Incorrect tensor format"
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100247 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(box.start_coord, box.end_coord, fm_shape)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100248 for idx, addr in enumerate(addresses):
249 if addr is None:
250 addresses[idx] = 0
251 fm.tiles = NpuTileBox(
252 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
253 )
254 strides = tens.get_strides()
255 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
256 return fm
257
258
259def create_weights(weight_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures) -> List[NpuAddressRange]:
260 """Returns address ranges for weights"""
261 weights = []
262 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
263 weight_substream_offsets = weight_tensor.compressed_values_substream_offsets[stream_index]
264 substreams = len(weight_substream_offsets) - 1 # Offset list must terminate with full stream length
265
266 # Extract weight substream offsets and calculate their lengths
267 assert len(weight_substream_offsets) > 1 and (weight_substream_offsets[0] == 0)
268 weight_addr = weight_tensor.address_for_coordinate(weight_box.start_coord)
269 region = get_region(weight_tensor, arch)
270 for core in range(substreams):
271 address = weight_addr + weight_substream_offsets[core]
272 length = weight_substream_offsets[core + 1] - weight_substream_offsets[core]
273 addr_range = NpuAddressRange(region, int(address), int(length))
274 weights.append(addr_range)
275 return weights
276
277
278def create_biases(
279 weight_tensor: Tensor, scale_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures
280) -> List[NpuAddressRange]:
281 """Returns address ranges for biases"""
282 biases = []
283 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
284 scale_substream_offsets = scale_tensor.compressed_values_substream_offsets[stream_index]
285 substreams = len(scale_substream_offsets) - 1 # Offset list must terminate with full stream length
286
287 # Extract scale substream offsets and calculate their lengths
288 assert len(scale_substream_offsets) > 1 and (scale_substream_offsets[0] == 0)
289 scale_addr = scale_tensor.address_for_coordinate(weight_box.start_coord[-1:])
290
291 region = get_region(scale_tensor, arch)
292 for core in range(substreams):
293 address = scale_addr + scale_substream_offsets[core]
294 length = scale_substream_offsets[core + 1] - scale_substream_offsets[core]
295 addr_range = NpuAddressRange(region, int(address), int(length))
296 biases.append(addr_range)
297 return biases
298
299
300def create_npu_activation(op: Operation) -> NpuActivation:
301 """Creates fused activation function"""
302 if op.activation is None:
303 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
304 faf = op.activation.op_type
305 act_op = NpuActivationOp.NONE_OR_RELU
306 if faf == Op.Tanh:
307 act_op = NpuActivationOp.TANH
308 elif faf == Op.Sigmoid:
309 act_op = NpuActivationOp.SIGMOID
310 elif faf == Op.LUT:
311 act_op = NpuActivationOp.TABLE_LOOKUP
312 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000313 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100314
315 act = NpuActivation(act_op)
316 act.min = op.activation.min
317 act.max = op.activation.max
318 act.lookup_table_index = op.activation.lut_index
319 return act
320
321
322def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
323 """Sets common fields of the given operation"""
324 ps = cmd.ps
325 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100326
327 ifm_height = cmd.ifm_box.get_block().height
328 ifm_width = Block.from_shape(cmd.ifm_tensor.shape).width
329 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100330
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100331 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100332 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100333 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100334
335 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100336 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100337 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100338 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
339
340 if cmd.weight_tensor is not None:
341 npu_op.weights = create_weights(cmd.weight_tensor, cmd.weight_box, arch)
342 if cmd.scale_tensor is not None:
343 npu_op.biases = create_biases(cmd.weight_tensor, cmd.scale_tensor, cmd.weight_box, arch)
344 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100345 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
346 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100347 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
348
349 if not op.type.is_elementwise_op():
350 npu_op.padding = create_padding(cmd, op)
351 npu_op.kernel = to_npu_kernel(op.kernel)
352 npu_op.ifm_upscale = get_upscale(op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100353 return npu_op
354
355
356def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
357 """Converts the command to NpuConv2DOperation"""
358 npu_op = NpuConv2DOperation()
359 set_common_op_fields(npu_op, cmd, arch)
360 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
361 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
362 else:
363 npu_op.block_traversal = block_traversal_map[cmd.weight_tensor.block_traversal]
364 return npu_op
365
366
367def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
368 """Converts the command to NpuConvDepthWiseOperation"""
369 npu_op = NpuConvDepthWiseOperation()
370 set_common_op_fields(npu_op, cmd, arch)
371 return npu_op
372
373
374def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
375 """Converts the command to NpuPoolingOperation"""
376 ps = cmd.ps
377 op = ps.primary_op
378 pool_op = NpuPoolingOp.AVERAGE
379 if op.type.is_maxpool_op():
380 pool_op = NpuPoolingOp.MAX
381 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
382 pool_op = NpuPoolingOp.AVERAGE
383 elif op.type == Op.ReduceSum:
384 pool_op = NpuPoolingOp.REDUCE_SUM
385 else:
386 assert 0, f"Unknown pool type {op.type}"
387 npu_op = NpuPoolingOperation(pool_op)
388 set_common_op_fields(npu_op, cmd, arch)
389 # Pooling specific info
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100390 if op.type == Op.ResizeBilinear:
391 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100392 return npu_op
393
394
395def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
396 """Converts the command to NpuElementWiseOperation"""
397 ps = cmd.ps
398 op = ps.primary_op
399 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
400 elemwise_op = elementwise_op_map[op.type]
401 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100402
Louis Verhaard1e170182020-11-26 11:42:04 +0100403 if elemwise_op not in UNARY_ELEMWISE_OPS:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100404 if not ifm_ifm2_correct_order(cmd.ifm_tensor.shape, cmd.ifm2_tensor.shape):
405 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
406 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
407 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100408 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100409 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100410 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100411 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
412 if cmd.ifm2_tensor.shape == []:
413 # scalar
414 assert cmd.ifm2_tensor.quant_values.size == 1
415 npu_op.ifm2_scalar = cmd.ifm2_tensor.values.item(0)
416 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
417 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100418 ifm2_blk = cmd.ifm2_box.get_block()
419 ifm2_width = Block.from_shape(cmd.ifm2_tensor.shape).width
420 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100421 set_common_op_fields(npu_op, cmd, arch)
422 # Check if output scale needs to be overridden
423 output_scale = None
424 if op.type == Op.Add and "resizebilinear" in op.attrs:
425 # Force output scale same as the input scale for
426 # resizebilinear 1x1 that is converted to add
427 output_scale = npu_op.ifm2.quantization.scale_f32
428 if op.type == Op.LeakyRelu:
429 output_scale = op.attrs["alpha"]
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100430 if op.type == Op.RescaleAdd:
431 assert op.rescale is not None, f"{op.type} must have rescale"
432 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100433 if op.type in (Op.Add, Op.Mul, Op.Sub):
434 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
435 output_scale = 1 / 0x3000
436 if output_scale is not None:
437 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
438 return npu_op
439
440
441def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
442 """Converts the command to NpuDmaOperation"""
443 src_region = get_region(cmd.in_tensor, arch)
444 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100445 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100446 else:
447 dest_region = get_region(cmd.out_tensor, arch)
448
449 start_coord = cmd.box.start_coord
450 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
451 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
452
453 if cmd.in_tensor.compressed_values is not None:
454 if cmd.out_tensor.purpose == TensorPurpose.FSBias:
455 sz = cmd.in_tensor.storage_size()
456 else:
457 stream_index = cmd.in_tensor.compressed_stream_index_from_coord(start_coord)
458 sz = cmd.in_tensor.size_of_compressed_stream(stream_index)
459 else:
460 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
461 src = NpuAddressRange(src_region, int(src_addr), int(sz))
462 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
463 return NpuDmaOperation(src, dest)
464
465
466def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
467 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100468 npu_op: NpuOperation
469 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100470 npu_op = create_dma_op(cmd, arch)
Dwight Lidman9b43f842020-12-08 17:56:44 +0100471 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100472 npu_block_type = cmd.ps.primary_op.type.npu_block_type
473 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
474 npu_op = create_npu_conv2d_op(cmd, arch)
475 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
476 npu_op = create_npu_conv_depthwise_op(cmd, arch)
477 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
478 npu_op = create_npu_pool_op(cmd, arch)
479 elif npu_block_type == NpuBlockType.ElementWise:
480 npu_op = create_npu_elementwise_op(cmd, arch)
481 else:
482 assert 0, f"Unknown command type {npu_block_type}"
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100483 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100484
485
486def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
487 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
488 # Convert high level command stream to list of NpuOperation
489 npu_op_list = []
490 npu_op_to_cmd = dict() # map from npu op to high level command
491 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100492 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100493 print("Warning: Skipping register command stream generation for", cmd.ps)
494 else:
495 npu_op = convert_command_to_npu_op(cmd, arch)
496 npu_op_list.append(npu_op)
497 npu_op_to_cmd[npu_op] = cmd
498 # Generate register commands
499 stream_id = DebugDatabase.add_stream(sg)
500 DebugDatabase.set_stream_offset(sg, 0) # Default to zero, can only set during file writing
501
502 def add_to_debug_db(npu_op: NpuOperation, offset: int):
503 """Adds info to the debug database"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100504 if not isinstance(npu_op, NpuDmaOperation):
Louis Verhaard1e170182020-11-26 11:42:04 +0100505 cmd = npu_op_to_cmd[npu_op]
506 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
507
Dwight Lidman9b43f842020-12-08 17:56:44 +0100508 sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db, npu_op_to_cmd)