blob: 6246b37e14c14250b608e4ef26731dc4c97cc3bb [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Jonas Ohlsson845e2322022-03-01 12:39:55 +010020from typing import cast
Louis Verhaard024c3552021-03-17 14:26:34 +010021from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010022from typing import List
23from typing import Optional
Jonas Ohlsson845e2322022-03-01 12:39:55 +010024from typing import Tuple
Louis Verhaarde8a5a782020-11-02 18:04:27 +010025
26from .api import NpuActivation
27from .api import NpuActivationOp
28from .api import NpuAddressRange
29from .api import NpuBlockOperation
30from .api import NpuBlockTraversal
31from .api import NpuConv2DOperation
32from .api import NpuConvDepthWiseOperation
33from .api import NpuDataType
34from .api import NpuDmaOperation
35from .api import NpuElementWiseOp
36from .api import NpuElementWiseOperation
37from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010038from .api import NpuLayout
39from .api import NpuOperation
Rickard Bolin9ae34552022-06-09 13:07:17 +000040from .api import NpuOperationType
Louis Verhaarde8a5a782020-11-02 18:04:27 +010041from .api import NpuPadding
42from .api import NpuPoolingOp
43from .api import NpuPoolingOperation
44from .api import NpuQuantization
45from .api import NpuResamplingMode
46from .api import NpuRoundingMode
47from .api import NpuShape3D
48from .api import NpuTileBox
49from .architecture_features import ArchitectureFeatures
50from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010051from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000052from .errors import UnsupportedFeatureError
Tim Hall3c5cfe92022-03-16 16:31:57 +000053from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +010054from .high_level_command_stream import Box
55from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010056from .high_level_command_stream import DMA
57from .high_level_command_stream import NpuStripe
Fredrik Svedberg838df0a2021-09-17 16:29:22 +020058from .numeric_util import quantise_float32
Tim Halld8339a72021-05-27 18:49:40 +010059from .numeric_util import round_up
Louis Verhaarde8a5a782020-11-02 18:04:27 +010060from .operation import NpuBlockType
61from .operation import Op
62from .operation import Operation
Rickard Bolin9ae34552022-06-09 13:07:17 +000063from .operation import Padding
Louis Verhaard1e170182020-11-26 11:42:04 +010064from .register_command_stream_generator import generate_command_stream
65from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010066from .register_command_stream_util import to_npu_kernel
67from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000068from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010069from .tensor import MemType
70from .tensor import Tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010071from .tensor import TensorFormat
72from .tensor import TensorPurpose
Jonas Ohlsson845e2322022-03-01 12:39:55 +010073from .weight_compressor import NpuWeightTensor
Tim Halld8339a72021-05-27 18:49:40 +010074from .weight_compressor import WeightKey
Louis Verhaarde8a5a782020-11-02 18:04:27 +010075
76
Louis Verhaarde8a5a782020-11-02 18:04:27 +010077class BasePointerIndex(IntEnum):
78 WeightTensor = 0 # base address index for the Weight tensor
79 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
80 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010081
82
83dtype_map = {
84 DataType.uint8: NpuDataType.UINT8,
85 DataType.int8: NpuDataType.INT8,
86 DataType.uint16: NpuDataType.UINT16,
87 DataType.int16: NpuDataType.INT16,
88 DataType.int32: NpuDataType.INT32,
89}
90
91
Louis Verhaarde8a5a782020-11-02 18:04:27 +010092# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
93elementwise_op_map = {
94 Op.Mul: NpuElementWiseOp.MUL,
Patrik Gustavssonb081d672021-08-25 13:49:25 +020095 Op.RescaleMul: NpuElementWiseOp.MUL,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010096 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010097 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010098 Op.Sub: NpuElementWiseOp.SUB,
99 Op.Minimum: NpuElementWiseOp.MIN,
100 Op.Maximum: NpuElementWiseOp.MAX,
101 Op.LeakyRelu: NpuElementWiseOp.LRELU,
102 Op.Abs: NpuElementWiseOp.ABS,
103 Op.CLZ: NpuElementWiseOp.CLZ,
104 Op.SHR: NpuElementWiseOp.SHR,
105 Op.SHL: NpuElementWiseOp.SHL,
106}
107
108
Tim Hall3c5cfe92022-03-16 16:31:57 +0000109# inverse of the resampling_mode_map in the register command stream generator
110resampling_mode_inv_map = {
111 resampling_mode.NONE: NpuResamplingMode.NONE,
112 resampling_mode.NEAREST: NpuResamplingMode.NEAREST,
113 resampling_mode.TRANSPOSE: NpuResamplingMode.TRANSPOSE,
114}
115
116
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100117def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
118 if ifm_shape == []:
119 # Scalar needs to be in IFM2
120 return False
121 if ifm2_shape == []:
122 return True
123
124 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
125 if ifm != ifm2 and ifm == 1:
126 # Broadcasted FM needs to be in IFM2
127 return False
128 return True
129
130
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100131def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100132 """Specifies type of rounding to be used"""
133 rounding_mode = NpuRoundingMode.TFL
Tim Hall885033b2022-07-21 11:46:03 +0100134 if op.type.is_resize_op():
Dwight Lidman9d243932021-08-10 12:53:12 +0200135 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100136 elif (
137 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
138 and op.ifm.dtype == DataType.int16
139 ):
140 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100141 elif (
142 not fused_quantize
143 and op.type.is_avgpool_op()
144 and op.memory_function == Op.ConcatSliceWrite
145 and op.kernel.elements_wh() == 1
146 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100147 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100148 if op.rounding_mode is not None:
149 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100150 return rounding_mode
151
152
Rickard Bolin9ae34552022-06-09 13:07:17 +0000153def create_padding(cmd: NpuStripe, primary_op: Operation, npu_op: NpuBlockOperation) -> NpuPadding:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100154 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
155 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100156 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100157
158 # Check if this is for horizontal ifm streaming
159 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100160 top = cmd.pad_top
161 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100162
Tim Hall3751aa42021-12-16 13:17:29 +0000163 # the ifm box coordinate range depends upon whether the primary op was combined with a split slice read
164 ifm_read_offset = primary_op.read_offsets[0]
165 ifm_read_shape = primary_op.read_shapes[0]
166 if ifm_read_offset is None or len(ifm_read_offset) < 2:
167 box_start_coord_min = 0
168 box_end_coord_max = cmd.ps.ifm_shapes[0].width
169 else:
170 box_start_coord_min = ifm_read_offset[-2]
171 box_end_coord_max = ifm_read_shape[-2]
172
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100173 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
174 # because of activation function needed to be fused.
Tim Hall3751aa42021-12-16 13:17:29 +0000175 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > box_start_coord_min:
176 left = 0
177 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < box_end_coord_max:
178 right = 0
Rickard Bolin9ae34552022-06-09 13:07:17 +0000179
180 # If tile padding is selected, modify the tile base addresses and set NpuPadding to zero.
181 if primary_op.attrs.get("padding", None) == Padding.TILE:
182 assert cmd.ifm_tensor.format == TensorFormat.NHCWB16, "Tensor format NHCWB16 required to perform tile padding"
183 assert npu_op.op_type == NpuOperationType.ConvDepthWise, "Tile padding only supported for depthwise convolution"
184 assert npu_op.ifm is not None, "Feature map must be initialized to modify the tile addresses"
185 npu_op.ifm.tiles = modify_tile_addresses_for_padding(
186 npu_op.ifm.tiles,
187 primary_op.attrs.get("explicit_padding", None),
188 channels=cmd.ps.ifm_shapes[0].depth,
189 dtype=cmd.ifm_tensor.dtype,
190 )
191 top, left, bottom, right = 0, 0, 0, 0
Louis Verhaard69b31762020-11-17 09:45:20 +0100192 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100193
194
Rickard Bolin9ae34552022-06-09 13:07:17 +0000195def modify_tile_addresses_for_padding(
196 tile_box: NpuTileBox, padding_direction: List[int], channels: int, dtype: DataType
197) -> NpuTileBox:
198 # Addresses are 16-bytes aligned when using the NHCWB16 format, which is required to utilize tiling
199 # Calculate the offset to top right, bottom left and bottom right element in the IFM (top left offset is 0)
200 """
201 Example: 4x4x1 IFM
202 | a b c d | <-- Offset to TR ('d') is (w0-1) = 3
203 | e f g h |
204 | i j k l |
205 | m n o p | <-- Offset to BL ('m') is (w0*(h0-1)) = 12 and to BR ('p') ((w0*h0)-1) = 15
206 """
207 h0, h1, w0, addresses = tile_box
208 elem_size = 2 if dtype == DataType.int16 else 1
209 tr_offset = (w0 - 1) * 16 * elem_size
210 bl_offset = w0 * (h0 - 1) * 16 * (round_up(channels, 16) // 16) * elem_size
211 br_offset = tr_offset + bl_offset
212
213 # Explicit padding order: (Top, Left, Bottom, Right)
214 if padding_direction == (1, 1, 0, 0):
215 # Pad top left corner
216 """
217 | a a b |
218 | a b | -> | a a b |
219 | c d | | c c d |
220 """
221 addresses = [addresses[0]] * 4
222 h0, h1, w0 = 1, 1, 1
223
224 elif padding_direction == (1, 0, 0, 1):
225 # Pad top right corner
226 """
227 | a b b |
228 | a b | -> | a b b |
229 | c d | | c d d |
230 """
231 addresses = [addresses[0], addresses[0] + tr_offset, addresses[0], addresses[0] + tr_offset]
232 h0, h1, w0 = 1, 1, w0
233
234 elif padding_direction == (0, 1, 1, 0):
235 # Pad bottom left corner
236 """
237 | a b | | a a b |
238 | c d | -> | c c d |
239 | c c d |
240 """
241 addresses = [addresses[0], addresses[0], addresses[0] + bl_offset, addresses[0] + bl_offset]
242 h0, h1, w0 = h0, h1, 1
243
244 elif padding_direction == (0, 0, 1, 1):
245 # Pad bottom right corner
246 """
247 | a b | | a b b |
248 | c d | -> | c d d |
249 | c d d |
250 """
251 addresses = [
252 addresses[0],
253 addresses[0] + tr_offset,
254 addresses[0] + bl_offset,
255 addresses[0] + br_offset,
256 ]
257 # h0, h1, w0 = h0, h1, w0
258 else:
259 assert 0, "Invalid padding direction for tile padding"
260
261 return NpuTileBox(height_0=h0, height_1=h1, width_0=w0, addresses=[int(addr) for addr in addresses])
262
263
Louis Verhaard024c3552021-03-17 14:26:34 +0100264def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000265 base_ptr_idx_map = {
266 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
267 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
268 MemType.Scratch: BasePointerIndex.ScratchTensor,
269 }
270
271 if arch.is_spilling_enabled():
272 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100273 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000274 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
275
Louis Verhaard024c3552021-03-17 14:26:34 +0100276 return base_ptr_idx_map[mem_type].value
277
278
279def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
280 """Returns map region -> max size of the region in bytes"""
281 mem_limits = dict()
282 for mem_type in MemType.all():
283 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
284 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
285 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100286
287
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100288def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
289 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100290 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100291 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100292 block = ofm_box.get_block()
293 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100294
295
296def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
297 """Checks if quantization should use 0 as zero point"""
298 if tens.dtype == DataType.int32 and is_ifm_tensor:
299 return True
Tim Hall885033b2022-07-21 11:46:03 +0100300 if ps.primary_op.type not in (Op.AvgPool, Op.CLZ, Op.SHL) and not ps.primary_op.type.is_resize_op():
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100301 return False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200302 if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
303 return False
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100304 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
305 forced_ofm_quantization = ps.primary_op.forced_output_quantization
306 use_0 = (
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200307 (
308 ps.primary_op.activation is None
309 or forced_ofm_quantization is not None
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200310 or (
311 ps.primary_op.type.is_avgpool_op()
312 and ps.primary_op.activation.op_type.is_relu_op()
313 and not ps.primary_op.rescale
314 )
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200315 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100316 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
317 and not fused_quantize
318 )
319 return use_0
320
321
322def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
323 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100324 op = ps.primary_op
325 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
326 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100327 return None
328 if use_zero_point_0(ps, tens, True):
329 zero_point = 0
330 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100331 zero_point = int(ifm_quant.zero_point)
332 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100333
334
335def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
336 """Gets quantization for OFM"""
337 op = ps.primary_op
338 # Check if operation's output quantization is should be used instead of the output tensor's quantization
339 # (used in LUTs)
340 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
341 if ofm_quant is None:
342 return None
343 if use_zero_point_0(ps, tens, False):
344 zero_point = 0
345 else:
346 zero_point = int(ofm_quant.zero_point)
347 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
348
349
Rickard Bolin17e53b52022-09-06 16:09:01 +0000350def create_feature_map(
351 tens: Tensor,
352 box: Box,
353 arch: ArchitectureFeatures,
354 op_shape4D: Shape4D,
355 stride_multiplier: Optional[List[int]] = None,
356) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100357 """Creates feature map with common fields populated"""
358 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100359 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100360 fm.data_type = dtype_map[tens.dtype]
361 if tens.format == TensorFormat.NHWC:
362 fm.layout = NpuLayout.NHWC
363 elif tens.format == TensorFormat.NHCWB16:
364 fm.layout = NpuLayout.NHCWB16
365 else:
366 assert 0, "Incorrect tensor format"
Rickard Bolin17e53b52022-09-06 16:09:01 +0000367
368 strides = tens.get_strides(op_shape4D)
369 assert strides is not None
370
371 if stride_multiplier and stride_multiplier != [1, 1, 1]:
372 assert (
373 tens.format == TensorFormat.NHWC
374 ), "Only default stride multiplier ([1, 1, 1]) supported for NHCWB16 format"
375 # Multiply strides for C/H/W (in that order) with corresponding stride factor
376 for i, stride_factor in enumerate(stride_multiplier, start=1):
377 strides[i] *= stride_factor
378
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100379 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000380 box.start_coord, box.end_coord, strides, op_shape4D
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100381 )
Rickard Bolin17e53b52022-09-06 16:09:01 +0000382
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100383 fm.tiles = NpuTileBox(
384 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
385 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100386 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
Tim Hall68df8a12022-03-16 16:51:16 +0000387 fm.name = tens.name
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100388 return fm
389
390
Tim Halld784af72021-06-08 21:25:57 +0100391def create_weights(
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100392 weight_tensor: NpuWeightTensor, weight_box: Box, scale_tensor: NpuWeightTensor, arch: ArchitectureFeatures
393) -> Tuple[List[NpuAddressRange], List[NpuAddressRange]]:
Tim Halld8339a72021-05-27 18:49:40 +0100394 """Returns address ranges for weights and scales"""
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100395 weights = []
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100396 biases = []
Tim Halld784af72021-06-08 21:25:57 +0100397 shared_region = get_region(weight_tensor.mem_type, arch)
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100398 scale_region = get_region(scale_tensor.mem_type, arch) if scale_tensor else 0
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100399
Tim Halld8339a72021-05-27 18:49:40 +0100400 w_tensor_src = weight_tensor
401 if weight_tensor.src_tensor:
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100402 w_tensor_src = cast(NpuWeightTensor, weight_tensor.src_tensor)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100403
Tim Halld8339a72021-05-27 18:49:40 +0100404 core_offset = 0
405 for core in range(0, arch.ncores):
406 # Get weight range per core
407 key = WeightKey(core, weight_box.start_coord[-1])
408 if key in w_tensor_src.encoded_ranges:
409 weight_range = w_tensor_src.encoded_ranges[key]
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000410 if weight_tensor == w_tensor_src:
411 # Straight from source tensor
412 address = weight_tensor.address + weight_range.offset
Tim Hallb5df7732022-05-04 16:20:43 +0100413 else:
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000414 # Weight buffered tensor
415 address = weight_tensor.address + core_offset
416 core_offset += round_up(weight_range.total_bytes, 16)
Tim Halld8339a72021-05-27 18:49:40 +0100417
418 # Location of weights in tensor
419 addr_range = NpuAddressRange(
Tim Halld784af72021-06-08 21:25:57 +0100420 shared_region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16)
Tim Halld8339a72021-05-27 18:49:40 +0100421 )
422 weights.append(addr_range)
Tim Halld784af72021-06-08 21:25:57 +0100423
424 # Location of standalone scales or combined weights tensor scales
425 if scale_tensor:
426 assert scale_tensor.src_tensor is None # Must be standalone
427 scale_range = scale_tensor.encoded_ranges[key]
428 address = scale_tensor.address + scale_range.offset
429 addr_range = NpuAddressRange(scale_region, int(address), round_up(int(scale_range.scale_bytes), 16))
430 else:
431 addr_range = NpuAddressRange(shared_region, int(address), round_up(int(weight_range.scale_bytes), 16))
432
Tim Halld8339a72021-05-27 18:49:40 +0100433 biases.append(addr_range)
434
435 return weights, biases
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100436
437
438def create_npu_activation(op: Operation) -> NpuActivation:
439 """Creates fused activation function"""
440 if op.activation is None:
441 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
442 faf = op.activation.op_type
443 act_op = NpuActivationOp.NONE_OR_RELU
444 if faf == Op.Tanh:
445 act_op = NpuActivationOp.TANH
446 elif faf == Op.Sigmoid:
447 act_op = NpuActivationOp.SIGMOID
448 elif faf == Op.LUT:
449 act_op = NpuActivationOp.TABLE_LOOKUP
450 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000451 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100452
453 act = NpuActivation(act_op)
454 act.min = op.activation.min
455 act.max = op.activation.max
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200456 if act_op is NpuActivationOp.NONE_OR_RELU and op.type.is_avgpool_op() and not op.rescale:
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200457 quant = op.ofm.quantization
458 if quant and quant.zero_point: # Zero point is not 0
459 scale_f32 = 1 if quant.scale_f32 is None else quant.scale_f32
460 zero_point = quant.zero_point
461 if act.min is not None:
462 act.min = scale_f32 * quantise_float32(act.min, scale_f32, zero_point)
463 if act.max is not None:
464 act.max = scale_f32 * quantise_float32(act.max, scale_f32, zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100465 act.lookup_table_index = op.activation.lut_index
466 return act
467
468
469def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
470 """Sets common fields of the given operation"""
471 ps = cmd.ps
472 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100473
474 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100475 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100476 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100477
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100478 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100479 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100480 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100481
482 out_block = cmd.ofm_box.get_block()
Rickard Bolin17e53b52022-09-06 16:09:01 +0000483 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0], op.ofm_stride_multiplier)
Louis Verhaard69b31762020-11-17 09:45:20 +0100484 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100485 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
486
487 if cmd.weight_tensor is not None:
Tim Halld784af72021-06-08 21:25:57 +0100488 npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, cmd.scale_tensor, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100489 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100490 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
491 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100492 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
493
494 if not op.type.is_elementwise_op():
Rickard Bolin9ae34552022-06-09 13:07:17 +0000495 npu_op.padding = create_padding(cmd, op, npu_op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100496 npu_op.kernel = to_npu_kernel(op.kernel)
Tim Hall3c5cfe92022-03-16 16:31:57 +0000497 npu_op.ifm_upscale = resampling_mode_inv_map[op.ifm_resampling_mode]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100498 return npu_op
499
500
501def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
502 """Converts the command to NpuConv2DOperation"""
503 npu_op = NpuConv2DOperation()
504 set_common_op_fields(npu_op, cmd, arch)
505 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
506 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
507 else:
Tim Halld8339a72021-05-27 18:49:40 +0100508 if cmd.weight_tensor.src_tensor:
509 npu_op.block_traversal = cmd.weight_tensor.src_tensor.hw_traversal
510 else:
511 npu_op.block_traversal = cmd.weight_tensor.hw_traversal
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100512 return npu_op
513
514
515def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
516 """Converts the command to NpuConvDepthWiseOperation"""
517 npu_op = NpuConvDepthWiseOperation()
518 set_common_op_fields(npu_op, cmd, arch)
519 return npu_op
520
521
522def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
523 """Converts the command to NpuPoolingOperation"""
524 ps = cmd.ps
525 op = ps.primary_op
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100526 if op.type.is_maxpool_op():
527 pool_op = NpuPoolingOp.MAX
Tim Hall885033b2022-07-21 11:46:03 +0100528 elif op.type.is_avgpool_op() or op.type.is_resize_op():
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100529 pool_op = NpuPoolingOp.AVERAGE
530 elif op.type == Op.ReduceSum:
531 pool_op = NpuPoolingOp.REDUCE_SUM
532 else:
533 assert 0, f"Unknown pool type {op.type}"
534 npu_op = NpuPoolingOperation(pool_op)
535 set_common_op_fields(npu_op, cmd, arch)
536 # Pooling specific info
Dwight Lidman4f728c02020-12-17 15:14:45 +0100537 npu_op.rescale = op.rescale
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200538 if op.explicit_scaling:
539 # Note: reuse of rescale for explicit scaling to not expose this in the external API
540 assert npu_op.rescale is None
541 npu_op.rescale = op.explicit_scaling
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100542 return npu_op
543
544
545def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
546 """Converts the command to NpuElementWiseOperation"""
547 ps = cmd.ps
548 op = ps.primary_op
549 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
550 elemwise_op = elementwise_op_map[op.type]
551 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100552
Louis Verhaard1e170182020-11-26 11:42:04 +0100553 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100554 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
555 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
556 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100557 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
558 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
559 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100560 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100561 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100562 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100563 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
564 if cmd.ifm2_tensor.shape == []:
565 # scalar
James Peet7519d502021-07-19 16:47:58 +0100566 npu_op.ifm2_scalar = cmd.ifm2_tensor.get_scalar()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100567 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
568 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100569 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100570 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100571 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100572 set_common_op_fields(npu_op, cmd, arch)
573 # Check if output scale needs to be overridden
574 output_scale = None
Tim Hall885033b2022-07-21 11:46:03 +0100575 if op.type == Op.Add and op.original_type.is_resize_op():
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100576 # Force output scale same as the input scale for
Tim Hall885033b2022-07-21 11:46:03 +0100577 # resizebilinear/nearestneighbor 1x1 that is converted to add
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100578 output_scale = npu_op.ifm2.quantization.scale_f32
Tim Hall885033b2022-07-21 11:46:03 +0100579 elif op.type == Op.Abs:
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100580 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Tim Hall885033b2022-07-21 11:46:03 +0100581 elif op.type == Op.LeakyRelu:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100582 output_scale = op.attrs["alpha"]
Tim Hall885033b2022-07-21 11:46:03 +0100583 elif op.type in (Op.RescaleAdd, Op.RescaleMul):
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100584 assert op.rescale is not None, f"{op.type} must have rescale"
585 npu_op.rescale = op.rescale
Tim Hall885033b2022-07-21 11:46:03 +0100586 elif op.type in (Op.Add, Op.Mul, Op.Sub):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100587 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
588 output_scale = 1 / 0x3000
589 if output_scale is not None:
590 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
591 return npu_op
592
593
594def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
595 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100596 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100597 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100598 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100599 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100600 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100601
Tim Halld8339a72021-05-27 18:49:40 +0100602 if cmd.in_tensor.purpose == TensorPurpose.Weights:
603 # Get weight range per core
604 sz = 0
605 for core in range(0, arch.ncores):
606 key = WeightKey(core, cmd.box.start_coord[-1])
607 if key in cmd.in_tensor.encoded_ranges:
608 weight_range = cmd.in_tensor.encoded_ranges[key]
609 sz += round_up(weight_range.total_bytes, 16)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100610
Tim Halld8339a72021-05-27 18:49:40 +0100611 if core == 0:
612 weight_range = cmd.in_tensor.encoded_ranges[key]
613 src_addr = cmd.in_tensor.address + weight_range.offset
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000614 dest_addr = cmd.out_tensor.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100615 else:
Rickard Bolin17e53b52022-09-06 16:09:01 +0000616 src_addr = cmd.in_tensor.address_for_coordinate(cmd.box.start_coord)
617 dest_addr = cmd.out_tensor.address_for_coordinate(cmd.box.start_coord)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100618 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
619 src = NpuAddressRange(src_region, int(src_addr), int(sz))
620 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
621 return NpuDmaOperation(src, dest)
622
623
624def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
625 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100626 npu_op: NpuOperation
627 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100628 npu_op = create_dma_op(cmd, arch)
Tim Hall68df8a12022-03-16 16:51:16 +0000629 npu_op.name = cmd.out_tensor.name
Dwight Lidman9b43f842020-12-08 17:56:44 +0100630 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100631 npu_block_type = cmd.ps.primary_op.type.npu_block_type
632 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
633 npu_op = create_npu_conv2d_op(cmd, arch)
634 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
635 npu_op = create_npu_conv_depthwise_op(cmd, arch)
636 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
637 npu_op = create_npu_pool_op(cmd, arch)
638 elif npu_block_type == NpuBlockType.ElementWise:
639 npu_op = create_npu_elementwise_op(cmd, arch)
640 else:
641 assert 0, f"Unknown command type {npu_block_type}"
Tim Hall68df8a12022-03-16 16:51:16 +0000642 npu_op.name = cmd.ps.primary_op.name
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100643 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100644
645
646def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
647 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
648 # Convert high level command stream to list of NpuOperation
649 npu_op_list = []
650 npu_op_to_cmd = dict() # map from npu op to high level command
651 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100652 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100653 print("Warning: Skipping register command stream generation for", cmd.ps)
654 else:
655 npu_op = convert_command_to_npu_op(cmd, arch)
656 npu_op_list.append(npu_op)
657 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100658 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100659 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100660 if len(sg.high_level_command_stream) > 0:
661 stream_id = DebugDatabase.add_stream(sg)
662 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100663
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100664 def add_to_debug_db(npu_op: NpuOperation, offset: int):
665 """Adds info to the debug database"""
666 if not isinstance(npu_op, NpuDmaOperation):
667 cmd = npu_op_to_cmd[npu_op]
668 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100669
Louis Verhaard024c3552021-03-17 14:26:34 +0100670 sg.register_command_stream = generate_command_stream(
671 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
672 )