blob: fe1c9f8ae6f4f753d2d598a1eee511ec40ae1eb6 [file] [log] [blame]
Rickard Bolinbc6ee582022-11-04 08:24:29 +00001# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Jonas Ohlsson845e2322022-03-01 12:39:55 +010020from typing import cast
Louis Verhaard024c3552021-03-17 14:26:34 +010021from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010022from typing import List
23from typing import Optional
Jonas Ohlsson845e2322022-03-01 12:39:55 +010024from typing import Tuple
Louis Verhaarde8a5a782020-11-02 18:04:27 +010025
26from .api import NpuActivation
27from .api import NpuActivationOp
28from .api import NpuAddressRange
29from .api import NpuBlockOperation
30from .api import NpuBlockTraversal
31from .api import NpuConv2DOperation
32from .api import NpuConvDepthWiseOperation
33from .api import NpuDataType
34from .api import NpuDmaOperation
35from .api import NpuElementWiseOp
36from .api import NpuElementWiseOperation
37from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010038from .api import NpuLayout
39from .api import NpuOperation
Rickard Bolin9ae34552022-06-09 13:07:17 +000040from .api import NpuOperationType
Louis Verhaarde8a5a782020-11-02 18:04:27 +010041from .api import NpuPadding
42from .api import NpuPoolingOp
43from .api import NpuPoolingOperation
44from .api import NpuQuantization
45from .api import NpuResamplingMode
46from .api import NpuRoundingMode
47from .api import NpuShape3D
48from .api import NpuTileBox
49from .architecture_features import ArchitectureFeatures
50from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010051from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000052from .errors import UnsupportedFeatureError
Tim Hall3c5cfe92022-03-16 16:31:57 +000053from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +010054from .high_level_command_stream import Box
55from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010056from .high_level_command_stream import DMA
57from .high_level_command_stream import NpuStripe
Fredrik Svedberg838df0a2021-09-17 16:29:22 +020058from .numeric_util import quantise_float32
Tim Halld8339a72021-05-27 18:49:40 +010059from .numeric_util import round_up
Louis Verhaarde8a5a782020-11-02 18:04:27 +010060from .operation import NpuBlockType
61from .operation import Op
62from .operation import Operation
Rickard Bolin9ae34552022-06-09 13:07:17 +000063from .operation import Padding
Louis Verhaard1e170182020-11-26 11:42:04 +010064from .register_command_stream_generator import generate_command_stream
65from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010066from .register_command_stream_util import to_npu_kernel
67from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000068from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010069from .tensor import MemType
70from .tensor import Tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010071from .tensor import TensorFormat
72from .tensor import TensorPurpose
Jonas Ohlsson845e2322022-03-01 12:39:55 +010073from .weight_compressor import NpuWeightTensor
Tim Halld8339a72021-05-27 18:49:40 +010074from .weight_compressor import WeightKey
Louis Verhaarde8a5a782020-11-02 18:04:27 +010075
76
Louis Verhaarde8a5a782020-11-02 18:04:27 +010077class BasePointerIndex(IntEnum):
78 WeightTensor = 0 # base address index for the Weight tensor
79 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
80 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010081
82
83dtype_map = {
84 DataType.uint8: NpuDataType.UINT8,
85 DataType.int8: NpuDataType.INT8,
86 DataType.uint16: NpuDataType.UINT16,
87 DataType.int16: NpuDataType.INT16,
88 DataType.int32: NpuDataType.INT32,
89}
90
91
Louis Verhaarde8a5a782020-11-02 18:04:27 +010092# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
93elementwise_op_map = {
94 Op.Mul: NpuElementWiseOp.MUL,
95 Op.Add: NpuElementWiseOp.ADD,
96 Op.Sub: NpuElementWiseOp.SUB,
97 Op.Minimum: NpuElementWiseOp.MIN,
98 Op.Maximum: NpuElementWiseOp.MAX,
99 Op.LeakyRelu: NpuElementWiseOp.LRELU,
100 Op.Abs: NpuElementWiseOp.ABS,
101 Op.CLZ: NpuElementWiseOp.CLZ,
102 Op.SHR: NpuElementWiseOp.SHR,
103 Op.SHL: NpuElementWiseOp.SHL,
104}
105
106
Tim Hall3c5cfe92022-03-16 16:31:57 +0000107# inverse of the resampling_mode_map in the register command stream generator
108resampling_mode_inv_map = {
109 resampling_mode.NONE: NpuResamplingMode.NONE,
110 resampling_mode.NEAREST: NpuResamplingMode.NEAREST,
111 resampling_mode.TRANSPOSE: NpuResamplingMode.TRANSPOSE,
112}
113
114
Johan Alfvén56a71b02022-10-19 11:20:12 +0200115def ifm_ifm2_correct_order(ifm_shape: Shape4D, ifm2_shape: Shape4D) -> bool:
116
117 if ifm_shape is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100118 # Scalar needs to be in IFM2
119 return False
Johan Alfvén56a71b02022-10-19 11:20:12 +0200120 if ifm2_shape is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100121 return True
122
Johan Alfvén56a71b02022-10-19 11:20:12 +0200123 for ifm, ifm2 in zip(ifm_shape.as_list(), ifm2_shape.as_list()):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100124 if ifm != ifm2 and ifm == 1:
125 # Broadcasted FM needs to be in IFM2
126 return False
127 return True
128
129
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100130def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100131 """Specifies type of rounding to be used"""
132 rounding_mode = NpuRoundingMode.TFL
Tim Hall885033b2022-07-21 11:46:03 +0100133 if op.type.is_resize_op():
Dwight Lidman9d243932021-08-10 12:53:12 +0200134 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100135 elif (
136 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
137 and op.ifm.dtype == DataType.int16
138 ):
139 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100140 elif (
141 not fused_quantize
142 and op.type.is_avgpool_op()
143 and op.memory_function == Op.ConcatSliceWrite
144 and op.kernel.elements_wh() == 1
145 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100146 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100147 if op.rounding_mode is not None:
148 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100149 return rounding_mode
150
151
Rickard Bolin9ae34552022-06-09 13:07:17 +0000152def create_padding(cmd: NpuStripe, primary_op: Operation, npu_op: NpuBlockOperation) -> NpuPadding:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100153 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
154 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100155 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100156
157 # Check if this is for horizontal ifm streaming
158 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100159 top = cmd.pad_top
160 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100161
Tim Hall3751aa42021-12-16 13:17:29 +0000162 # the ifm box coordinate range depends upon whether the primary op was combined with a split slice read
163 ifm_read_offset = primary_op.read_offsets[0]
164 ifm_read_shape = primary_op.read_shapes[0]
165 if ifm_read_offset is None or len(ifm_read_offset) < 2:
166 box_start_coord_min = 0
167 box_end_coord_max = cmd.ps.ifm_shapes[0].width
168 else:
169 box_start_coord_min = ifm_read_offset[-2]
170 box_end_coord_max = ifm_read_shape[-2]
171
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100172 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
173 # because of activation function needed to be fused.
Tim Hall3751aa42021-12-16 13:17:29 +0000174 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > box_start_coord_min:
175 left = 0
176 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < box_end_coord_max:
177 right = 0
Rickard Bolin9ae34552022-06-09 13:07:17 +0000178
179 # If tile padding is selected, modify the tile base addresses and set NpuPadding to zero.
180 if primary_op.attrs.get("padding", None) == Padding.TILE:
181 assert cmd.ifm_tensor.format == TensorFormat.NHCWB16, "Tensor format NHCWB16 required to perform tile padding"
182 assert npu_op.op_type == NpuOperationType.ConvDepthWise, "Tile padding only supported for depthwise convolution"
183 assert npu_op.ifm is not None, "Feature map must be initialized to modify the tile addresses"
184 npu_op.ifm.tiles = modify_tile_addresses_for_padding(
185 npu_op.ifm.tiles,
186 primary_op.attrs.get("explicit_padding", None),
187 channels=cmd.ps.ifm_shapes[0].depth,
188 dtype=cmd.ifm_tensor.dtype,
189 )
190 top, left, bottom, right = 0, 0, 0, 0
Rickard Bolinfea15162022-07-04 16:19:16 +0000191
Louis Verhaard69b31762020-11-17 09:45:20 +0100192 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100193
194
Rickard Bolin9ae34552022-06-09 13:07:17 +0000195def modify_tile_addresses_for_padding(
196 tile_box: NpuTileBox, padding_direction: List[int], channels: int, dtype: DataType
197) -> NpuTileBox:
198 # Addresses are 16-bytes aligned when using the NHCWB16 format, which is required to utilize tiling
199 # Calculate the offset to top right, bottom left and bottom right element in the IFM (top left offset is 0)
200 """
201 Example: 4x4x1 IFM
202 | a b c d | <-- Offset to TR ('d') is (w0-1) = 3
203 | e f g h |
204 | i j k l |
205 | m n o p | <-- Offset to BL ('m') is (w0*(h0-1)) = 12 and to BR ('p') ((w0*h0)-1) = 15
206 """
207 h0, h1, w0, addresses = tile_box
208 elem_size = 2 if dtype == DataType.int16 else 1
209 tr_offset = (w0 - 1) * 16 * elem_size
210 bl_offset = w0 * (h0 - 1) * 16 * (round_up(channels, 16) // 16) * elem_size
211 br_offset = tr_offset + bl_offset
212
213 # Explicit padding order: (Top, Left, Bottom, Right)
214 if padding_direction == (1, 1, 0, 0):
215 # Pad top left corner
216 """
217 | a a b |
218 | a b | -> | a a b |
219 | c d | | c c d |
220 """
221 addresses = [addresses[0]] * 4
222 h0, h1, w0 = 1, 1, 1
223
224 elif padding_direction == (1, 0, 0, 1):
225 # Pad top right corner
226 """
227 | a b b |
228 | a b | -> | a b b |
229 | c d | | c d d |
230 """
231 addresses = [addresses[0], addresses[0] + tr_offset, addresses[0], addresses[0] + tr_offset]
232 h0, h1, w0 = 1, 1, w0
233
234 elif padding_direction == (0, 1, 1, 0):
235 # Pad bottom left corner
236 """
237 | a b | | a a b |
238 | c d | -> | c c d |
239 | c c d |
240 """
241 addresses = [addresses[0], addresses[0], addresses[0] + bl_offset, addresses[0] + bl_offset]
242 h0, h1, w0 = h0, h1, 1
243
244 elif padding_direction == (0, 0, 1, 1):
245 # Pad bottom right corner
246 """
247 | a b | | a b b |
248 | c d | -> | c d d |
249 | c d d |
250 """
251 addresses = [
252 addresses[0],
253 addresses[0] + tr_offset,
254 addresses[0] + bl_offset,
255 addresses[0] + br_offset,
256 ]
257 # h0, h1, w0 = h0, h1, w0
258 else:
259 assert 0, "Invalid padding direction for tile padding"
260
261 return NpuTileBox(height_0=h0, height_1=h1, width_0=w0, addresses=[int(addr) for addr in addresses])
262
263
Louis Verhaard024c3552021-03-17 14:26:34 +0100264def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000265 base_ptr_idx_map = {
266 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
267 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
268 MemType.Scratch: BasePointerIndex.ScratchTensor,
269 }
270
271 if arch.is_spilling_enabled():
272 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100273 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000274 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
275
Louis Verhaard024c3552021-03-17 14:26:34 +0100276 return base_ptr_idx_map[mem_type].value
277
278
279def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
280 """Returns map region -> max size of the region in bytes"""
281 mem_limits = dict()
282 for mem_type in MemType.all():
283 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
284 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
285 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100286
287
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100288def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
289 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100290 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100291 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100292 block = ofm_box.get_block()
293 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100294
295
296def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
297 """Checks if quantization should use 0 as zero point"""
298 if tens.dtype == DataType.int32 and is_ifm_tensor:
299 return True
Rickard Bolinfea15162022-07-04 16:19:16 +0000300 # Force zero point to 0 for ResizeBilinear when converting to a DepthwiseConv since the reference kernel
301 # will ignore the zero point.
302 if ps.primary_op.original_type == Op.ResizeBilinear and ps.primary_op.type == Op.DepthwiseConv2DBias:
303 return True
Tim Hall885033b2022-07-21 11:46:03 +0100304 if ps.primary_op.type not in (Op.AvgPool, Op.CLZ, Op.SHL) and not ps.primary_op.type.is_resize_op():
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100305 return False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200306 if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
307 return False
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100308 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
309 forced_ofm_quantization = ps.primary_op.forced_output_quantization
310 use_0 = (
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200311 (
312 ps.primary_op.activation is None
313 or forced_ofm_quantization is not None
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200314 or (ps.primary_op.type.is_avgpool_op() and ps.primary_op.activation.op_type.is_relu_op())
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200315 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100316 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
317 and not fused_quantize
318 )
319 return use_0
320
321
322def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
323 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100324 op = ps.primary_op
325 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
326 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100327 return None
328 if use_zero_point_0(ps, tens, True):
329 zero_point = 0
330 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100331 zero_point = int(ifm_quant.zero_point)
332 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100333
334
335def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
336 """Gets quantization for OFM"""
337 op = ps.primary_op
338 # Check if operation's output quantization is should be used instead of the output tensor's quantization
339 # (used in LUTs)
340 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
341 if ofm_quant is None:
342 return None
343 if use_zero_point_0(ps, tens, False):
344 zero_point = 0
345 else:
346 zero_point = int(ofm_quant.zero_point)
347 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
348
349
Rickard Bolin17e53b52022-09-06 16:09:01 +0000350def create_feature_map(
351 tens: Tensor,
352 box: Box,
353 arch: ArchitectureFeatures,
354 op_shape4D: Shape4D,
Rickard Bolinfea15162022-07-04 16:19:16 +0000355 tile_base_offsets: List[int],
Rickard Bolin17e53b52022-09-06 16:09:01 +0000356 stride_multiplier: Optional[List[int]] = None,
357) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100358 """Creates feature map with common fields populated"""
359 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100360 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100361 fm.data_type = dtype_map[tens.dtype]
362 if tens.format == TensorFormat.NHWC:
363 fm.layout = NpuLayout.NHWC
364 elif tens.format == TensorFormat.NHCWB16:
365 fm.layout = NpuLayout.NHCWB16
366 else:
367 assert 0, "Incorrect tensor format"
Rickard Bolin17e53b52022-09-06 16:09:01 +0000368
369 strides = tens.get_strides(op_shape4D)
370 assert strides is not None
371
372 if stride_multiplier and stride_multiplier != [1, 1, 1]:
373 assert (
374 tens.format == TensorFormat.NHWC
375 ), "Only default stride multiplier ([1, 1, 1]) supported for NHCWB16 format"
376 # Multiply strides for C/H/W (in that order) with corresponding stride factor
377 for i, stride_factor in enumerate(stride_multiplier, start=1):
378 strides[i] *= stride_factor
379
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100380 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000381 box.start_coord, box.end_coord, strides, op_shape4D
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100382 )
Rickard Bolin17e53b52022-09-06 16:09:01 +0000383
Rickard Bolinfea15162022-07-04 16:19:16 +0000384 for idx, offset in enumerate(tile_base_offsets):
385 addresses[idx] += offset
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100386 fm.tiles = NpuTileBox(
387 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
388 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100389 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
Tim Hall68df8a12022-03-16 16:51:16 +0000390 fm.name = tens.name
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100391 return fm
392
393
Tim Halld784af72021-06-08 21:25:57 +0100394def create_weights(
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100395 weight_tensor: NpuWeightTensor, weight_box: Box, scale_tensor: NpuWeightTensor, arch: ArchitectureFeatures
396) -> Tuple[List[NpuAddressRange], List[NpuAddressRange]]:
Tim Halld8339a72021-05-27 18:49:40 +0100397 """Returns address ranges for weights and scales"""
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100398 weights = []
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100399 biases = []
Tim Halld784af72021-06-08 21:25:57 +0100400 shared_region = get_region(weight_tensor.mem_type, arch)
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100401 scale_region = get_region(scale_tensor.mem_type, arch) if scale_tensor else 0
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100402
Tim Halld8339a72021-05-27 18:49:40 +0100403 w_tensor_src = weight_tensor
404 if weight_tensor.src_tensor:
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100405 w_tensor_src = cast(NpuWeightTensor, weight_tensor.src_tensor)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100406
Tim Halld8339a72021-05-27 18:49:40 +0100407 core_offset = 0
408 for core in range(0, arch.ncores):
409 # Get weight range per core
410 key = WeightKey(core, weight_box.start_coord[-1])
411 if key in w_tensor_src.encoded_ranges:
412 weight_range = w_tensor_src.encoded_ranges[key]
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000413 if weight_tensor == w_tensor_src:
414 # Straight from source tensor
415 address = weight_tensor.address + weight_range.offset
Tim Hallb5df7732022-05-04 16:20:43 +0100416 else:
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000417 # Weight buffered tensor
418 address = weight_tensor.address + core_offset
419 core_offset += round_up(weight_range.total_bytes, 16)
Tim Halld8339a72021-05-27 18:49:40 +0100420
421 # Location of weights in tensor
422 addr_range = NpuAddressRange(
Tim Halld784af72021-06-08 21:25:57 +0100423 shared_region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16)
Tim Halld8339a72021-05-27 18:49:40 +0100424 )
425 weights.append(addr_range)
Tim Halld784af72021-06-08 21:25:57 +0100426
427 # Location of standalone scales or combined weights tensor scales
428 if scale_tensor:
429 assert scale_tensor.src_tensor is None # Must be standalone
430 scale_range = scale_tensor.encoded_ranges[key]
431 address = scale_tensor.address + scale_range.offset
432 addr_range = NpuAddressRange(scale_region, int(address), round_up(int(scale_range.scale_bytes), 16))
433 else:
434 addr_range = NpuAddressRange(shared_region, int(address), round_up(int(weight_range.scale_bytes), 16))
435
Tim Halld8339a72021-05-27 18:49:40 +0100436 biases.append(addr_range)
437
438 return weights, biases
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100439
440
441def create_npu_activation(op: Operation) -> NpuActivation:
442 """Creates fused activation function"""
443 if op.activation is None:
444 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
445 faf = op.activation.op_type
446 act_op = NpuActivationOp.NONE_OR_RELU
447 if faf == Op.Tanh:
448 act_op = NpuActivationOp.TANH
449 elif faf == Op.Sigmoid:
450 act_op = NpuActivationOp.SIGMOID
451 elif faf == Op.LUT:
452 act_op = NpuActivationOp.TABLE_LOOKUP
453 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000454 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100455
456 act = NpuActivation(act_op)
457 act.min = op.activation.min
458 act.max = op.activation.max
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200459 if act_op is NpuActivationOp.NONE_OR_RELU and op.type.is_avgpool_op() and not op.explicit_scaling:
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200460 quant = op.ofm.quantization
461 if quant and quant.zero_point: # Zero point is not 0
462 scale_f32 = 1 if quant.scale_f32 is None else quant.scale_f32
463 zero_point = quant.zero_point
464 if act.min is not None:
465 act.min = scale_f32 * quantise_float32(act.min, scale_f32, zero_point)
466 if act.max is not None:
467 act.max = scale_f32 * quantise_float32(act.max, scale_f32, zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100468 act.lookup_table_index = op.activation.lut_index
469 return act
470
471
472def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
473 """Sets common fields of the given operation"""
474 ps = cmd.ps
475 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100476
477 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100478 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100479 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100480
Rickard Bolinfea15162022-07-04 16:19:16 +0000481 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0], op.tile_base_offsets_ifm[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100482 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100483 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100484
485 out_block = cmd.ofm_box.get_block()
Rickard Bolinfea15162022-07-04 16:19:16 +0000486 npu_op.ofm = create_feature_map(
487 cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0], op.tile_base_offsets_ofm, op.ofm_stride_multiplier
488 )
Louis Verhaard69b31762020-11-17 09:45:20 +0100489 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100490 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
491
492 if cmd.weight_tensor is not None:
Tim Halld784af72021-06-08 21:25:57 +0100493 npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, cmd.scale_tensor, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100494 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100495 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
496 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100497 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
498
499 if not op.type.is_elementwise_op():
Rickard Bolin9ae34552022-06-09 13:07:17 +0000500 npu_op.padding = create_padding(cmd, op, npu_op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100501 npu_op.kernel = to_npu_kernel(op.kernel)
Tim Hall3c5cfe92022-03-16 16:31:57 +0000502 npu_op.ifm_upscale = resampling_mode_inv_map[op.ifm_resampling_mode]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100503 return npu_op
504
505
506def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
507 """Converts the command to NpuConv2DOperation"""
508 npu_op = NpuConv2DOperation()
509 set_common_op_fields(npu_op, cmd, arch)
510 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
511 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
512 else:
Tim Halld8339a72021-05-27 18:49:40 +0100513 if cmd.weight_tensor.src_tensor:
514 npu_op.block_traversal = cmd.weight_tensor.src_tensor.hw_traversal
515 else:
516 npu_op.block_traversal = cmd.weight_tensor.hw_traversal
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100517 return npu_op
518
519
520def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
521 """Converts the command to NpuConvDepthWiseOperation"""
522 npu_op = NpuConvDepthWiseOperation()
523 set_common_op_fields(npu_op, cmd, arch)
524 return npu_op
525
526
527def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
528 """Converts the command to NpuPoolingOperation"""
529 ps = cmd.ps
530 op = ps.primary_op
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100531 if op.type.is_maxpool_op():
532 pool_op = NpuPoolingOp.MAX
Tim Hall885033b2022-07-21 11:46:03 +0100533 elif op.type.is_avgpool_op() or op.type.is_resize_op():
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100534 pool_op = NpuPoolingOp.AVERAGE
535 elif op.type == Op.ReduceSum:
536 pool_op = NpuPoolingOp.REDUCE_SUM
537 else:
538 assert 0, f"Unknown pool type {op.type}"
539 npu_op = NpuPoolingOperation(pool_op)
540 set_common_op_fields(npu_op, cmd, arch)
541 # Pooling specific info
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200542 if op.explicit_scaling:
543 # Note: reuse of rescale for explicit scaling to not expose this in the external API
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200544 npu_op.rescale = op.explicit_scaling
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100545 return npu_op
546
547
548def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
549 """Converts the command to NpuElementWiseOperation"""
550 ps = cmd.ps
551 op = ps.primary_op
552 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
553 elemwise_op = elementwise_op_map[op.type]
554 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100555
Louis Verhaard1e170182020-11-26 11:42:04 +0100556 if elemwise_op not in UNARY_ELEMWISE_OPS:
Johan Alfvén56a71b02022-10-19 11:20:12 +0200557 ifm_shape = None if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0]
558 ifm2_shape = None if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1]
Fredrik Svedbergb81e1bb2022-10-11 21:50:51 +0200559 if cmd.reversed_operands:
560 assert ifm_ifm2_correct_order(ifm_shape, ifm2_shape)
561 npu_op.reversed_operands = True
562 elif not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100563 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
564 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
565 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100566 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100567 npu_op.reversed_operands = True
Rickard Bolinfea15162022-07-04 16:19:16 +0000568 npu_op.ifm2 = create_feature_map(
569 cmd.ifm2_tensor,
570 cmd.ifm2_box,
571 arch,
572 ps.ifm_shapes[1],
573 op.tile_base_offsets_ifm[1],
574 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100575 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
576 if cmd.ifm2_tensor.shape == []:
577 # scalar
James Peet7519d502021-07-19 16:47:58 +0100578 npu_op.ifm2_scalar = cmd.ifm2_tensor.get_scalar()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100579 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
580 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100581 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100582 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100583 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100584 set_common_op_fields(npu_op, cmd, arch)
585 # Check if output scale needs to be overridden
586 output_scale = None
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200587 if op.explicit_scaling is not None:
588 assert not op.explicit_scaling.per_channel
589 assert op.type in (Op.Add, Op.Mul, Op.Sub)
590 npu_op.rescale = (op.explicit_scaling.multiplier[0], op.explicit_scaling.shift[0])
591 elif op.type == Op.Add and op.original_type.is_resize_op():
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100592 # Force output scale same as the input scale for
Tim Hall885033b2022-07-21 11:46:03 +0100593 # resizebilinear/nearestneighbor 1x1 that is converted to add
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100594 output_scale = npu_op.ifm2.quantization.scale_f32
Tim Hall885033b2022-07-21 11:46:03 +0100595 elif op.type == Op.Abs:
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100596 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Tim Hall885033b2022-07-21 11:46:03 +0100597 elif op.type == Op.LeakyRelu:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100598 output_scale = op.attrs["alpha"]
Tim Hall885033b2022-07-21 11:46:03 +0100599 elif op.type in (Op.Add, Op.Mul, Op.Sub):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100600 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
601 output_scale = 1 / 0x3000
602 if output_scale is not None:
603 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
604 return npu_op
605
606
607def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
608 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100609 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100610 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100611 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100612 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100613 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100614
Tim Halld8339a72021-05-27 18:49:40 +0100615 if cmd.in_tensor.purpose == TensorPurpose.Weights:
616 # Get weight range per core
617 sz = 0
618 for core in range(0, arch.ncores):
619 key = WeightKey(core, cmd.box.start_coord[-1])
620 if key in cmd.in_tensor.encoded_ranges:
621 weight_range = cmd.in_tensor.encoded_ranges[key]
622 sz += round_up(weight_range.total_bytes, 16)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100623
Tim Halld8339a72021-05-27 18:49:40 +0100624 if core == 0:
625 weight_range = cmd.in_tensor.encoded_ranges[key]
626 src_addr = cmd.in_tensor.address + weight_range.offset
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000627 dest_addr = cmd.out_tensor.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100628 else:
Rickard Bolin17e53b52022-09-06 16:09:01 +0000629 src_addr = cmd.in_tensor.address_for_coordinate(cmd.box.start_coord)
630 dest_addr = cmd.out_tensor.address_for_coordinate(cmd.box.start_coord)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100631 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
632 src = NpuAddressRange(src_region, int(src_addr), int(sz))
633 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
634 return NpuDmaOperation(src, dest)
635
636
637def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
638 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100639 npu_op: NpuOperation
640 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100641 npu_op = create_dma_op(cmd, arch)
Tim Hall68df8a12022-03-16 16:51:16 +0000642 npu_op.name = cmd.out_tensor.name
Dwight Lidman9b43f842020-12-08 17:56:44 +0100643 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100644 npu_block_type = cmd.ps.primary_op.type.npu_block_type
645 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
646 npu_op = create_npu_conv2d_op(cmd, arch)
647 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
648 npu_op = create_npu_conv_depthwise_op(cmd, arch)
649 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
650 npu_op = create_npu_pool_op(cmd, arch)
651 elif npu_block_type == NpuBlockType.ElementWise:
652 npu_op = create_npu_elementwise_op(cmd, arch)
653 else:
654 assert 0, f"Unknown command type {npu_block_type}"
Tim Hall68df8a12022-03-16 16:51:16 +0000655 npu_op.name = cmd.ps.primary_op.name
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100656 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100657
658
659def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
660 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
661 # Convert high level command stream to list of NpuOperation
662 npu_op_list = []
663 npu_op_to_cmd = dict() # map from npu op to high level command
664 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100665 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100666 print("Warning: Skipping register command stream generation for", cmd.ps)
667 else:
668 npu_op = convert_command_to_npu_op(cmd, arch)
669 npu_op_list.append(npu_op)
670 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100671 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100672 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100673 if len(sg.high_level_command_stream) > 0:
674 stream_id = DebugDatabase.add_stream(sg)
675 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100676
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100677 def add_to_debug_db(npu_op: NpuOperation, offset: int):
678 """Adds info to the debug database"""
679 if not isinstance(npu_op, NpuDmaOperation):
680 cmd = npu_op_to_cmd[npu_op]
681 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100682
Louis Verhaard024c3552021-03-17 14:26:34 +0100683 sg.register_command_stream = generate_command_stream(
684 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
685 )