blob: 9e0ed01008a8fb0e18e812473d3d0fd07f92fadf [file] [log] [blame]
Louis Verhaarde8a5a782020-11-02 18:04:27 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
20from typing import List
21from typing import Optional
22
23from .api import NpuActivation
24from .api import NpuActivationOp
25from .api import NpuAddressRange
26from .api import NpuBlockOperation
27from .api import NpuBlockTraversal
28from .api import NpuConv2DOperation
29from .api import NpuConvDepthWiseOperation
30from .api import NpuDataType
31from .api import NpuDmaOperation
32from .api import NpuElementWiseOp
33from .api import NpuElementWiseOperation
34from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010035from .api import NpuLayout
36from .api import NpuOperation
37from .api import NpuPadding
38from .api import NpuPoolingOp
39from .api import NpuPoolingOperation
40from .api import NpuQuantization
41from .api import NpuResamplingMode
42from .api import NpuRoundingMode
43from .api import NpuShape3D
44from .api import NpuTileBox
45from .architecture_features import ArchitectureFeatures
Louis Verhaard69b31762020-11-17 09:45:20 +010046from .architecture_features import Block
Louis Verhaarde8a5a782020-11-02 18:04:27 +010047from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010048from .debug_database import DebugDatabase
Louis Verhaarde8a5a782020-11-02 18:04:27 +010049from .high_level_command_stream import Box
50from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010051from .high_level_command_stream import DMA
52from .high_level_command_stream import NpuStripe
Louis Verhaarde8a5a782020-11-02 18:04:27 +010053from .operation import NpuBlockType
54from .operation import Op
55from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010056from .register_command_stream_generator import generate_command_stream
57from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010058from .register_command_stream_util import to_npu_kernel
59from .register_command_stream_util import UNARY_ELEMWISE_OPS
Louis Verhaarde8a5a782020-11-02 18:04:27 +010060from .tensor import MemType
61from .tensor import Tensor
62from .tensor import TensorBlockTraversal
63from .tensor import TensorFormat
64from .tensor import TensorPurpose
65
66
Louis Verhaarde8a5a782020-11-02 18:04:27 +010067class BasePointerIndex(IntEnum):
68 WeightTensor = 0 # base address index for the Weight tensor
69 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
70 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010071
72
73dtype_map = {
74 DataType.uint8: NpuDataType.UINT8,
75 DataType.int8: NpuDataType.INT8,
76 DataType.uint16: NpuDataType.UINT16,
77 DataType.int16: NpuDataType.INT16,
78 DataType.int32: NpuDataType.INT32,
79}
80
81
82block_traversal_map = {
83 TensorBlockTraversal.DepthFirst: NpuBlockTraversal.DEPTH_FIRST,
84 TensorBlockTraversal.PartKernelFirst: NpuBlockTraversal.PART_KERNEL_FIRST,
85}
86
87
88# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
89elementwise_op_map = {
90 Op.Mul: NpuElementWiseOp.MUL,
91 Op.Add: NpuElementWiseOp.ADD,
92 Op.Sub: NpuElementWiseOp.SUB,
93 Op.Minimum: NpuElementWiseOp.MIN,
94 Op.Maximum: NpuElementWiseOp.MAX,
95 Op.LeakyRelu: NpuElementWiseOp.LRELU,
96 Op.Abs: NpuElementWiseOp.ABS,
97 Op.CLZ: NpuElementWiseOp.CLZ,
98 Op.SHR: NpuElementWiseOp.SHR,
99 Op.SHL: NpuElementWiseOp.SHL,
100}
101
102
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100103def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
104 if ifm_shape == []:
105 # Scalar needs to be in IFM2
106 return False
107 if ifm2_shape == []:
108 return True
109
110 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
111 if ifm != ifm2 and ifm == 1:
112 # Broadcasted FM needs to be in IFM2
113 return False
114 return True
115
116
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100117def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100118 """Specifies type of rounding to be used"""
119 rounding_mode = NpuRoundingMode.TFL
120 if op.type == Op.ResizeBilinear:
121 rounding_mode = NpuRoundingMode.TRUNCATE
122 elif (
123 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
124 and op.ifm.dtype == DataType.int16
125 ):
126 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100127 elif (
128 not fused_quantize
129 and op.type.is_avgpool_op()
130 and op.memory_function == Op.ConcatSliceWrite
131 and op.kernel.elements_wh() == 1
132 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100133 rounding_mode = NpuRoundingMode.NATURAL
134 rounding_mode = op.attrs.get("rounding_mode", rounding_mode)
135 return rounding_mode
136
137
138def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
139 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
140 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100141 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100142
143 # Check if this is for horizontal ifm streaming
144 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100145 top = cmd.pad_top
146 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100147
148 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
149 # because of activation function needed to be fused.
Andreas Nevalainen083f1032020-11-18 10:45:50 +0100150 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > 0:
Louis Verhaard69b31762020-11-17 09:45:20 +0100151 left = 0
Andreas Nevalainen083f1032020-11-18 10:45:50 +0100152 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < Block.from_shape(cmd.ifm_tensor.shape).width:
Louis Verhaard69b31762020-11-17 09:45:20 +0100153 right = 0
154 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100155
156
157def get_region(tens: Tensor, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000158 base_ptr_idx_map = {
159 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
160 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
161 MemType.Scratch: BasePointerIndex.ScratchTensor,
162 }
163
164 if arch.is_spilling_enabled():
165 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100166 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000167 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
168
Dwight Lidman9b43f842020-12-08 17:56:44 +0100169 return base_ptr_idx_map[tens.mem_type].value
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100170
171
172def get_upscale(op: Operation) -> NpuResamplingMode:
173 upscale = NpuResamplingMode.NONE
174 if op.type == Op.ResizeBilinear:
175 # perform nearest neighbor upscale
176 upscale = NpuResamplingMode.NEAREST
177 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
178 # perform insert zero upscale
179 upscale = NpuResamplingMode.TRANSPOSE
180 return upscale
181
182
183def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
184 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100185 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100186 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100187 block = ofm_box.get_block()
188 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100189
190
191def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
192 """Checks if quantization should use 0 as zero point"""
193 if tens.dtype == DataType.int32 and is_ifm_tensor:
194 return True
195 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
196 return False
197 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
198 forced_ofm_quantization = ps.primary_op.forced_output_quantization
199 use_0 = (
200 (ps.primary_op.activation is None or forced_ofm_quantization is not None)
201 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
202 and not fused_quantize
203 )
204 return use_0
205
206
207def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
208 """Gets quantization for IFM/IFM2"""
209 if tens.quantization is None:
210 return None
211 if use_zero_point_0(ps, tens, True):
212 zero_point = 0
213 else:
214 zero_point = int(tens.quantization.zero_point)
215 return NpuQuantization(scale_f32=tens.quantization.scale_f32, zero_point=zero_point)
216
217
218def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
219 """Gets quantization for OFM"""
220 op = ps.primary_op
221 # Check if operation's output quantization is should be used instead of the output tensor's quantization
222 # (used in LUTs)
223 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
224 if ofm_quant is None:
225 return None
226 if use_zero_point_0(ps, tens, False):
227 zero_point = 0
228 else:
229 zero_point = int(ofm_quant.zero_point)
230 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
231
232
233def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures) -> NpuFeatureMap:
234 """Creates feature map with common fields populated"""
235 fm = NpuFeatureMap()
236 fm.region = get_region(tens, arch)
237 fm.data_type = dtype_map[tens.dtype]
238 if tens.format == TensorFormat.NHWC:
239 fm.layout = NpuLayout.NHWC
240 elif tens.format == TensorFormat.NHCWB16:
241 fm.layout = NpuLayout.NHCWB16
242 else:
243 assert 0, "Incorrect tensor format"
244 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(box.start_coord, box.end_coord)
245 for idx, addr in enumerate(addresses):
246 if addr is None:
247 addresses[idx] = 0
248 fm.tiles = NpuTileBox(
249 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
250 )
251 strides = tens.get_strides()
252 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
253 return fm
254
255
256def create_weights(weight_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures) -> List[NpuAddressRange]:
257 """Returns address ranges for weights"""
258 weights = []
259 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
260 weight_substream_offsets = weight_tensor.compressed_values_substream_offsets[stream_index]
261 substreams = len(weight_substream_offsets) - 1 # Offset list must terminate with full stream length
262
263 # Extract weight substream offsets and calculate their lengths
264 assert len(weight_substream_offsets) > 1 and (weight_substream_offsets[0] == 0)
265 weight_addr = weight_tensor.address_for_coordinate(weight_box.start_coord)
266 region = get_region(weight_tensor, arch)
267 for core in range(substreams):
268 address = weight_addr + weight_substream_offsets[core]
269 length = weight_substream_offsets[core + 1] - weight_substream_offsets[core]
270 addr_range = NpuAddressRange(region, int(address), int(length))
271 weights.append(addr_range)
272 return weights
273
274
275def create_biases(
276 weight_tensor: Tensor, scale_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures
277) -> List[NpuAddressRange]:
278 """Returns address ranges for biases"""
279 biases = []
280 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
281 scale_substream_offsets = scale_tensor.compressed_values_substream_offsets[stream_index]
282 substreams = len(scale_substream_offsets) - 1 # Offset list must terminate with full stream length
283
284 # Extract scale substream offsets and calculate their lengths
285 assert len(scale_substream_offsets) > 1 and (scale_substream_offsets[0] == 0)
286 scale_addr = scale_tensor.address_for_coordinate(weight_box.start_coord[-1:])
287
288 region = get_region(scale_tensor, arch)
289 for core in range(substreams):
290 address = scale_addr + scale_substream_offsets[core]
291 length = scale_substream_offsets[core + 1] - scale_substream_offsets[core]
292 addr_range = NpuAddressRange(region, int(address), int(length))
293 biases.append(addr_range)
294 return biases
295
296
297def create_npu_activation(op: Operation) -> NpuActivation:
298 """Creates fused activation function"""
299 if op.activation is None:
300 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
301 faf = op.activation.op_type
302 act_op = NpuActivationOp.NONE_OR_RELU
303 if faf == Op.Tanh:
304 act_op = NpuActivationOp.TANH
305 elif faf == Op.Sigmoid:
306 act_op = NpuActivationOp.SIGMOID
307 elif faf == Op.LUT:
308 act_op = NpuActivationOp.TABLE_LOOKUP
309 elif not faf.is_relu_op():
310 raise Exception("Unsupported fused_activation_function = " + faf.name)
311
312 act = NpuActivation(act_op)
313 act.min = op.activation.min
314 act.max = op.activation.max
315 act.lookup_table_index = op.activation.lut_index
316 return act
317
318
319def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
320 """Sets common fields of the given operation"""
321 ps = cmd.ps
322 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100323
324 ifm_height = cmd.ifm_box.get_block().height
325 ifm_width = Block.from_shape(cmd.ifm_tensor.shape).width
326 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100327
328 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch)
Louis Verhaard69b31762020-11-17 09:45:20 +0100329 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100330 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100331
332 out_block = cmd.ofm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100333 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch)
Louis Verhaard69b31762020-11-17 09:45:20 +0100334 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100335 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
336
337 if cmd.weight_tensor is not None:
338 npu_op.weights = create_weights(cmd.weight_tensor, cmd.weight_box, arch)
339 if cmd.scale_tensor is not None:
340 npu_op.biases = create_biases(cmd.weight_tensor, cmd.scale_tensor, cmd.weight_box, arch)
341 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100342 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
343 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100344 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
345
346 if not op.type.is_elementwise_op():
347 npu_op.padding = create_padding(cmd, op)
348 npu_op.kernel = to_npu_kernel(op.kernel)
349 npu_op.ifm_upscale = get_upscale(op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100350 return npu_op
351
352
353def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
354 """Converts the command to NpuConv2DOperation"""
355 npu_op = NpuConv2DOperation()
356 set_common_op_fields(npu_op, cmd, arch)
357 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
358 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
359 else:
360 npu_op.block_traversal = block_traversal_map[cmd.weight_tensor.block_traversal]
361 return npu_op
362
363
364def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
365 """Converts the command to NpuConvDepthWiseOperation"""
366 npu_op = NpuConvDepthWiseOperation()
367 set_common_op_fields(npu_op, cmd, arch)
368 return npu_op
369
370
371def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
372 """Converts the command to NpuPoolingOperation"""
373 ps = cmd.ps
374 op = ps.primary_op
375 pool_op = NpuPoolingOp.AVERAGE
376 if op.type.is_maxpool_op():
377 pool_op = NpuPoolingOp.MAX
378 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
379 pool_op = NpuPoolingOp.AVERAGE
380 elif op.type == Op.ReduceSum:
381 pool_op = NpuPoolingOp.REDUCE_SUM
382 else:
383 assert 0, f"Unknown pool type {op.type}"
384 npu_op = NpuPoolingOperation(pool_op)
385 set_common_op_fields(npu_op, cmd, arch)
386 # Pooling specific info
387 if op.type == Op.ResizeBilinear and "rescale" in op.attrs:
388 npu_op.rescale = op.attrs["rescale"]
389 return npu_op
390
391
392def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
393 """Converts the command to NpuElementWiseOperation"""
394 ps = cmd.ps
395 op = ps.primary_op
396 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
397 elemwise_op = elementwise_op_map[op.type]
398 npu_op = NpuElementWiseOperation(elemwise_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100399 if elemwise_op not in UNARY_ELEMWISE_OPS:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100400 if not ifm_ifm2_correct_order(cmd.ifm_tensor.shape, cmd.ifm2_tensor.shape):
401 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
402 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
403 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
404 npu_op.reversed_operands = True
405 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch)
406 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
407 if cmd.ifm2_tensor.shape == []:
408 # scalar
409 assert cmd.ifm2_tensor.quant_values.size == 1
410 npu_op.ifm2_scalar = cmd.ifm2_tensor.values.item(0)
411 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
412 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100413 ifm2_blk = cmd.ifm2_box.get_block()
414 ifm2_width = Block.from_shape(cmd.ifm2_tensor.shape).width
415 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100416 set_common_op_fields(npu_op, cmd, arch)
417 # Check if output scale needs to be overridden
418 output_scale = None
419 if op.type == Op.Add and "resizebilinear" in op.attrs:
420 # Force output scale same as the input scale for
421 # resizebilinear 1x1 that is converted to add
422 output_scale = npu_op.ifm2.quantization.scale_f32
423 if op.type == Op.LeakyRelu:
424 output_scale = op.attrs["alpha"]
425 if op.type in (Op.Add, Op.Sub) and "rescale" in op.attrs:
426 npu_op.rescale = op.attrs.get("rescale")
427 if op.type in (Op.Add, Op.Mul, Op.Sub):
428 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
429 output_scale = 1 / 0x3000
430 if output_scale is not None:
431 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
432 return npu_op
433
434
435def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
436 """Converts the command to NpuDmaOperation"""
437 src_region = get_region(cmd.in_tensor, arch)
438 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100439 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100440 else:
441 dest_region = get_region(cmd.out_tensor, arch)
442
443 start_coord = cmd.box.start_coord
444 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
445 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
446
447 if cmd.in_tensor.compressed_values is not None:
448 if cmd.out_tensor.purpose == TensorPurpose.FSBias:
449 sz = cmd.in_tensor.storage_size()
450 else:
451 stream_index = cmd.in_tensor.compressed_stream_index_from_coord(start_coord)
452 sz = cmd.in_tensor.size_of_compressed_stream(stream_index)
453 else:
454 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
455 src = NpuAddressRange(src_region, int(src_addr), int(sz))
456 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
457 return NpuDmaOperation(src, dest)
458
459
460def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
461 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100462 npu_op: NpuOperation
463 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100464 npu_op = create_dma_op(cmd, arch)
Dwight Lidman9b43f842020-12-08 17:56:44 +0100465 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100466 npu_block_type = cmd.ps.primary_op.type.npu_block_type
467 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
468 npu_op = create_npu_conv2d_op(cmd, arch)
469 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
470 npu_op = create_npu_conv_depthwise_op(cmd, arch)
471 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
472 npu_op = create_npu_pool_op(cmd, arch)
473 elif npu_block_type == NpuBlockType.ElementWise:
474 npu_op = create_npu_elementwise_op(cmd, arch)
475 else:
476 assert 0, f"Unknown command type {npu_block_type}"
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100477 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100478
479
480def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
481 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
482 # Convert high level command stream to list of NpuOperation
483 npu_op_list = []
484 npu_op_to_cmd = dict() # map from npu op to high level command
485 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100486 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100487 print("Warning: Skipping register command stream generation for", cmd.ps)
488 else:
489 npu_op = convert_command_to_npu_op(cmd, arch)
490 npu_op_list.append(npu_op)
491 npu_op_to_cmd[npu_op] = cmd
492 # Generate register commands
493 stream_id = DebugDatabase.add_stream(sg)
494 DebugDatabase.set_stream_offset(sg, 0) # Default to zero, can only set during file writing
495
496 def add_to_debug_db(npu_op: NpuOperation, offset: int):
497 """Adds info to the debug database"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100498 if not isinstance(npu_op, NpuDmaOperation):
Louis Verhaard1e170182020-11-26 11:42:04 +0100499 cmd = npu_op_to_cmd[npu_op]
500 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
501
Dwight Lidman9b43f842020-12-08 17:56:44 +0100502 sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db, npu_op_to_cmd)