blob: 7923e3717446831e769e3c37a05fef4b2d2c868c [file] [log] [blame]
Rickard Bolinfea15162022-07-04 16:19:16 +00001# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Jonas Ohlsson845e2322022-03-01 12:39:55 +010020from typing import cast
Louis Verhaard024c3552021-03-17 14:26:34 +010021from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010022from typing import List
23from typing import Optional
Jonas Ohlsson845e2322022-03-01 12:39:55 +010024from typing import Tuple
Louis Verhaarde8a5a782020-11-02 18:04:27 +010025
26from .api import NpuActivation
27from .api import NpuActivationOp
28from .api import NpuAddressRange
29from .api import NpuBlockOperation
30from .api import NpuBlockTraversal
31from .api import NpuConv2DOperation
32from .api import NpuConvDepthWiseOperation
33from .api import NpuDataType
34from .api import NpuDmaOperation
35from .api import NpuElementWiseOp
36from .api import NpuElementWiseOperation
37from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010038from .api import NpuLayout
39from .api import NpuOperation
Rickard Bolin9ae34552022-06-09 13:07:17 +000040from .api import NpuOperationType
Louis Verhaarde8a5a782020-11-02 18:04:27 +010041from .api import NpuPadding
42from .api import NpuPoolingOp
43from .api import NpuPoolingOperation
44from .api import NpuQuantization
45from .api import NpuResamplingMode
46from .api import NpuRoundingMode
47from .api import NpuShape3D
48from .api import NpuTileBox
49from .architecture_features import ArchitectureFeatures
50from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010051from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000052from .errors import UnsupportedFeatureError
Tim Hall3c5cfe92022-03-16 16:31:57 +000053from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +010054from .high_level_command_stream import Box
55from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010056from .high_level_command_stream import DMA
57from .high_level_command_stream import NpuStripe
Fredrik Svedberg838df0a2021-09-17 16:29:22 +020058from .numeric_util import quantise_float32
Tim Halld8339a72021-05-27 18:49:40 +010059from .numeric_util import round_up
Louis Verhaarde8a5a782020-11-02 18:04:27 +010060from .operation import NpuBlockType
61from .operation import Op
62from .operation import Operation
Rickard Bolin9ae34552022-06-09 13:07:17 +000063from .operation import Padding
Louis Verhaard1e170182020-11-26 11:42:04 +010064from .register_command_stream_generator import generate_command_stream
65from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010066from .register_command_stream_util import to_npu_kernel
67from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000068from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010069from .tensor import MemType
70from .tensor import Tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010071from .tensor import TensorFormat
72from .tensor import TensorPurpose
Jonas Ohlsson845e2322022-03-01 12:39:55 +010073from .weight_compressor import NpuWeightTensor
Tim Halld8339a72021-05-27 18:49:40 +010074from .weight_compressor import WeightKey
Louis Verhaarde8a5a782020-11-02 18:04:27 +010075
76
Louis Verhaarde8a5a782020-11-02 18:04:27 +010077class BasePointerIndex(IntEnum):
78 WeightTensor = 0 # base address index for the Weight tensor
79 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
80 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010081
82
83dtype_map = {
84 DataType.uint8: NpuDataType.UINT8,
85 DataType.int8: NpuDataType.INT8,
86 DataType.uint16: NpuDataType.UINT16,
87 DataType.int16: NpuDataType.INT16,
88 DataType.int32: NpuDataType.INT32,
89}
90
91
Louis Verhaarde8a5a782020-11-02 18:04:27 +010092# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
93elementwise_op_map = {
94 Op.Mul: NpuElementWiseOp.MUL,
Patrik Gustavssonb081d672021-08-25 13:49:25 +020095 Op.RescaleMul: NpuElementWiseOp.MUL,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010096 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010097 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010098 Op.Sub: NpuElementWiseOp.SUB,
99 Op.Minimum: NpuElementWiseOp.MIN,
100 Op.Maximum: NpuElementWiseOp.MAX,
101 Op.LeakyRelu: NpuElementWiseOp.LRELU,
102 Op.Abs: NpuElementWiseOp.ABS,
103 Op.CLZ: NpuElementWiseOp.CLZ,
104 Op.SHR: NpuElementWiseOp.SHR,
105 Op.SHL: NpuElementWiseOp.SHL,
106}
107
108
Tim Hall3c5cfe92022-03-16 16:31:57 +0000109# inverse of the resampling_mode_map in the register command stream generator
110resampling_mode_inv_map = {
111 resampling_mode.NONE: NpuResamplingMode.NONE,
112 resampling_mode.NEAREST: NpuResamplingMode.NEAREST,
113 resampling_mode.TRANSPOSE: NpuResamplingMode.TRANSPOSE,
114}
115
116
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100117def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
118 if ifm_shape == []:
119 # Scalar needs to be in IFM2
120 return False
121 if ifm2_shape == []:
122 return True
123
124 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
125 if ifm != ifm2 and ifm == 1:
126 # Broadcasted FM needs to be in IFM2
127 return False
128 return True
129
130
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100131def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100132 """Specifies type of rounding to be used"""
133 rounding_mode = NpuRoundingMode.TFL
Tim Hall885033b2022-07-21 11:46:03 +0100134 if op.type.is_resize_op():
Dwight Lidman9d243932021-08-10 12:53:12 +0200135 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100136 elif (
137 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
138 and op.ifm.dtype == DataType.int16
139 ):
140 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100141 elif (
142 not fused_quantize
143 and op.type.is_avgpool_op()
144 and op.memory_function == Op.ConcatSliceWrite
145 and op.kernel.elements_wh() == 1
146 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100147 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100148 if op.rounding_mode is not None:
149 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100150 return rounding_mode
151
152
Rickard Bolin9ae34552022-06-09 13:07:17 +0000153def create_padding(cmd: NpuStripe, primary_op: Operation, npu_op: NpuBlockOperation) -> NpuPadding:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100154 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
155 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100156 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100157
158 # Check if this is for horizontal ifm streaming
159 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100160 top = cmd.pad_top
161 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100162
Tim Hall3751aa42021-12-16 13:17:29 +0000163 # the ifm box coordinate range depends upon whether the primary op was combined with a split slice read
164 ifm_read_offset = primary_op.read_offsets[0]
165 ifm_read_shape = primary_op.read_shapes[0]
166 if ifm_read_offset is None or len(ifm_read_offset) < 2:
167 box_start_coord_min = 0
168 box_end_coord_max = cmd.ps.ifm_shapes[0].width
169 else:
170 box_start_coord_min = ifm_read_offset[-2]
171 box_end_coord_max = ifm_read_shape[-2]
172
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100173 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
174 # because of activation function needed to be fused.
Tim Hall3751aa42021-12-16 13:17:29 +0000175 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > box_start_coord_min:
176 left = 0
177 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < box_end_coord_max:
178 right = 0
Rickard Bolin9ae34552022-06-09 13:07:17 +0000179
180 # If tile padding is selected, modify the tile base addresses and set NpuPadding to zero.
181 if primary_op.attrs.get("padding", None) == Padding.TILE:
182 assert cmd.ifm_tensor.format == TensorFormat.NHCWB16, "Tensor format NHCWB16 required to perform tile padding"
183 assert npu_op.op_type == NpuOperationType.ConvDepthWise, "Tile padding only supported for depthwise convolution"
184 assert npu_op.ifm is not None, "Feature map must be initialized to modify the tile addresses"
185 npu_op.ifm.tiles = modify_tile_addresses_for_padding(
186 npu_op.ifm.tiles,
187 primary_op.attrs.get("explicit_padding", None),
188 channels=cmd.ps.ifm_shapes[0].depth,
189 dtype=cmd.ifm_tensor.dtype,
190 )
191 top, left, bottom, right = 0, 0, 0, 0
Rickard Bolinfea15162022-07-04 16:19:16 +0000192
Louis Verhaard69b31762020-11-17 09:45:20 +0100193 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100194
195
Rickard Bolin9ae34552022-06-09 13:07:17 +0000196def modify_tile_addresses_for_padding(
197 tile_box: NpuTileBox, padding_direction: List[int], channels: int, dtype: DataType
198) -> NpuTileBox:
199 # Addresses are 16-bytes aligned when using the NHCWB16 format, which is required to utilize tiling
200 # Calculate the offset to top right, bottom left and bottom right element in the IFM (top left offset is 0)
201 """
202 Example: 4x4x1 IFM
203 | a b c d | <-- Offset to TR ('d') is (w0-1) = 3
204 | e f g h |
205 | i j k l |
206 | m n o p | <-- Offset to BL ('m') is (w0*(h0-1)) = 12 and to BR ('p') ((w0*h0)-1) = 15
207 """
208 h0, h1, w0, addresses = tile_box
209 elem_size = 2 if dtype == DataType.int16 else 1
210 tr_offset = (w0 - 1) * 16 * elem_size
211 bl_offset = w0 * (h0 - 1) * 16 * (round_up(channels, 16) // 16) * elem_size
212 br_offset = tr_offset + bl_offset
213
214 # Explicit padding order: (Top, Left, Bottom, Right)
215 if padding_direction == (1, 1, 0, 0):
216 # Pad top left corner
217 """
218 | a a b |
219 | a b | -> | a a b |
220 | c d | | c c d |
221 """
222 addresses = [addresses[0]] * 4
223 h0, h1, w0 = 1, 1, 1
224
225 elif padding_direction == (1, 0, 0, 1):
226 # Pad top right corner
227 """
228 | a b b |
229 | a b | -> | a b b |
230 | c d | | c d d |
231 """
232 addresses = [addresses[0], addresses[0] + tr_offset, addresses[0], addresses[0] + tr_offset]
233 h0, h1, w0 = 1, 1, w0
234
235 elif padding_direction == (0, 1, 1, 0):
236 # Pad bottom left corner
237 """
238 | a b | | a a b |
239 | c d | -> | c c d |
240 | c c d |
241 """
242 addresses = [addresses[0], addresses[0], addresses[0] + bl_offset, addresses[0] + bl_offset]
243 h0, h1, w0 = h0, h1, 1
244
245 elif padding_direction == (0, 0, 1, 1):
246 # Pad bottom right corner
247 """
248 | a b | | a b b |
249 | c d | -> | c d d |
250 | c d d |
251 """
252 addresses = [
253 addresses[0],
254 addresses[0] + tr_offset,
255 addresses[0] + bl_offset,
256 addresses[0] + br_offset,
257 ]
258 # h0, h1, w0 = h0, h1, w0
259 else:
260 assert 0, "Invalid padding direction for tile padding"
261
262 return NpuTileBox(height_0=h0, height_1=h1, width_0=w0, addresses=[int(addr) for addr in addresses])
263
264
Louis Verhaard024c3552021-03-17 14:26:34 +0100265def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000266 base_ptr_idx_map = {
267 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
268 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
269 MemType.Scratch: BasePointerIndex.ScratchTensor,
270 }
271
272 if arch.is_spilling_enabled():
273 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100274 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000275 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
276
Louis Verhaard024c3552021-03-17 14:26:34 +0100277 return base_ptr_idx_map[mem_type].value
278
279
280def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
281 """Returns map region -> max size of the region in bytes"""
282 mem_limits = dict()
283 for mem_type in MemType.all():
284 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
285 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
286 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100287
288
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100289def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
290 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100291 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100292 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100293 block = ofm_box.get_block()
294 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100295
296
297def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
298 """Checks if quantization should use 0 as zero point"""
299 if tens.dtype == DataType.int32 and is_ifm_tensor:
300 return True
Rickard Bolinfea15162022-07-04 16:19:16 +0000301 # Force zero point to 0 for ResizeBilinear when converting to a DepthwiseConv since the reference kernel
302 # will ignore the zero point.
303 if ps.primary_op.original_type == Op.ResizeBilinear and ps.primary_op.type == Op.DepthwiseConv2DBias:
304 return True
Tim Hall885033b2022-07-21 11:46:03 +0100305 if ps.primary_op.type not in (Op.AvgPool, Op.CLZ, Op.SHL) and not ps.primary_op.type.is_resize_op():
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100306 return False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200307 if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
308 return False
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100309 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
310 forced_ofm_quantization = ps.primary_op.forced_output_quantization
311 use_0 = (
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200312 (
313 ps.primary_op.activation is None
314 or forced_ofm_quantization is not None
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200315 or (
316 ps.primary_op.type.is_avgpool_op()
317 and ps.primary_op.activation.op_type.is_relu_op()
318 and not ps.primary_op.rescale
319 )
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200320 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100321 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
322 and not fused_quantize
323 )
324 return use_0
325
326
327def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
328 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100329 op = ps.primary_op
330 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
331 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100332 return None
333 if use_zero_point_0(ps, tens, True):
334 zero_point = 0
335 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100336 zero_point = int(ifm_quant.zero_point)
337 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100338
339
340def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
341 """Gets quantization for OFM"""
342 op = ps.primary_op
343 # Check if operation's output quantization is should be used instead of the output tensor's quantization
344 # (used in LUTs)
345 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
346 if ofm_quant is None:
347 return None
348 if use_zero_point_0(ps, tens, False):
349 zero_point = 0
350 else:
351 zero_point = int(ofm_quant.zero_point)
352 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
353
354
Rickard Bolin17e53b52022-09-06 16:09:01 +0000355def create_feature_map(
356 tens: Tensor,
357 box: Box,
358 arch: ArchitectureFeatures,
359 op_shape4D: Shape4D,
Rickard Bolinfea15162022-07-04 16:19:16 +0000360 tile_base_offsets: List[int],
Rickard Bolin17e53b52022-09-06 16:09:01 +0000361 stride_multiplier: Optional[List[int]] = None,
362) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100363 """Creates feature map with common fields populated"""
364 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100365 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100366 fm.data_type = dtype_map[tens.dtype]
367 if tens.format == TensorFormat.NHWC:
368 fm.layout = NpuLayout.NHWC
369 elif tens.format == TensorFormat.NHCWB16:
370 fm.layout = NpuLayout.NHCWB16
371 else:
372 assert 0, "Incorrect tensor format"
Rickard Bolin17e53b52022-09-06 16:09:01 +0000373
374 strides = tens.get_strides(op_shape4D)
375 assert strides is not None
376
377 if stride_multiplier and stride_multiplier != [1, 1, 1]:
378 assert (
379 tens.format == TensorFormat.NHWC
380 ), "Only default stride multiplier ([1, 1, 1]) supported for NHCWB16 format"
381 # Multiply strides for C/H/W (in that order) with corresponding stride factor
382 for i, stride_factor in enumerate(stride_multiplier, start=1):
383 strides[i] *= stride_factor
384
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100385 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000386 box.start_coord, box.end_coord, strides, op_shape4D
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100387 )
Rickard Bolin17e53b52022-09-06 16:09:01 +0000388
Rickard Bolinfea15162022-07-04 16:19:16 +0000389 for idx, offset in enumerate(tile_base_offsets):
390 addresses[idx] += offset
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100391 fm.tiles = NpuTileBox(
392 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
393 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100394 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
Tim Hall68df8a12022-03-16 16:51:16 +0000395 fm.name = tens.name
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100396 return fm
397
398
Tim Halld784af72021-06-08 21:25:57 +0100399def create_weights(
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100400 weight_tensor: NpuWeightTensor, weight_box: Box, scale_tensor: NpuWeightTensor, arch: ArchitectureFeatures
401) -> Tuple[List[NpuAddressRange], List[NpuAddressRange]]:
Tim Halld8339a72021-05-27 18:49:40 +0100402 """Returns address ranges for weights and scales"""
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100403 weights = []
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100404 biases = []
Tim Halld784af72021-06-08 21:25:57 +0100405 shared_region = get_region(weight_tensor.mem_type, arch)
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100406 scale_region = get_region(scale_tensor.mem_type, arch) if scale_tensor else 0
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100407
Tim Halld8339a72021-05-27 18:49:40 +0100408 w_tensor_src = weight_tensor
409 if weight_tensor.src_tensor:
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100410 w_tensor_src = cast(NpuWeightTensor, weight_tensor.src_tensor)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100411
Tim Halld8339a72021-05-27 18:49:40 +0100412 core_offset = 0
413 for core in range(0, arch.ncores):
414 # Get weight range per core
415 key = WeightKey(core, weight_box.start_coord[-1])
416 if key in w_tensor_src.encoded_ranges:
417 weight_range = w_tensor_src.encoded_ranges[key]
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000418 if weight_tensor == w_tensor_src:
419 # Straight from source tensor
420 address = weight_tensor.address + weight_range.offset
Tim Hallb5df7732022-05-04 16:20:43 +0100421 else:
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000422 # Weight buffered tensor
423 address = weight_tensor.address + core_offset
424 core_offset += round_up(weight_range.total_bytes, 16)
Tim Halld8339a72021-05-27 18:49:40 +0100425
426 # Location of weights in tensor
427 addr_range = NpuAddressRange(
Tim Halld784af72021-06-08 21:25:57 +0100428 shared_region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16)
Tim Halld8339a72021-05-27 18:49:40 +0100429 )
430 weights.append(addr_range)
Tim Halld784af72021-06-08 21:25:57 +0100431
432 # Location of standalone scales or combined weights tensor scales
433 if scale_tensor:
434 assert scale_tensor.src_tensor is None # Must be standalone
435 scale_range = scale_tensor.encoded_ranges[key]
436 address = scale_tensor.address + scale_range.offset
437 addr_range = NpuAddressRange(scale_region, int(address), round_up(int(scale_range.scale_bytes), 16))
438 else:
439 addr_range = NpuAddressRange(shared_region, int(address), round_up(int(weight_range.scale_bytes), 16))
440
Tim Halld8339a72021-05-27 18:49:40 +0100441 biases.append(addr_range)
442
443 return weights, biases
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100444
445
446def create_npu_activation(op: Operation) -> NpuActivation:
447 """Creates fused activation function"""
448 if op.activation is None:
449 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
450 faf = op.activation.op_type
451 act_op = NpuActivationOp.NONE_OR_RELU
452 if faf == Op.Tanh:
453 act_op = NpuActivationOp.TANH
454 elif faf == Op.Sigmoid:
455 act_op = NpuActivationOp.SIGMOID
456 elif faf == Op.LUT:
457 act_op = NpuActivationOp.TABLE_LOOKUP
458 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000459 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100460
461 act = NpuActivation(act_op)
462 act.min = op.activation.min
463 act.max = op.activation.max
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200464 if act_op is NpuActivationOp.NONE_OR_RELU and op.type.is_avgpool_op() and not op.rescale:
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200465 quant = op.ofm.quantization
466 if quant and quant.zero_point: # Zero point is not 0
467 scale_f32 = 1 if quant.scale_f32 is None else quant.scale_f32
468 zero_point = quant.zero_point
469 if act.min is not None:
470 act.min = scale_f32 * quantise_float32(act.min, scale_f32, zero_point)
471 if act.max is not None:
472 act.max = scale_f32 * quantise_float32(act.max, scale_f32, zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100473 act.lookup_table_index = op.activation.lut_index
474 return act
475
476
477def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
478 """Sets common fields of the given operation"""
479 ps = cmd.ps
480 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100481
482 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100483 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100484 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100485
Rickard Bolinfea15162022-07-04 16:19:16 +0000486 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0], op.tile_base_offsets_ifm[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100487 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100488 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100489
490 out_block = cmd.ofm_box.get_block()
Rickard Bolinfea15162022-07-04 16:19:16 +0000491 npu_op.ofm = create_feature_map(
492 cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0], op.tile_base_offsets_ofm, op.ofm_stride_multiplier
493 )
Louis Verhaard69b31762020-11-17 09:45:20 +0100494 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100495 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
496
497 if cmd.weight_tensor is not None:
Tim Halld784af72021-06-08 21:25:57 +0100498 npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, cmd.scale_tensor, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100499 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100500 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
501 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100502 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
503
504 if not op.type.is_elementwise_op():
Rickard Bolin9ae34552022-06-09 13:07:17 +0000505 npu_op.padding = create_padding(cmd, op, npu_op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100506 npu_op.kernel = to_npu_kernel(op.kernel)
Tim Hall3c5cfe92022-03-16 16:31:57 +0000507 npu_op.ifm_upscale = resampling_mode_inv_map[op.ifm_resampling_mode]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100508 return npu_op
509
510
511def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
512 """Converts the command to NpuConv2DOperation"""
513 npu_op = NpuConv2DOperation()
514 set_common_op_fields(npu_op, cmd, arch)
515 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
516 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
517 else:
Tim Halld8339a72021-05-27 18:49:40 +0100518 if cmd.weight_tensor.src_tensor:
519 npu_op.block_traversal = cmd.weight_tensor.src_tensor.hw_traversal
520 else:
521 npu_op.block_traversal = cmd.weight_tensor.hw_traversal
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100522 return npu_op
523
524
525def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
526 """Converts the command to NpuConvDepthWiseOperation"""
527 npu_op = NpuConvDepthWiseOperation()
528 set_common_op_fields(npu_op, cmd, arch)
529 return npu_op
530
531
532def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
533 """Converts the command to NpuPoolingOperation"""
534 ps = cmd.ps
535 op = ps.primary_op
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100536 if op.type.is_maxpool_op():
537 pool_op = NpuPoolingOp.MAX
Tim Hall885033b2022-07-21 11:46:03 +0100538 elif op.type.is_avgpool_op() or op.type.is_resize_op():
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100539 pool_op = NpuPoolingOp.AVERAGE
540 elif op.type == Op.ReduceSum:
541 pool_op = NpuPoolingOp.REDUCE_SUM
542 else:
543 assert 0, f"Unknown pool type {op.type}"
544 npu_op = NpuPoolingOperation(pool_op)
545 set_common_op_fields(npu_op, cmd, arch)
546 # Pooling specific info
Dwight Lidman4f728c02020-12-17 15:14:45 +0100547 npu_op.rescale = op.rescale
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200548 if op.explicit_scaling:
549 # Note: reuse of rescale for explicit scaling to not expose this in the external API
550 assert npu_op.rescale is None
551 npu_op.rescale = op.explicit_scaling
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100552 return npu_op
553
554
555def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
556 """Converts the command to NpuElementWiseOperation"""
557 ps = cmd.ps
558 op = ps.primary_op
559 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
560 elemwise_op = elementwise_op_map[op.type]
561 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100562
Louis Verhaard1e170182020-11-26 11:42:04 +0100563 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100564 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
565 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
566 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100567 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
568 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
569 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100570 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100571 npu_op.reversed_operands = True
Rickard Bolinfea15162022-07-04 16:19:16 +0000572 npu_op.ifm2 = create_feature_map(
573 cmd.ifm2_tensor,
574 cmd.ifm2_box,
575 arch,
576 ps.ifm_shapes[1],
577 op.tile_base_offsets_ifm[1],
578 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100579 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
580 if cmd.ifm2_tensor.shape == []:
581 # scalar
James Peet7519d502021-07-19 16:47:58 +0100582 npu_op.ifm2_scalar = cmd.ifm2_tensor.get_scalar()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100583 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
584 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100585 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100586 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100587 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100588 set_common_op_fields(npu_op, cmd, arch)
589 # Check if output scale needs to be overridden
590 output_scale = None
Tim Hall885033b2022-07-21 11:46:03 +0100591 if op.type == Op.Add and op.original_type.is_resize_op():
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100592 # Force output scale same as the input scale for
Tim Hall885033b2022-07-21 11:46:03 +0100593 # resizebilinear/nearestneighbor 1x1 that is converted to add
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100594 output_scale = npu_op.ifm2.quantization.scale_f32
Tim Hall885033b2022-07-21 11:46:03 +0100595 elif op.type == Op.Abs:
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100596 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Tim Hall885033b2022-07-21 11:46:03 +0100597 elif op.type == Op.LeakyRelu:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100598 output_scale = op.attrs["alpha"]
Tim Hall885033b2022-07-21 11:46:03 +0100599 elif op.type in (Op.RescaleAdd, Op.RescaleMul):
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100600 assert op.rescale is not None, f"{op.type} must have rescale"
601 npu_op.rescale = op.rescale
Tim Hall885033b2022-07-21 11:46:03 +0100602 elif op.type in (Op.Add, Op.Mul, Op.Sub):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100603 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
604 output_scale = 1 / 0x3000
605 if output_scale is not None:
606 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
607 return npu_op
608
609
610def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
611 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100612 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100613 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100614 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100615 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100616 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100617
Tim Halld8339a72021-05-27 18:49:40 +0100618 if cmd.in_tensor.purpose == TensorPurpose.Weights:
619 # Get weight range per core
620 sz = 0
621 for core in range(0, arch.ncores):
622 key = WeightKey(core, cmd.box.start_coord[-1])
623 if key in cmd.in_tensor.encoded_ranges:
624 weight_range = cmd.in_tensor.encoded_ranges[key]
625 sz += round_up(weight_range.total_bytes, 16)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100626
Tim Halld8339a72021-05-27 18:49:40 +0100627 if core == 0:
628 weight_range = cmd.in_tensor.encoded_ranges[key]
629 src_addr = cmd.in_tensor.address + weight_range.offset
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000630 dest_addr = cmd.out_tensor.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100631 else:
Rickard Bolin17e53b52022-09-06 16:09:01 +0000632 src_addr = cmd.in_tensor.address_for_coordinate(cmd.box.start_coord)
633 dest_addr = cmd.out_tensor.address_for_coordinate(cmd.box.start_coord)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100634 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
635 src = NpuAddressRange(src_region, int(src_addr), int(sz))
636 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
637 return NpuDmaOperation(src, dest)
638
639
640def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
641 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100642 npu_op: NpuOperation
643 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100644 npu_op = create_dma_op(cmd, arch)
Tim Hall68df8a12022-03-16 16:51:16 +0000645 npu_op.name = cmd.out_tensor.name
Dwight Lidman9b43f842020-12-08 17:56:44 +0100646 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100647 npu_block_type = cmd.ps.primary_op.type.npu_block_type
648 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
649 npu_op = create_npu_conv2d_op(cmd, arch)
650 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
651 npu_op = create_npu_conv_depthwise_op(cmd, arch)
652 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
653 npu_op = create_npu_pool_op(cmd, arch)
654 elif npu_block_type == NpuBlockType.ElementWise:
655 npu_op = create_npu_elementwise_op(cmd, arch)
656 else:
657 assert 0, f"Unknown command type {npu_block_type}"
Tim Hall68df8a12022-03-16 16:51:16 +0000658 npu_op.name = cmd.ps.primary_op.name
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100659 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100660
661
662def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
663 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
664 # Convert high level command stream to list of NpuOperation
665 npu_op_list = []
666 npu_op_to_cmd = dict() # map from npu op to high level command
667 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100668 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100669 print("Warning: Skipping register command stream generation for", cmd.ps)
670 else:
671 npu_op = convert_command_to_npu_op(cmd, arch)
672 npu_op_list.append(npu_op)
673 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100674 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100675 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100676 if len(sg.high_level_command_stream) > 0:
677 stream_id = DebugDatabase.add_stream(sg)
678 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100679
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100680 def add_to_debug_db(npu_op: NpuOperation, offset: int):
681 """Adds info to the debug database"""
682 if not isinstance(npu_op, NpuDmaOperation):
683 cmd = npu_op_to_cmd[npu_op]
684 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100685
Louis Verhaard024c3552021-03-17 14:26:34 +0100686 sg.register_command_stream = generate_command_stream(
687 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
688 )