blob: 77501210570b23b955d53ed7745ac4743ecd0172 [file] [log] [blame]
Louis Verhaarde8a5a782020-11-02 18:04:27 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
20from typing import List
21from typing import Optional
22
23from .api import NpuActivation
24from .api import NpuActivationOp
25from .api import NpuAddressRange
26from .api import NpuBlockOperation
27from .api import NpuBlockTraversal
28from .api import NpuConv2DOperation
29from .api import NpuConvDepthWiseOperation
30from .api import NpuDataType
31from .api import NpuDmaOperation
32from .api import NpuElementWiseOp
33from .api import NpuElementWiseOperation
34from .api import NpuFeatureMap
35from .api import NpuKernel
36from .api import NpuLayout
37from .api import NpuOperation
38from .api import NpuPadding
39from .api import NpuPoolingOp
40from .api import NpuPoolingOperation
41from .api import NpuQuantization
42from .api import NpuResamplingMode
43from .api import NpuRoundingMode
44from .api import NpuShape3D
45from .api import NpuTileBox
46from .architecture_features import ArchitectureFeatures
47from .data_type import DataType
48from .high_level_command_stream import Box
49from .high_level_command_stream import Command
50from .high_level_command_stream import CommandType
51from .high_level_command_stream import DMA
52from .high_level_command_stream import NpuStripe
53from .operation import Kernel
54from .operation import NpuBlockType
55from .operation import Op
56from .operation import Operation
57from .tensor import MemType
58from .tensor import Tensor
59from .tensor import TensorBlockTraversal
60from .tensor import TensorFormat
61from .tensor import TensorPurpose
62
63
64unary_elementwise_ops = set((NpuElementWiseOp.ABS, NpuElementWiseOp.LRELU, NpuElementWiseOp.CLZ,))
65
66
67class BasePointerIndex(IntEnum):
68 WeightTensor = 0 # base address index for the Weight tensor
69 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
70 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
71 Mem2Mem = (1 << 8) | (3 << 0) # base address slot for memory 2 memory transfer
72
73
74dtype_map = {
75 DataType.uint8: NpuDataType.UINT8,
76 DataType.int8: NpuDataType.INT8,
77 DataType.uint16: NpuDataType.UINT16,
78 DataType.int16: NpuDataType.INT16,
79 DataType.int32: NpuDataType.INT32,
80}
81
82
83block_traversal_map = {
84 TensorBlockTraversal.DepthFirst: NpuBlockTraversal.DEPTH_FIRST,
85 TensorBlockTraversal.PartKernelFirst: NpuBlockTraversal.PART_KERNEL_FIRST,
86}
87
88
89# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
90elementwise_op_map = {
91 Op.Mul: NpuElementWiseOp.MUL,
92 Op.Add: NpuElementWiseOp.ADD,
93 Op.Sub: NpuElementWiseOp.SUB,
94 Op.Minimum: NpuElementWiseOp.MIN,
95 Op.Maximum: NpuElementWiseOp.MAX,
96 Op.LeakyRelu: NpuElementWiseOp.LRELU,
97 Op.Abs: NpuElementWiseOp.ABS,
98 Op.CLZ: NpuElementWiseOp.CLZ,
99 Op.SHR: NpuElementWiseOp.SHR,
100 Op.SHL: NpuElementWiseOp.SHL,
101}
102
103
104def to_npu_kernel(kernel: Kernel) -> NpuKernel:
105 """Converts the given internally used kernel object to NpuKernel (of public API)"""
106 return NpuKernel(
107 kernel.width, kernel.height, kernel.stride.x, kernel.stride.y, kernel.dilation.x, kernel.dilation.y
108 )
109
110
111def to_kernel(kernel: Optional[NpuKernel]) -> Kernel:
112 """Converts the given public API object to Kernel (used internally)"""
113 if kernel is None:
114 return Kernel(1, 1)
115 return Kernel(kernel.width, kernel.height, kernel.stride_x, kernel.stride_y, kernel.dilation_x, kernel.dilation_y)
116
117
118def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
119 if ifm_shape == []:
120 # Scalar needs to be in IFM2
121 return False
122 if ifm2_shape == []:
123 return True
124
125 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
126 if ifm != ifm2 and ifm == 1:
127 # Broadcasted FM needs to be in IFM2
128 return False
129 return True
130
131
132def get_rounding_mode(op: Operation) -> NpuRoundingMode:
133 """Specifies type of rounding to be used"""
134 rounding_mode = NpuRoundingMode.TFL
135 if op.type == Op.ResizeBilinear:
136 rounding_mode = NpuRoundingMode.TRUNCATE
137 elif (
138 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
139 and op.ifm.dtype == DataType.int16
140 ):
141 rounding_mode = NpuRoundingMode.NATURAL
142 elif op.type.is_avgpool_op() and op.memory_function == Op.ConcatSliceWrite and op.kernel.elements_wh() == 1:
143 rounding_mode = NpuRoundingMode.NATURAL
144 rounding_mode = op.attrs.get("rounding_mode", rounding_mode)
145 return rounding_mode
146
147
148def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
149 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
150 return NpuPadding(top=0, left=0, bottom=0, right=0)
151 explicit_padding = list(primary_op.attrs["explicit_padding"]) # (top, left, bottom, right)
152
153 # Check if this is for horizontal ifm streaming
154 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
155 explicit_padding[0] = cmd.pad_top
156 explicit_padding[2] = cmd.pad_bottom
157
158 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
159 # because of activation function needed to be fused.
160 if cmd.ifm_box.start_coord[-2] > 0:
161 explicit_padding[1] = 0
162 if cmd.ifm_box.end_coord[-2] < cmd.ifm_tensor.shape[-2]:
163 explicit_padding[3] = 0
164 return NpuPadding(
165 top=explicit_padding[0], left=explicit_padding[1], bottom=explicit_padding[2], right=explicit_padding[3]
166 )
167
168
169def get_region(tens: Tensor, arch: ArchitectureFeatures) -> int:
170 if arch.feature_map_storage_mem_area == arch.fast_storage_mem_area:
171 base_ptr_idx_map = {
172 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
173 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
174 MemType.Scratch: BasePointerIndex.ScratchTensor,
175 MemType.Scratch_fast: BasePointerIndex.ScratchTensor,
176 }
177 else:
178 base_ptr_idx_map = {
179 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
180 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
181 MemType.Scratch: BasePointerIndex.ScratchTensor,
182 MemType.Scratch_fast: BasePointerIndex.ScratchFastTensor,
183 }
184 return int(base_ptr_idx_map[tens.mem_type])
185
186
187def get_upscale(op: Operation) -> NpuResamplingMode:
188 upscale = NpuResamplingMode.NONE
189 if op.type == Op.ResizeBilinear:
190 # perform nearest neighbor upscale
191 upscale = NpuResamplingMode.NEAREST
192 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
193 # perform insert zero upscale
194 upscale = NpuResamplingMode.TRANSPOSE
195 return upscale
196
197
198def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
199 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
200 shape = ifm_box.get_size_shape()
201 else:
202 shape = ofm_box.get_size_shape()
203 return shape[-1]
204
205
206def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
207 """Checks if quantization should use 0 as zero point"""
208 if tens.dtype == DataType.int32 and is_ifm_tensor:
209 return True
210 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
211 return False
212 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
213 forced_ofm_quantization = ps.primary_op.forced_output_quantization
214 use_0 = (
215 (ps.primary_op.activation is None or forced_ofm_quantization is not None)
216 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
217 and not fused_quantize
218 )
219 return use_0
220
221
222def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
223 """Gets quantization for IFM/IFM2"""
224 if tens.quantization is None:
225 return None
226 if use_zero_point_0(ps, tens, True):
227 zero_point = 0
228 else:
229 zero_point = int(tens.quantization.zero_point)
230 return NpuQuantization(scale_f32=tens.quantization.scale_f32, zero_point=zero_point)
231
232
233def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
234 """Gets quantization for OFM"""
235 op = ps.primary_op
236 # Check if operation's output quantization is should be used instead of the output tensor's quantization
237 # (used in LUTs)
238 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
239 if ofm_quant is None:
240 return None
241 if use_zero_point_0(ps, tens, False):
242 zero_point = 0
243 else:
244 zero_point = int(ofm_quant.zero_point)
245 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
246
247
248def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures) -> NpuFeatureMap:
249 """Creates feature map with common fields populated"""
250 fm = NpuFeatureMap()
251 fm.region = get_region(tens, arch)
252 fm.data_type = dtype_map[tens.dtype]
253 if tens.format == TensorFormat.NHWC:
254 fm.layout = NpuLayout.NHWC
255 elif tens.format == TensorFormat.NHCWB16:
256 fm.layout = NpuLayout.NHCWB16
257 else:
258 assert 0, "Incorrect tensor format"
259 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(box.start_coord, box.end_coord)
260 for idx, addr in enumerate(addresses):
261 if addr is None:
262 addresses[idx] = 0
263 fm.tiles = NpuTileBox(
264 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
265 )
266 strides = tens.get_strides()
267 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
268 return fm
269
270
271def create_weights(weight_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures) -> List[NpuAddressRange]:
272 """Returns address ranges for weights"""
273 weights = []
274 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
275 weight_substream_offsets = weight_tensor.compressed_values_substream_offsets[stream_index]
276 substreams = len(weight_substream_offsets) - 1 # Offset list must terminate with full stream length
277
278 # Extract weight substream offsets and calculate their lengths
279 assert len(weight_substream_offsets) > 1 and (weight_substream_offsets[0] == 0)
280 weight_addr = weight_tensor.address_for_coordinate(weight_box.start_coord)
281 region = get_region(weight_tensor, arch)
282 for core in range(substreams):
283 address = weight_addr + weight_substream_offsets[core]
284 length = weight_substream_offsets[core + 1] - weight_substream_offsets[core]
285 addr_range = NpuAddressRange(region, int(address), int(length))
286 weights.append(addr_range)
287 return weights
288
289
290def create_biases(
291 weight_tensor: Tensor, scale_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures
292) -> List[NpuAddressRange]:
293 """Returns address ranges for biases"""
294 biases = []
295 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
296 scale_substream_offsets = scale_tensor.compressed_values_substream_offsets[stream_index]
297 substreams = len(scale_substream_offsets) - 1 # Offset list must terminate with full stream length
298
299 # Extract scale substream offsets and calculate their lengths
300 assert len(scale_substream_offsets) > 1 and (scale_substream_offsets[0] == 0)
301 scale_addr = scale_tensor.address_for_coordinate(weight_box.start_coord[-1:])
302
303 region = get_region(scale_tensor, arch)
304 for core in range(substreams):
305 address = scale_addr + scale_substream_offsets[core]
306 length = scale_substream_offsets[core + 1] - scale_substream_offsets[core]
307 addr_range = NpuAddressRange(region, int(address), int(length))
308 biases.append(addr_range)
309 return biases
310
311
312def create_npu_activation(op: Operation) -> NpuActivation:
313 """Creates fused activation function"""
314 if op.activation is None:
315 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
316 faf = op.activation.op_type
317 act_op = NpuActivationOp.NONE_OR_RELU
318 if faf == Op.Tanh:
319 act_op = NpuActivationOp.TANH
320 elif faf == Op.Sigmoid:
321 act_op = NpuActivationOp.SIGMOID
322 elif faf == Op.LUT:
323 act_op = NpuActivationOp.TABLE_LOOKUP
324 elif not faf.is_relu_op():
325 raise Exception("Unsupported fused_activation_function = " + faf.name)
326
327 act = NpuActivation(act_op)
328 act.min = op.activation.min
329 act.max = op.activation.max
330 act.lookup_table_index = op.activation.lut_index
331 return act
332
333
334def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
335 """Sets common fields of the given operation"""
336 ps = cmd.ps
337 op = ps.primary_op
338 in_shape = cmd.ifm_box.get_size_shape()
339 out_shape = cmd.ofm_box.get_size_shape()
340 ofm_height = out_shape[-3] if len(out_shape) >= 4 else 1
341 ofm_width = out_shape[-2] if len(out_shape) >= 2 else 1
342 ofm_depth = out_shape[-1] if len(out_shape) >= 1 else 1
343 ifm_height = in_shape[-3] if len(in_shape) >= 4 else 1
344 if op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
345 ifm_depth = in_shape[-1] if len(in_shape) >= 1 else 1
346 else:
347 ifm_depth = ofm_depth
348
349 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch)
350 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=cmd.ifm_tensor.shape[-2], depth=ifm_depth)
351 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
352 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch)
353 npu_op.ofm.shape = NpuShape3D(height=ofm_height, width=ofm_width, depth=ofm_depth)
354 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
355
356 if cmd.weight_tensor is not None:
357 npu_op.weights = create_weights(cmd.weight_tensor, cmd.weight_box, arch)
358 if cmd.scale_tensor is not None:
359 npu_op.biases = create_biases(cmd.weight_tensor, cmd.scale_tensor, cmd.weight_box, arch)
360 npu_op.activation = create_npu_activation(op)
361 npu_op.rounding_mode = get_rounding_mode(op)
362 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
363
364 if not op.type.is_elementwise_op():
365 npu_op.padding = create_padding(cmd, op)
366 npu_op.kernel = to_npu_kernel(op.kernel)
367 npu_op.ifm_upscale = get_upscale(op)
368 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
369 return npu_op
370
371
372def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
373 """Converts the command to NpuConv2DOperation"""
374 npu_op = NpuConv2DOperation()
375 set_common_op_fields(npu_op, cmd, arch)
376 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
377 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
378 else:
379 npu_op.block_traversal = block_traversal_map[cmd.weight_tensor.block_traversal]
380 return npu_op
381
382
383def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
384 """Converts the command to NpuConvDepthWiseOperation"""
385 npu_op = NpuConvDepthWiseOperation()
386 set_common_op_fields(npu_op, cmd, arch)
387 return npu_op
388
389
390def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
391 """Converts the command to NpuPoolingOperation"""
392 ps = cmd.ps
393 op = ps.primary_op
394 pool_op = NpuPoolingOp.AVERAGE
395 if op.type.is_maxpool_op():
396 pool_op = NpuPoolingOp.MAX
397 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
398 pool_op = NpuPoolingOp.AVERAGE
399 elif op.type == Op.ReduceSum:
400 pool_op = NpuPoolingOp.REDUCE_SUM
401 else:
402 assert 0, f"Unknown pool type {op.type}"
403 npu_op = NpuPoolingOperation(pool_op)
404 set_common_op_fields(npu_op, cmd, arch)
405 # Pooling specific info
406 if op.type == Op.ResizeBilinear and "rescale" in op.attrs:
407 npu_op.rescale = op.attrs["rescale"]
408 return npu_op
409
410
411def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
412 """Converts the command to NpuElementWiseOperation"""
413 ps = cmd.ps
414 op = ps.primary_op
415 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
416 elemwise_op = elementwise_op_map[op.type]
417 npu_op = NpuElementWiseOperation(elemwise_op)
418 if elemwise_op not in unary_elementwise_ops:
419 if not ifm_ifm2_correct_order(cmd.ifm_tensor.shape, cmd.ifm2_tensor.shape):
420 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
421 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
422 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
423 npu_op.reversed_operands = True
424 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch)
425 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
426 if cmd.ifm2_tensor.shape == []:
427 # scalar
428 assert cmd.ifm2_tensor.quant_values.size == 1
429 npu_op.ifm2_scalar = cmd.ifm2_tensor.values.item(0)
430 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
431 else:
432 box_shp = cmd.ifm2_box.get_size_shape()
433 height = box_shp[-3] if len(box_shp) >= 3 else 1
434 npu_op.ifm2.shape = NpuShape3D(height=height, width=cmd.ifm2_tensor.shape[-2], depth=box_shp[-1])
435 set_common_op_fields(npu_op, cmd, arch)
436 # Check if output scale needs to be overridden
437 output_scale = None
438 if op.type == Op.Add and "resizebilinear" in op.attrs:
439 # Force output scale same as the input scale for
440 # resizebilinear 1x1 that is converted to add
441 output_scale = npu_op.ifm2.quantization.scale_f32
442 if op.type == Op.LeakyRelu:
443 output_scale = op.attrs["alpha"]
444 if op.type in (Op.Add, Op.Sub) and "rescale" in op.attrs:
445 npu_op.rescale = op.attrs.get("rescale")
446 if op.type in (Op.Add, Op.Mul, Op.Sub):
447 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
448 output_scale = 1 / 0x3000
449 if output_scale is not None:
450 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
451 return npu_op
452
453
454def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
455 """Converts the command to NpuDmaOperation"""
456 src_region = get_region(cmd.in_tensor, arch)
457 if cmd.out_tensor.purpose == TensorPurpose.LUT:
458 dest_region = BasePointerIndex.Mem2Mem
459 else:
460 dest_region = get_region(cmd.out_tensor, arch)
461
462 start_coord = cmd.box.start_coord
463 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
464 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
465
466 if cmd.in_tensor.compressed_values is not None:
467 if cmd.out_tensor.purpose == TensorPurpose.FSBias:
468 sz = cmd.in_tensor.storage_size()
469 else:
470 stream_index = cmd.in_tensor.compressed_stream_index_from_coord(start_coord)
471 sz = cmd.in_tensor.size_of_compressed_stream(stream_index)
472 else:
473 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
474 src = NpuAddressRange(src_region, int(src_addr), int(sz))
475 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
476 return NpuDmaOperation(src, dest)
477
478
479def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
480 """Converts the high level command to NpuOperation"""
481 if cmd.cmdtype == CommandType.DMA:
482 npu_op = create_dma_op(cmd, arch)
483 elif cmd.cmdtype == CommandType.NpuStripe:
484 npu_block_type = cmd.ps.primary_op.type.npu_block_type
485 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
486 npu_op = create_npu_conv2d_op(cmd, arch)
487 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
488 npu_op = create_npu_conv_depthwise_op(cmd, arch)
489 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
490 npu_op = create_npu_pool_op(cmd, arch)
491 elif npu_block_type == NpuBlockType.ElementWise:
492 npu_op = create_npu_elementwise_op(cmd, arch)
493 else:
494 assert 0, f"Unknown command type {npu_block_type}"
495 # add a link to the high level command for debugging purposes
496 npu_op.cmd = cmd
497 return npu_op