blob: efd8a03d007e349fe2e45c38971d52dd82c77b57 [file] [log] [blame]
Louis Verhaarde8a5a782020-11-02 18:04:27 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
20from typing import List
21from typing import Optional
22
23from .api import NpuActivation
24from .api import NpuActivationOp
25from .api import NpuAddressRange
26from .api import NpuBlockOperation
27from .api import NpuBlockTraversal
28from .api import NpuConv2DOperation
29from .api import NpuConvDepthWiseOperation
30from .api import NpuDataType
31from .api import NpuDmaOperation
32from .api import NpuElementWiseOp
33from .api import NpuElementWiseOperation
34from .api import NpuFeatureMap
35from .api import NpuKernel
36from .api import NpuLayout
37from .api import NpuOperation
38from .api import NpuPadding
39from .api import NpuPoolingOp
40from .api import NpuPoolingOperation
41from .api import NpuQuantization
42from .api import NpuResamplingMode
43from .api import NpuRoundingMode
44from .api import NpuShape3D
45from .api import NpuTileBox
46from .architecture_features import ArchitectureFeatures
Louis Verhaard69b31762020-11-17 09:45:20 +010047from .architecture_features import Block
Louis Verhaarde8a5a782020-11-02 18:04:27 +010048from .data_type import DataType
49from .high_level_command_stream import Box
50from .high_level_command_stream import Command
51from .high_level_command_stream import CommandType
52from .high_level_command_stream import DMA
53from .high_level_command_stream import NpuStripe
54from .operation import Kernel
55from .operation import NpuBlockType
56from .operation import Op
57from .operation import Operation
58from .tensor import MemType
59from .tensor import Tensor
60from .tensor import TensorBlockTraversal
61from .tensor import TensorFormat
62from .tensor import TensorPurpose
63
64
65unary_elementwise_ops = set((NpuElementWiseOp.ABS, NpuElementWiseOp.LRELU, NpuElementWiseOp.CLZ,))
66
67
68class BasePointerIndex(IntEnum):
69 WeightTensor = 0 # base address index for the Weight tensor
70 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
71 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
72 Mem2Mem = (1 << 8) | (3 << 0) # base address slot for memory 2 memory transfer
73
74
75dtype_map = {
76 DataType.uint8: NpuDataType.UINT8,
77 DataType.int8: NpuDataType.INT8,
78 DataType.uint16: NpuDataType.UINT16,
79 DataType.int16: NpuDataType.INT16,
80 DataType.int32: NpuDataType.INT32,
81}
82
83
84block_traversal_map = {
85 TensorBlockTraversal.DepthFirst: NpuBlockTraversal.DEPTH_FIRST,
86 TensorBlockTraversal.PartKernelFirst: NpuBlockTraversal.PART_KERNEL_FIRST,
87}
88
89
90# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
91elementwise_op_map = {
92 Op.Mul: NpuElementWiseOp.MUL,
93 Op.Add: NpuElementWiseOp.ADD,
94 Op.Sub: NpuElementWiseOp.SUB,
95 Op.Minimum: NpuElementWiseOp.MIN,
96 Op.Maximum: NpuElementWiseOp.MAX,
97 Op.LeakyRelu: NpuElementWiseOp.LRELU,
98 Op.Abs: NpuElementWiseOp.ABS,
99 Op.CLZ: NpuElementWiseOp.CLZ,
100 Op.SHR: NpuElementWiseOp.SHR,
101 Op.SHL: NpuElementWiseOp.SHL,
102}
103
104
105def to_npu_kernel(kernel: Kernel) -> NpuKernel:
106 """Converts the given internally used kernel object to NpuKernel (of public API)"""
107 return NpuKernel(
108 kernel.width, kernel.height, kernel.stride.x, kernel.stride.y, kernel.dilation.x, kernel.dilation.y
109 )
110
111
112def to_kernel(kernel: Optional[NpuKernel]) -> Kernel:
113 """Converts the given public API object to Kernel (used internally)"""
114 if kernel is None:
115 return Kernel(1, 1)
116 return Kernel(kernel.width, kernel.height, kernel.stride_x, kernel.stride_y, kernel.dilation_x, kernel.dilation_y)
117
118
119def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
120 if ifm_shape == []:
121 # Scalar needs to be in IFM2
122 return False
123 if ifm2_shape == []:
124 return True
125
126 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
127 if ifm != ifm2 and ifm == 1:
128 # Broadcasted FM needs to be in IFM2
129 return False
130 return True
131
132
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100133def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100134 """Specifies type of rounding to be used"""
135 rounding_mode = NpuRoundingMode.TFL
136 if op.type == Op.ResizeBilinear:
137 rounding_mode = NpuRoundingMode.TRUNCATE
138 elif (
139 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
140 and op.ifm.dtype == DataType.int16
141 ):
142 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100143 elif (
144 not fused_quantize
145 and op.type.is_avgpool_op()
146 and op.memory_function == Op.ConcatSliceWrite
147 and op.kernel.elements_wh() == 1
148 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100149 rounding_mode = NpuRoundingMode.NATURAL
150 rounding_mode = op.attrs.get("rounding_mode", rounding_mode)
151 return rounding_mode
152
153
154def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
155 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
156 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100157 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100158
159 # Check if this is for horizontal ifm streaming
160 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100161 top = cmd.pad_top
162 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100163
164 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
165 # because of activation function needed to be fused.
Andreas Nevalainen083f1032020-11-18 10:45:50 +0100166 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > 0:
Louis Verhaard69b31762020-11-17 09:45:20 +0100167 left = 0
Andreas Nevalainen083f1032020-11-18 10:45:50 +0100168 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < Block.from_shape(cmd.ifm_tensor.shape).width:
Louis Verhaard69b31762020-11-17 09:45:20 +0100169 right = 0
170 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100171
172
173def get_region(tens: Tensor, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000174 base_ptr_idx_map = {
175 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
176 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
177 MemType.Scratch: BasePointerIndex.ScratchTensor,
178 }
179
180 if arch.is_spilling_enabled():
181 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100182 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000183 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
184
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100185 return int(base_ptr_idx_map[tens.mem_type])
186
187
188def get_upscale(op: Operation) -> NpuResamplingMode:
189 upscale = NpuResamplingMode.NONE
190 if op.type == Op.ResizeBilinear:
191 # perform nearest neighbor upscale
192 upscale = NpuResamplingMode.NEAREST
193 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
194 # perform insert zero upscale
195 upscale = NpuResamplingMode.TRANSPOSE
196 return upscale
197
198
199def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
200 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100201 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100202 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100203 block = ofm_box.get_block()
204 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100205
206
207def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
208 """Checks if quantization should use 0 as zero point"""
209 if tens.dtype == DataType.int32 and is_ifm_tensor:
210 return True
211 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
212 return False
213 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
214 forced_ofm_quantization = ps.primary_op.forced_output_quantization
215 use_0 = (
216 (ps.primary_op.activation is None or forced_ofm_quantization is not None)
217 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
218 and not fused_quantize
219 )
220 return use_0
221
222
223def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
224 """Gets quantization for IFM/IFM2"""
225 if tens.quantization is None:
226 return None
227 if use_zero_point_0(ps, tens, True):
228 zero_point = 0
229 else:
230 zero_point = int(tens.quantization.zero_point)
231 return NpuQuantization(scale_f32=tens.quantization.scale_f32, zero_point=zero_point)
232
233
234def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
235 """Gets quantization for OFM"""
236 op = ps.primary_op
237 # Check if operation's output quantization is should be used instead of the output tensor's quantization
238 # (used in LUTs)
239 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
240 if ofm_quant is None:
241 return None
242 if use_zero_point_0(ps, tens, False):
243 zero_point = 0
244 else:
245 zero_point = int(ofm_quant.zero_point)
246 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
247
248
249def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures) -> NpuFeatureMap:
250 """Creates feature map with common fields populated"""
251 fm = NpuFeatureMap()
252 fm.region = get_region(tens, arch)
253 fm.data_type = dtype_map[tens.dtype]
254 if tens.format == TensorFormat.NHWC:
255 fm.layout = NpuLayout.NHWC
256 elif tens.format == TensorFormat.NHCWB16:
257 fm.layout = NpuLayout.NHCWB16
258 else:
259 assert 0, "Incorrect tensor format"
260 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(box.start_coord, box.end_coord)
261 for idx, addr in enumerate(addresses):
262 if addr is None:
263 addresses[idx] = 0
264 fm.tiles = NpuTileBox(
265 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
266 )
267 strides = tens.get_strides()
268 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
269 return fm
270
271
272def create_weights(weight_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures) -> List[NpuAddressRange]:
273 """Returns address ranges for weights"""
274 weights = []
275 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
276 weight_substream_offsets = weight_tensor.compressed_values_substream_offsets[stream_index]
277 substreams = len(weight_substream_offsets) - 1 # Offset list must terminate with full stream length
278
279 # Extract weight substream offsets and calculate their lengths
280 assert len(weight_substream_offsets) > 1 and (weight_substream_offsets[0] == 0)
281 weight_addr = weight_tensor.address_for_coordinate(weight_box.start_coord)
282 region = get_region(weight_tensor, arch)
283 for core in range(substreams):
284 address = weight_addr + weight_substream_offsets[core]
285 length = weight_substream_offsets[core + 1] - weight_substream_offsets[core]
286 addr_range = NpuAddressRange(region, int(address), int(length))
287 weights.append(addr_range)
288 return weights
289
290
291def create_biases(
292 weight_tensor: Tensor, scale_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures
293) -> List[NpuAddressRange]:
294 """Returns address ranges for biases"""
295 biases = []
296 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
297 scale_substream_offsets = scale_tensor.compressed_values_substream_offsets[stream_index]
298 substreams = len(scale_substream_offsets) - 1 # Offset list must terminate with full stream length
299
300 # Extract scale substream offsets and calculate their lengths
301 assert len(scale_substream_offsets) > 1 and (scale_substream_offsets[0] == 0)
302 scale_addr = scale_tensor.address_for_coordinate(weight_box.start_coord[-1:])
303
304 region = get_region(scale_tensor, arch)
305 for core in range(substreams):
306 address = scale_addr + scale_substream_offsets[core]
307 length = scale_substream_offsets[core + 1] - scale_substream_offsets[core]
308 addr_range = NpuAddressRange(region, int(address), int(length))
309 biases.append(addr_range)
310 return biases
311
312
313def create_npu_activation(op: Operation) -> NpuActivation:
314 """Creates fused activation function"""
315 if op.activation is None:
316 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
317 faf = op.activation.op_type
318 act_op = NpuActivationOp.NONE_OR_RELU
319 if faf == Op.Tanh:
320 act_op = NpuActivationOp.TANH
321 elif faf == Op.Sigmoid:
322 act_op = NpuActivationOp.SIGMOID
323 elif faf == Op.LUT:
324 act_op = NpuActivationOp.TABLE_LOOKUP
325 elif not faf.is_relu_op():
326 raise Exception("Unsupported fused_activation_function = " + faf.name)
327
328 act = NpuActivation(act_op)
329 act.min = op.activation.min
330 act.max = op.activation.max
331 act.lookup_table_index = op.activation.lut_index
332 return act
333
334
335def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
336 """Sets common fields of the given operation"""
337 ps = cmd.ps
338 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100339
340 ifm_height = cmd.ifm_box.get_block().height
341 ifm_width = Block.from_shape(cmd.ifm_tensor.shape).width
342 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100343
344 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch)
Louis Verhaard69b31762020-11-17 09:45:20 +0100345 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100346 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100347
348 out_block = cmd.ofm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100349 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch)
Louis Verhaard69b31762020-11-17 09:45:20 +0100350 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100351 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
352
353 if cmd.weight_tensor is not None:
354 npu_op.weights = create_weights(cmd.weight_tensor, cmd.weight_box, arch)
355 if cmd.scale_tensor is not None:
356 npu_op.biases = create_biases(cmd.weight_tensor, cmd.scale_tensor, cmd.weight_box, arch)
357 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100358 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
359 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100360 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
361
362 if not op.type.is_elementwise_op():
363 npu_op.padding = create_padding(cmd, op)
364 npu_op.kernel = to_npu_kernel(op.kernel)
365 npu_op.ifm_upscale = get_upscale(op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100366 return npu_op
367
368
369def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
370 """Converts the command to NpuConv2DOperation"""
371 npu_op = NpuConv2DOperation()
372 set_common_op_fields(npu_op, cmd, arch)
373 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
374 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
375 else:
376 npu_op.block_traversal = block_traversal_map[cmd.weight_tensor.block_traversal]
377 return npu_op
378
379
380def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
381 """Converts the command to NpuConvDepthWiseOperation"""
382 npu_op = NpuConvDepthWiseOperation()
383 set_common_op_fields(npu_op, cmd, arch)
384 return npu_op
385
386
387def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
388 """Converts the command to NpuPoolingOperation"""
389 ps = cmd.ps
390 op = ps.primary_op
391 pool_op = NpuPoolingOp.AVERAGE
392 if op.type.is_maxpool_op():
393 pool_op = NpuPoolingOp.MAX
394 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
395 pool_op = NpuPoolingOp.AVERAGE
396 elif op.type == Op.ReduceSum:
397 pool_op = NpuPoolingOp.REDUCE_SUM
398 else:
399 assert 0, f"Unknown pool type {op.type}"
400 npu_op = NpuPoolingOperation(pool_op)
401 set_common_op_fields(npu_op, cmd, arch)
402 # Pooling specific info
403 if op.type == Op.ResizeBilinear and "rescale" in op.attrs:
404 npu_op.rescale = op.attrs["rescale"]
405 return npu_op
406
407
408def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
409 """Converts the command to NpuElementWiseOperation"""
410 ps = cmd.ps
411 op = ps.primary_op
412 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
413 elemwise_op = elementwise_op_map[op.type]
414 npu_op = NpuElementWiseOperation(elemwise_op)
415 if elemwise_op not in unary_elementwise_ops:
416 if not ifm_ifm2_correct_order(cmd.ifm_tensor.shape, cmd.ifm2_tensor.shape):
417 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
418 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
419 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
420 npu_op.reversed_operands = True
421 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch)
422 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
423 if cmd.ifm2_tensor.shape == []:
424 # scalar
425 assert cmd.ifm2_tensor.quant_values.size == 1
426 npu_op.ifm2_scalar = cmd.ifm2_tensor.values.item(0)
427 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
428 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100429 ifm2_blk = cmd.ifm2_box.get_block()
430 ifm2_width = Block.from_shape(cmd.ifm2_tensor.shape).width
431 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100432 set_common_op_fields(npu_op, cmd, arch)
433 # Check if output scale needs to be overridden
434 output_scale = None
435 if op.type == Op.Add and "resizebilinear" in op.attrs:
436 # Force output scale same as the input scale for
437 # resizebilinear 1x1 that is converted to add
438 output_scale = npu_op.ifm2.quantization.scale_f32
439 if op.type == Op.LeakyRelu:
440 output_scale = op.attrs["alpha"]
441 if op.type in (Op.Add, Op.Sub) and "rescale" in op.attrs:
442 npu_op.rescale = op.attrs.get("rescale")
443 if op.type in (Op.Add, Op.Mul, Op.Sub):
444 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
445 output_scale = 1 / 0x3000
446 if output_scale is not None:
447 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
448 return npu_op
449
450
451def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
452 """Converts the command to NpuDmaOperation"""
453 src_region = get_region(cmd.in_tensor, arch)
454 if cmd.out_tensor.purpose == TensorPurpose.LUT:
455 dest_region = BasePointerIndex.Mem2Mem
456 else:
457 dest_region = get_region(cmd.out_tensor, arch)
458
459 start_coord = cmd.box.start_coord
460 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
461 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
462
463 if cmd.in_tensor.compressed_values is not None:
464 if cmd.out_tensor.purpose == TensorPurpose.FSBias:
465 sz = cmd.in_tensor.storage_size()
466 else:
467 stream_index = cmd.in_tensor.compressed_stream_index_from_coord(start_coord)
468 sz = cmd.in_tensor.size_of_compressed_stream(stream_index)
469 else:
470 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
471 src = NpuAddressRange(src_region, int(src_addr), int(sz))
472 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
473 return NpuDmaOperation(src, dest)
474
475
476def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
477 """Converts the high level command to NpuOperation"""
478 if cmd.cmdtype == CommandType.DMA:
479 npu_op = create_dma_op(cmd, arch)
480 elif cmd.cmdtype == CommandType.NpuStripe:
481 npu_block_type = cmd.ps.primary_op.type.npu_block_type
482 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
483 npu_op = create_npu_conv2d_op(cmd, arch)
484 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
485 npu_op = create_npu_conv_depthwise_op(cmd, arch)
486 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
487 npu_op = create_npu_pool_op(cmd, arch)
488 elif npu_block_type == NpuBlockType.ElementWise:
489 npu_op = create_npu_elementwise_op(cmd, arch)
490 else:
491 assert 0, f"Unknown command type {npu_block_type}"
492 # add a link to the high level command for debugging purposes
493 npu_op.cmd = cmd
494 return npu_op