blob: 7db4931de7e4ec32819d5bb592e77227c2530c31 [file] [log] [blame]
Louis Verhaarde8a5a782020-11-02 18:04:27 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
20from typing import List
21from typing import Optional
22
23from .api import NpuActivation
24from .api import NpuActivationOp
25from .api import NpuAddressRange
26from .api import NpuBlockOperation
27from .api import NpuBlockTraversal
28from .api import NpuConv2DOperation
29from .api import NpuConvDepthWiseOperation
30from .api import NpuDataType
31from .api import NpuDmaOperation
32from .api import NpuElementWiseOp
33from .api import NpuElementWiseOperation
34from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010035from .api import NpuLayout
36from .api import NpuOperation
37from .api import NpuPadding
38from .api import NpuPoolingOp
39from .api import NpuPoolingOperation
40from .api import NpuQuantization
41from .api import NpuResamplingMode
42from .api import NpuRoundingMode
43from .api import NpuShape3D
44from .api import NpuTileBox
45from .architecture_features import ArchitectureFeatures
Louis Verhaard69b31762020-11-17 09:45:20 +010046from .architecture_features import Block
Louis Verhaarde8a5a782020-11-02 18:04:27 +010047from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010048from .debug_database import DebugDatabase
Louis Verhaarde8a5a782020-11-02 18:04:27 +010049from .high_level_command_stream import Box
50from .high_level_command_stream import Command
51from .high_level_command_stream import CommandType
52from .high_level_command_stream import DMA
53from .high_level_command_stream import NpuStripe
Louis Verhaarde8a5a782020-11-02 18:04:27 +010054from .operation import NpuBlockType
55from .operation import Op
56from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010057from .register_command_stream_generator import generate_command_stream
58from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
59from .register_command_stream_util import is_dma_op
60from .register_command_stream_util import to_npu_kernel
61from .register_command_stream_util import UNARY_ELEMWISE_OPS
Louis Verhaarde8a5a782020-11-02 18:04:27 +010062from .tensor import MemType
63from .tensor import Tensor
64from .tensor import TensorBlockTraversal
65from .tensor import TensorFormat
66from .tensor import TensorPurpose
67
68
Louis Verhaarde8a5a782020-11-02 18:04:27 +010069class BasePointerIndex(IntEnum):
70 WeightTensor = 0 # base address index for the Weight tensor
71 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
72 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010073
74
75dtype_map = {
76 DataType.uint8: NpuDataType.UINT8,
77 DataType.int8: NpuDataType.INT8,
78 DataType.uint16: NpuDataType.UINT16,
79 DataType.int16: NpuDataType.INT16,
80 DataType.int32: NpuDataType.INT32,
81}
82
83
84block_traversal_map = {
85 TensorBlockTraversal.DepthFirst: NpuBlockTraversal.DEPTH_FIRST,
86 TensorBlockTraversal.PartKernelFirst: NpuBlockTraversal.PART_KERNEL_FIRST,
87}
88
89
90# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
91elementwise_op_map = {
92 Op.Mul: NpuElementWiseOp.MUL,
93 Op.Add: NpuElementWiseOp.ADD,
94 Op.Sub: NpuElementWiseOp.SUB,
95 Op.Minimum: NpuElementWiseOp.MIN,
96 Op.Maximum: NpuElementWiseOp.MAX,
97 Op.LeakyRelu: NpuElementWiseOp.LRELU,
98 Op.Abs: NpuElementWiseOp.ABS,
99 Op.CLZ: NpuElementWiseOp.CLZ,
100 Op.SHR: NpuElementWiseOp.SHR,
101 Op.SHL: NpuElementWiseOp.SHL,
102}
103
104
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100105def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
106 if ifm_shape == []:
107 # Scalar needs to be in IFM2
108 return False
109 if ifm2_shape == []:
110 return True
111
112 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
113 if ifm != ifm2 and ifm == 1:
114 # Broadcasted FM needs to be in IFM2
115 return False
116 return True
117
118
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100119def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100120 """Specifies type of rounding to be used"""
121 rounding_mode = NpuRoundingMode.TFL
122 if op.type == Op.ResizeBilinear:
123 rounding_mode = NpuRoundingMode.TRUNCATE
124 elif (
125 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
126 and op.ifm.dtype == DataType.int16
127 ):
128 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100129 elif (
130 not fused_quantize
131 and op.type.is_avgpool_op()
132 and op.memory_function == Op.ConcatSliceWrite
133 and op.kernel.elements_wh() == 1
134 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100135 rounding_mode = NpuRoundingMode.NATURAL
136 rounding_mode = op.attrs.get("rounding_mode", rounding_mode)
137 return rounding_mode
138
139
140def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
141 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
142 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100143 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100144
145 # Check if this is for horizontal ifm streaming
146 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100147 top = cmd.pad_top
148 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100149
150 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
151 # because of activation function needed to be fused.
Andreas Nevalainen083f1032020-11-18 10:45:50 +0100152 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > 0:
Louis Verhaard69b31762020-11-17 09:45:20 +0100153 left = 0
Andreas Nevalainen083f1032020-11-18 10:45:50 +0100154 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < Block.from_shape(cmd.ifm_tensor.shape).width:
Louis Verhaard69b31762020-11-17 09:45:20 +0100155 right = 0
156 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100157
158
159def get_region(tens: Tensor, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000160 base_ptr_idx_map = {
161 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
162 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
163 MemType.Scratch: BasePointerIndex.ScratchTensor,
164 }
165
166 if arch.is_spilling_enabled():
167 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100168 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000169 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
170
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100171 return int(base_ptr_idx_map[tens.mem_type])
172
173
174def get_upscale(op: Operation) -> NpuResamplingMode:
175 upscale = NpuResamplingMode.NONE
176 if op.type == Op.ResizeBilinear:
177 # perform nearest neighbor upscale
178 upscale = NpuResamplingMode.NEAREST
179 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
180 # perform insert zero upscale
181 upscale = NpuResamplingMode.TRANSPOSE
182 return upscale
183
184
185def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
186 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100187 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100188 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100189 block = ofm_box.get_block()
190 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100191
192
193def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
194 """Checks if quantization should use 0 as zero point"""
195 if tens.dtype == DataType.int32 and is_ifm_tensor:
196 return True
197 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
198 return False
199 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
200 forced_ofm_quantization = ps.primary_op.forced_output_quantization
201 use_0 = (
202 (ps.primary_op.activation is None or forced_ofm_quantization is not None)
203 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
204 and not fused_quantize
205 )
206 return use_0
207
208
209def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
210 """Gets quantization for IFM/IFM2"""
211 if tens.quantization is None:
212 return None
213 if use_zero_point_0(ps, tens, True):
214 zero_point = 0
215 else:
216 zero_point = int(tens.quantization.zero_point)
217 return NpuQuantization(scale_f32=tens.quantization.scale_f32, zero_point=zero_point)
218
219
220def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
221 """Gets quantization for OFM"""
222 op = ps.primary_op
223 # Check if operation's output quantization is should be used instead of the output tensor's quantization
224 # (used in LUTs)
225 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
226 if ofm_quant is None:
227 return None
228 if use_zero_point_0(ps, tens, False):
229 zero_point = 0
230 else:
231 zero_point = int(ofm_quant.zero_point)
232 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
233
234
235def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures) -> NpuFeatureMap:
236 """Creates feature map with common fields populated"""
237 fm = NpuFeatureMap()
238 fm.region = get_region(tens, arch)
239 fm.data_type = dtype_map[tens.dtype]
240 if tens.format == TensorFormat.NHWC:
241 fm.layout = NpuLayout.NHWC
242 elif tens.format == TensorFormat.NHCWB16:
243 fm.layout = NpuLayout.NHCWB16
244 else:
245 assert 0, "Incorrect tensor format"
246 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(box.start_coord, box.end_coord)
247 for idx, addr in enumerate(addresses):
248 if addr is None:
249 addresses[idx] = 0
250 fm.tiles = NpuTileBox(
251 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
252 )
253 strides = tens.get_strides()
254 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
255 return fm
256
257
258def create_weights(weight_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures) -> List[NpuAddressRange]:
259 """Returns address ranges for weights"""
260 weights = []
261 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
262 weight_substream_offsets = weight_tensor.compressed_values_substream_offsets[stream_index]
263 substreams = len(weight_substream_offsets) - 1 # Offset list must terminate with full stream length
264
265 # Extract weight substream offsets and calculate their lengths
266 assert len(weight_substream_offsets) > 1 and (weight_substream_offsets[0] == 0)
267 weight_addr = weight_tensor.address_for_coordinate(weight_box.start_coord)
268 region = get_region(weight_tensor, arch)
269 for core in range(substreams):
270 address = weight_addr + weight_substream_offsets[core]
271 length = weight_substream_offsets[core + 1] - weight_substream_offsets[core]
272 addr_range = NpuAddressRange(region, int(address), int(length))
273 weights.append(addr_range)
274 return weights
275
276
277def create_biases(
278 weight_tensor: Tensor, scale_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures
279) -> List[NpuAddressRange]:
280 """Returns address ranges for biases"""
281 biases = []
282 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
283 scale_substream_offsets = scale_tensor.compressed_values_substream_offsets[stream_index]
284 substreams = len(scale_substream_offsets) - 1 # Offset list must terminate with full stream length
285
286 # Extract scale substream offsets and calculate their lengths
287 assert len(scale_substream_offsets) > 1 and (scale_substream_offsets[0] == 0)
288 scale_addr = scale_tensor.address_for_coordinate(weight_box.start_coord[-1:])
289
290 region = get_region(scale_tensor, arch)
291 for core in range(substreams):
292 address = scale_addr + scale_substream_offsets[core]
293 length = scale_substream_offsets[core + 1] - scale_substream_offsets[core]
294 addr_range = NpuAddressRange(region, int(address), int(length))
295 biases.append(addr_range)
296 return biases
297
298
299def create_npu_activation(op: Operation) -> NpuActivation:
300 """Creates fused activation function"""
301 if op.activation is None:
302 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
303 faf = op.activation.op_type
304 act_op = NpuActivationOp.NONE_OR_RELU
305 if faf == Op.Tanh:
306 act_op = NpuActivationOp.TANH
307 elif faf == Op.Sigmoid:
308 act_op = NpuActivationOp.SIGMOID
309 elif faf == Op.LUT:
310 act_op = NpuActivationOp.TABLE_LOOKUP
311 elif not faf.is_relu_op():
312 raise Exception("Unsupported fused_activation_function = " + faf.name)
313
314 act = NpuActivation(act_op)
315 act.min = op.activation.min
316 act.max = op.activation.max
317 act.lookup_table_index = op.activation.lut_index
318 return act
319
320
321def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
322 """Sets common fields of the given operation"""
323 ps = cmd.ps
324 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100325
326 ifm_height = cmd.ifm_box.get_block().height
327 ifm_width = Block.from_shape(cmd.ifm_tensor.shape).width
328 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100329
330 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch)
Louis Verhaard69b31762020-11-17 09:45:20 +0100331 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100332 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100333
334 out_block = cmd.ofm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100335 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch)
Louis Verhaard69b31762020-11-17 09:45:20 +0100336 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100337 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
338
339 if cmd.weight_tensor is not None:
340 npu_op.weights = create_weights(cmd.weight_tensor, cmd.weight_box, arch)
341 if cmd.scale_tensor is not None:
342 npu_op.biases = create_biases(cmd.weight_tensor, cmd.scale_tensor, cmd.weight_box, arch)
343 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100344 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
345 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100346 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
347
348 if not op.type.is_elementwise_op():
349 npu_op.padding = create_padding(cmd, op)
350 npu_op.kernel = to_npu_kernel(op.kernel)
351 npu_op.ifm_upscale = get_upscale(op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100352 return npu_op
353
354
355def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
356 """Converts the command to NpuConv2DOperation"""
357 npu_op = NpuConv2DOperation()
358 set_common_op_fields(npu_op, cmd, arch)
359 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
360 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
361 else:
362 npu_op.block_traversal = block_traversal_map[cmd.weight_tensor.block_traversal]
363 return npu_op
364
365
366def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
367 """Converts the command to NpuConvDepthWiseOperation"""
368 npu_op = NpuConvDepthWiseOperation()
369 set_common_op_fields(npu_op, cmd, arch)
370 return npu_op
371
372
373def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
374 """Converts the command to NpuPoolingOperation"""
375 ps = cmd.ps
376 op = ps.primary_op
377 pool_op = NpuPoolingOp.AVERAGE
378 if op.type.is_maxpool_op():
379 pool_op = NpuPoolingOp.MAX
380 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
381 pool_op = NpuPoolingOp.AVERAGE
382 elif op.type == Op.ReduceSum:
383 pool_op = NpuPoolingOp.REDUCE_SUM
384 else:
385 assert 0, f"Unknown pool type {op.type}"
386 npu_op = NpuPoolingOperation(pool_op)
387 set_common_op_fields(npu_op, cmd, arch)
388 # Pooling specific info
389 if op.type == Op.ResizeBilinear and "rescale" in op.attrs:
390 npu_op.rescale = op.attrs["rescale"]
391 return npu_op
392
393
394def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
395 """Converts the command to NpuElementWiseOperation"""
396 ps = cmd.ps
397 op = ps.primary_op
398 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
399 elemwise_op = elementwise_op_map[op.type]
400 npu_op = NpuElementWiseOperation(elemwise_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100401 if elemwise_op not in UNARY_ELEMWISE_OPS:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100402 if not ifm_ifm2_correct_order(cmd.ifm_tensor.shape, cmd.ifm2_tensor.shape):
403 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
404 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
405 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
406 npu_op.reversed_operands = True
407 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch)
408 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
409 if cmd.ifm2_tensor.shape == []:
410 # scalar
411 assert cmd.ifm2_tensor.quant_values.size == 1
412 npu_op.ifm2_scalar = cmd.ifm2_tensor.values.item(0)
413 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
414 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100415 ifm2_blk = cmd.ifm2_box.get_block()
416 ifm2_width = Block.from_shape(cmd.ifm2_tensor.shape).width
417 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100418 set_common_op_fields(npu_op, cmd, arch)
419 # Check if output scale needs to be overridden
420 output_scale = None
421 if op.type == Op.Add and "resizebilinear" in op.attrs:
422 # Force output scale same as the input scale for
423 # resizebilinear 1x1 that is converted to add
424 output_scale = npu_op.ifm2.quantization.scale_f32
425 if op.type == Op.LeakyRelu:
426 output_scale = op.attrs["alpha"]
427 if op.type in (Op.Add, Op.Sub) and "rescale" in op.attrs:
428 npu_op.rescale = op.attrs.get("rescale")
429 if op.type in (Op.Add, Op.Mul, Op.Sub):
430 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
431 output_scale = 1 / 0x3000
432 if output_scale is not None:
433 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
434 return npu_op
435
436
437def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
438 """Converts the command to NpuDmaOperation"""
439 src_region = get_region(cmd.in_tensor, arch)
440 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100441 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100442 else:
443 dest_region = get_region(cmd.out_tensor, arch)
444
445 start_coord = cmd.box.start_coord
446 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
447 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
448
449 if cmd.in_tensor.compressed_values is not None:
450 if cmd.out_tensor.purpose == TensorPurpose.FSBias:
451 sz = cmd.in_tensor.storage_size()
452 else:
453 stream_index = cmd.in_tensor.compressed_stream_index_from_coord(start_coord)
454 sz = cmd.in_tensor.size_of_compressed_stream(stream_index)
455 else:
456 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
457 src = NpuAddressRange(src_region, int(src_addr), int(sz))
458 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
459 return NpuDmaOperation(src, dest)
460
461
462def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
463 """Converts the high level command to NpuOperation"""
464 if cmd.cmdtype == CommandType.DMA:
465 npu_op = create_dma_op(cmd, arch)
466 elif cmd.cmdtype == CommandType.NpuStripe:
467 npu_block_type = cmd.ps.primary_op.type.npu_block_type
468 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
469 npu_op = create_npu_conv2d_op(cmd, arch)
470 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
471 npu_op = create_npu_conv_depthwise_op(cmd, arch)
472 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
473 npu_op = create_npu_pool_op(cmd, arch)
474 elif npu_block_type == NpuBlockType.ElementWise:
475 npu_op = create_npu_elementwise_op(cmd, arch)
476 else:
477 assert 0, f"Unknown command type {npu_block_type}"
478 # add a link to the high level command for debugging purposes
479 npu_op.cmd = cmd
480 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100481
482
483def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
484 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
485 # Convert high level command stream to list of NpuOperation
486 npu_op_list = []
487 npu_op_to_cmd = dict() # map from npu op to high level command
488 for cmd in sg.high_level_command_stream:
489 if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.npu_block_type == NpuBlockType.Default:
490 print("Warning: Skipping register command stream generation for", cmd.ps)
491 else:
492 npu_op = convert_command_to_npu_op(cmd, arch)
493 npu_op_list.append(npu_op)
494 npu_op_to_cmd[npu_op] = cmd
495 # Generate register commands
496 stream_id = DebugDatabase.add_stream(sg)
497 DebugDatabase.set_stream_offset(sg, 0) # Default to zero, can only set during file writing
498
499 def add_to_debug_db(npu_op: NpuOperation, offset: int):
500 """Adds info to the debug database"""
501 if not is_dma_op(npu_op):
502 cmd = npu_op_to_cmd[npu_op]
503 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
504
505 sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db)