blob: 9380374ea4f15191c126e1b2ee4911e60802a238 [file] [log] [blame]
Louis Verhaarde8a5a782020-11-02 18:04:27 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
20from typing import List
21from typing import Optional
22
23from .api import NpuActivation
24from .api import NpuActivationOp
25from .api import NpuAddressRange
26from .api import NpuBlockOperation
27from .api import NpuBlockTraversal
28from .api import NpuConv2DOperation
29from .api import NpuConvDepthWiseOperation
30from .api import NpuDataType
31from .api import NpuDmaOperation
32from .api import NpuElementWiseOp
33from .api import NpuElementWiseOperation
34from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010035from .api import NpuLayout
36from .api import NpuOperation
37from .api import NpuPadding
38from .api import NpuPoolingOp
39from .api import NpuPoolingOperation
40from .api import NpuQuantization
41from .api import NpuResamplingMode
42from .api import NpuRoundingMode
43from .api import NpuShape3D
44from .api import NpuTileBox
45from .architecture_features import ArchitectureFeatures
Louis Verhaard69b31762020-11-17 09:45:20 +010046from .architecture_features import Block
Louis Verhaarde8a5a782020-11-02 18:04:27 +010047from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010048from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000049from .errors import UnsupportedFeatureError
Louis Verhaarde8a5a782020-11-02 18:04:27 +010050from .high_level_command_stream import Box
51from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010052from .high_level_command_stream import DMA
53from .high_level_command_stream import NpuStripe
Louis Verhaarde8a5a782020-11-02 18:04:27 +010054from .operation import NpuBlockType
55from .operation import Op
56from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010057from .register_command_stream_generator import generate_command_stream
58from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010059from .register_command_stream_util import to_npu_kernel
60from .register_command_stream_util import UNARY_ELEMWISE_OPS
Louis Verhaarde8a5a782020-11-02 18:04:27 +010061from .tensor import MemType
62from .tensor import Tensor
63from .tensor import TensorBlockTraversal
64from .tensor import TensorFormat
65from .tensor import TensorPurpose
66
67
Louis Verhaarde8a5a782020-11-02 18:04:27 +010068class BasePointerIndex(IntEnum):
69 WeightTensor = 0 # base address index for the Weight tensor
70 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
71 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010072
73
74dtype_map = {
75 DataType.uint8: NpuDataType.UINT8,
76 DataType.int8: NpuDataType.INT8,
77 DataType.uint16: NpuDataType.UINT16,
78 DataType.int16: NpuDataType.INT16,
79 DataType.int32: NpuDataType.INT32,
80}
81
82
83block_traversal_map = {
84 TensorBlockTraversal.DepthFirst: NpuBlockTraversal.DEPTH_FIRST,
85 TensorBlockTraversal.PartKernelFirst: NpuBlockTraversal.PART_KERNEL_FIRST,
86}
87
88
89# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
90elementwise_op_map = {
91 Op.Mul: NpuElementWiseOp.MUL,
92 Op.Add: NpuElementWiseOp.ADD,
93 Op.Sub: NpuElementWiseOp.SUB,
94 Op.Minimum: NpuElementWiseOp.MIN,
95 Op.Maximum: NpuElementWiseOp.MAX,
96 Op.LeakyRelu: NpuElementWiseOp.LRELU,
97 Op.Abs: NpuElementWiseOp.ABS,
98 Op.CLZ: NpuElementWiseOp.CLZ,
99 Op.SHR: NpuElementWiseOp.SHR,
100 Op.SHL: NpuElementWiseOp.SHL,
101}
102
103
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100104def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
105 if ifm_shape == []:
106 # Scalar needs to be in IFM2
107 return False
108 if ifm2_shape == []:
109 return True
110
111 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
112 if ifm != ifm2 and ifm == 1:
113 # Broadcasted FM needs to be in IFM2
114 return False
115 return True
116
117
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100118def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100119 """Specifies type of rounding to be used"""
120 rounding_mode = NpuRoundingMode.TFL
121 if op.type == Op.ResizeBilinear:
122 rounding_mode = NpuRoundingMode.TRUNCATE
123 elif (
124 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
125 and op.ifm.dtype == DataType.int16
126 ):
127 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100128 elif (
129 not fused_quantize
130 and op.type.is_avgpool_op()
131 and op.memory_function == Op.ConcatSliceWrite
132 and op.kernel.elements_wh() == 1
133 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100134 rounding_mode = NpuRoundingMode.NATURAL
135 rounding_mode = op.attrs.get("rounding_mode", rounding_mode)
136 return rounding_mode
137
138
139def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
140 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
141 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100142 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100143
144 # Check if this is for horizontal ifm streaming
145 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100146 top = cmd.pad_top
147 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100148
149 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
150 # because of activation function needed to be fused.
Andreas Nevalainen083f1032020-11-18 10:45:50 +0100151 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > 0:
Louis Verhaard69b31762020-11-17 09:45:20 +0100152 left = 0
Andreas Nevalainen083f1032020-11-18 10:45:50 +0100153 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < Block.from_shape(cmd.ifm_tensor.shape).width:
Louis Verhaard69b31762020-11-17 09:45:20 +0100154 right = 0
155 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100156
157
158def get_region(tens: Tensor, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000159 base_ptr_idx_map = {
160 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
161 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
162 MemType.Scratch: BasePointerIndex.ScratchTensor,
163 }
164
165 if arch.is_spilling_enabled():
166 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100167 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000168 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
169
Dwight Lidman9b43f842020-12-08 17:56:44 +0100170 return base_ptr_idx_map[tens.mem_type].value
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100171
172
173def get_upscale(op: Operation) -> NpuResamplingMode:
174 upscale = NpuResamplingMode.NONE
175 if op.type == Op.ResizeBilinear:
176 # perform nearest neighbor upscale
177 upscale = NpuResamplingMode.NEAREST
178 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
179 # perform insert zero upscale
180 upscale = NpuResamplingMode.TRANSPOSE
181 return upscale
182
183
184def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
185 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100186 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100187 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100188 block = ofm_box.get_block()
189 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100190
191
192def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
193 """Checks if quantization should use 0 as zero point"""
194 if tens.dtype == DataType.int32 and is_ifm_tensor:
195 return True
196 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
197 return False
198 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
199 forced_ofm_quantization = ps.primary_op.forced_output_quantization
200 use_0 = (
201 (ps.primary_op.activation is None or forced_ofm_quantization is not None)
202 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
203 and not fused_quantize
204 )
205 return use_0
206
207
208def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
209 """Gets quantization for IFM/IFM2"""
210 if tens.quantization is None:
211 return None
212 if use_zero_point_0(ps, tens, True):
213 zero_point = 0
214 else:
215 zero_point = int(tens.quantization.zero_point)
216 return NpuQuantization(scale_f32=tens.quantization.scale_f32, zero_point=zero_point)
217
218
219def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
220 """Gets quantization for OFM"""
221 op = ps.primary_op
222 # Check if operation's output quantization is should be used instead of the output tensor's quantization
223 # (used in LUTs)
224 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
225 if ofm_quant is None:
226 return None
227 if use_zero_point_0(ps, tens, False):
228 zero_point = 0
229 else:
230 zero_point = int(ofm_quant.zero_point)
231 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
232
233
patrik.gustavssondf0a5902020-12-21 16:56:26 +0000234def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, fm_shape: List[int]) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100235 """Creates feature map with common fields populated"""
236 fm = NpuFeatureMap()
237 fm.region = get_region(tens, arch)
238 fm.data_type = dtype_map[tens.dtype]
239 if tens.format == TensorFormat.NHWC:
240 fm.layout = NpuLayout.NHWC
241 elif tens.format == TensorFormat.NHCWB16:
242 fm.layout = NpuLayout.NHCWB16
243 else:
244 assert 0, "Incorrect tensor format"
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100245 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(box.start_coord, box.end_coord, fm_shape)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100246 for idx, addr in enumerate(addresses):
247 if addr is None:
248 addresses[idx] = 0
249 fm.tiles = NpuTileBox(
250 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
251 )
252 strides = tens.get_strides()
253 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
254 return fm
255
256
257def create_weights(weight_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures) -> List[NpuAddressRange]:
258 """Returns address ranges for weights"""
259 weights = []
260 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
261 weight_substream_offsets = weight_tensor.compressed_values_substream_offsets[stream_index]
262 substreams = len(weight_substream_offsets) - 1 # Offset list must terminate with full stream length
263
264 # Extract weight substream offsets and calculate their lengths
265 assert len(weight_substream_offsets) > 1 and (weight_substream_offsets[0] == 0)
266 weight_addr = weight_tensor.address_for_coordinate(weight_box.start_coord)
267 region = get_region(weight_tensor, arch)
268 for core in range(substreams):
269 address = weight_addr + weight_substream_offsets[core]
270 length = weight_substream_offsets[core + 1] - weight_substream_offsets[core]
271 addr_range = NpuAddressRange(region, int(address), int(length))
272 weights.append(addr_range)
273 return weights
274
275
276def create_biases(
277 weight_tensor: Tensor, scale_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures
278) -> List[NpuAddressRange]:
279 """Returns address ranges for biases"""
280 biases = []
281 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
282 scale_substream_offsets = scale_tensor.compressed_values_substream_offsets[stream_index]
283 substreams = len(scale_substream_offsets) - 1 # Offset list must terminate with full stream length
284
285 # Extract scale substream offsets and calculate their lengths
286 assert len(scale_substream_offsets) > 1 and (scale_substream_offsets[0] == 0)
287 scale_addr = scale_tensor.address_for_coordinate(weight_box.start_coord[-1:])
288
289 region = get_region(scale_tensor, arch)
290 for core in range(substreams):
291 address = scale_addr + scale_substream_offsets[core]
292 length = scale_substream_offsets[core + 1] - scale_substream_offsets[core]
293 addr_range = NpuAddressRange(region, int(address), int(length))
294 biases.append(addr_range)
295 return biases
296
297
298def create_npu_activation(op: Operation) -> NpuActivation:
299 """Creates fused activation function"""
300 if op.activation is None:
301 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
302 faf = op.activation.op_type
303 act_op = NpuActivationOp.NONE_OR_RELU
304 if faf == Op.Tanh:
305 act_op = NpuActivationOp.TANH
306 elif faf == Op.Sigmoid:
307 act_op = NpuActivationOp.SIGMOID
308 elif faf == Op.LUT:
309 act_op = NpuActivationOp.TABLE_LOOKUP
310 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000311 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100312
313 act = NpuActivation(act_op)
314 act.min = op.activation.min
315 act.max = op.activation.max
316 act.lookup_table_index = op.activation.lut_index
317 return act
318
319
320def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
321 """Sets common fields of the given operation"""
322 ps = cmd.ps
323 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100324
325 ifm_height = cmd.ifm_box.get_block().height
326 ifm_width = Block.from_shape(cmd.ifm_tensor.shape).width
327 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100328
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100329 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100330 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100331 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100332
333 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100334 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100335 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100336 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
337
338 if cmd.weight_tensor is not None:
339 npu_op.weights = create_weights(cmd.weight_tensor, cmd.weight_box, arch)
340 if cmd.scale_tensor is not None:
341 npu_op.biases = create_biases(cmd.weight_tensor, cmd.scale_tensor, cmd.weight_box, arch)
342 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100343 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
344 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100345 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
346
347 if not op.type.is_elementwise_op():
348 npu_op.padding = create_padding(cmd, op)
349 npu_op.kernel = to_npu_kernel(op.kernel)
350 npu_op.ifm_upscale = get_upscale(op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100351 return npu_op
352
353
354def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
355 """Converts the command to NpuConv2DOperation"""
356 npu_op = NpuConv2DOperation()
357 set_common_op_fields(npu_op, cmd, arch)
358 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
359 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
360 else:
361 npu_op.block_traversal = block_traversal_map[cmd.weight_tensor.block_traversal]
362 return npu_op
363
364
365def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
366 """Converts the command to NpuConvDepthWiseOperation"""
367 npu_op = NpuConvDepthWiseOperation()
368 set_common_op_fields(npu_op, cmd, arch)
369 return npu_op
370
371
372def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
373 """Converts the command to NpuPoolingOperation"""
374 ps = cmd.ps
375 op = ps.primary_op
376 pool_op = NpuPoolingOp.AVERAGE
377 if op.type.is_maxpool_op():
378 pool_op = NpuPoolingOp.MAX
379 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
380 pool_op = NpuPoolingOp.AVERAGE
381 elif op.type == Op.ReduceSum:
382 pool_op = NpuPoolingOp.REDUCE_SUM
383 else:
384 assert 0, f"Unknown pool type {op.type}"
385 npu_op = NpuPoolingOperation(pool_op)
386 set_common_op_fields(npu_op, cmd, arch)
387 # Pooling specific info
388 if op.type == Op.ResizeBilinear and "rescale" in op.attrs:
389 npu_op.rescale = op.attrs["rescale"]
390 return npu_op
391
392
393def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
394 """Converts the command to NpuElementWiseOperation"""
395 ps = cmd.ps
396 op = ps.primary_op
397 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
398 elemwise_op = elementwise_op_map[op.type]
399 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100400
Louis Verhaard1e170182020-11-26 11:42:04 +0100401 if elemwise_op not in UNARY_ELEMWISE_OPS:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100402 if not ifm_ifm2_correct_order(cmd.ifm_tensor.shape, cmd.ifm2_tensor.shape):
403 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
404 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
405 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100406 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100407 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100408 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100409 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
410 if cmd.ifm2_tensor.shape == []:
411 # scalar
412 assert cmd.ifm2_tensor.quant_values.size == 1
413 npu_op.ifm2_scalar = cmd.ifm2_tensor.values.item(0)
414 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
415 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100416 ifm2_blk = cmd.ifm2_box.get_block()
417 ifm2_width = Block.from_shape(cmd.ifm2_tensor.shape).width
418 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100419 set_common_op_fields(npu_op, cmd, arch)
420 # Check if output scale needs to be overridden
421 output_scale = None
422 if op.type == Op.Add and "resizebilinear" in op.attrs:
423 # Force output scale same as the input scale for
424 # resizebilinear 1x1 that is converted to add
425 output_scale = npu_op.ifm2.quantization.scale_f32
426 if op.type == Op.LeakyRelu:
427 output_scale = op.attrs["alpha"]
428 if op.type in (Op.Add, Op.Sub) and "rescale" in op.attrs:
429 npu_op.rescale = op.attrs.get("rescale")
430 if op.type in (Op.Add, Op.Mul, Op.Sub):
431 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
432 output_scale = 1 / 0x3000
433 if output_scale is not None:
434 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
435 return npu_op
436
437
438def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
439 """Converts the command to NpuDmaOperation"""
440 src_region = get_region(cmd.in_tensor, arch)
441 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100442 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100443 else:
444 dest_region = get_region(cmd.out_tensor, arch)
445
446 start_coord = cmd.box.start_coord
447 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
448 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
449
450 if cmd.in_tensor.compressed_values is not None:
451 if cmd.out_tensor.purpose == TensorPurpose.FSBias:
452 sz = cmd.in_tensor.storage_size()
453 else:
454 stream_index = cmd.in_tensor.compressed_stream_index_from_coord(start_coord)
455 sz = cmd.in_tensor.size_of_compressed_stream(stream_index)
456 else:
457 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
458 src = NpuAddressRange(src_region, int(src_addr), int(sz))
459 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
460 return NpuDmaOperation(src, dest)
461
462
463def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
464 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100465 npu_op: NpuOperation
466 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100467 npu_op = create_dma_op(cmd, arch)
Dwight Lidman9b43f842020-12-08 17:56:44 +0100468 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100469 npu_block_type = cmd.ps.primary_op.type.npu_block_type
470 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
471 npu_op = create_npu_conv2d_op(cmd, arch)
472 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
473 npu_op = create_npu_conv_depthwise_op(cmd, arch)
474 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
475 npu_op = create_npu_pool_op(cmd, arch)
476 elif npu_block_type == NpuBlockType.ElementWise:
477 npu_op = create_npu_elementwise_op(cmd, arch)
478 else:
479 assert 0, f"Unknown command type {npu_block_type}"
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100480 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100481
482
483def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
484 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
485 # Convert high level command stream to list of NpuOperation
486 npu_op_list = []
487 npu_op_to_cmd = dict() # map from npu op to high level command
488 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100489 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100490 print("Warning: Skipping register command stream generation for", cmd.ps)
491 else:
492 npu_op = convert_command_to_npu_op(cmd, arch)
493 npu_op_list.append(npu_op)
494 npu_op_to_cmd[npu_op] = cmd
495 # Generate register commands
496 stream_id = DebugDatabase.add_stream(sg)
497 DebugDatabase.set_stream_offset(sg, 0) # Default to zero, can only set during file writing
498
499 def add_to_debug_db(npu_op: NpuOperation, offset: int):
500 """Adds info to the debug database"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100501 if not isinstance(npu_op, NpuDmaOperation):
Louis Verhaard1e170182020-11-26 11:42:04 +0100502 cmd = npu_op_to_cmd[npu_op]
503 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
504
Dwight Lidman9b43f842020-12-08 17:56:44 +0100505 sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db, npu_op_to_cmd)