blob: 974d980ce8f93d6e77007f68e1bd02c8ddfa672c [file] [log] [blame]
Rickard Bolinfea15162022-07-04 16:19:16 +00001# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Jonas Ohlsson845e2322022-03-01 12:39:55 +010020from typing import cast
Louis Verhaard024c3552021-03-17 14:26:34 +010021from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010022from typing import List
23from typing import Optional
Jonas Ohlsson845e2322022-03-01 12:39:55 +010024from typing import Tuple
Louis Verhaarde8a5a782020-11-02 18:04:27 +010025
26from .api import NpuActivation
27from .api import NpuActivationOp
28from .api import NpuAddressRange
29from .api import NpuBlockOperation
30from .api import NpuBlockTraversal
31from .api import NpuConv2DOperation
32from .api import NpuConvDepthWiseOperation
33from .api import NpuDataType
34from .api import NpuDmaOperation
35from .api import NpuElementWiseOp
36from .api import NpuElementWiseOperation
37from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010038from .api import NpuLayout
39from .api import NpuOperation
Rickard Bolin9ae34552022-06-09 13:07:17 +000040from .api import NpuOperationType
Louis Verhaarde8a5a782020-11-02 18:04:27 +010041from .api import NpuPadding
42from .api import NpuPoolingOp
43from .api import NpuPoolingOperation
44from .api import NpuQuantization
45from .api import NpuResamplingMode
46from .api import NpuRoundingMode
47from .api import NpuShape3D
48from .api import NpuTileBox
49from .architecture_features import ArchitectureFeatures
50from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010051from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000052from .errors import UnsupportedFeatureError
Tim Hall3c5cfe92022-03-16 16:31:57 +000053from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +010054from .high_level_command_stream import Box
55from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010056from .high_level_command_stream import DMA
57from .high_level_command_stream import NpuStripe
Fredrik Svedberg838df0a2021-09-17 16:29:22 +020058from .numeric_util import quantise_float32
Tim Halld8339a72021-05-27 18:49:40 +010059from .numeric_util import round_up
Louis Verhaarde8a5a782020-11-02 18:04:27 +010060from .operation import NpuBlockType
61from .operation import Op
62from .operation import Operation
Rickard Bolin9ae34552022-06-09 13:07:17 +000063from .operation import Padding
Louis Verhaard1e170182020-11-26 11:42:04 +010064from .register_command_stream_generator import generate_command_stream
65from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010066from .register_command_stream_util import to_npu_kernel
67from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000068from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010069from .tensor import MemType
70from .tensor import Tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010071from .tensor import TensorFormat
72from .tensor import TensorPurpose
Jonas Ohlsson845e2322022-03-01 12:39:55 +010073from .weight_compressor import NpuWeightTensor
Tim Halld8339a72021-05-27 18:49:40 +010074from .weight_compressor import WeightKey
Louis Verhaarde8a5a782020-11-02 18:04:27 +010075
76
Louis Verhaarde8a5a782020-11-02 18:04:27 +010077class BasePointerIndex(IntEnum):
78 WeightTensor = 0 # base address index for the Weight tensor
79 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
80 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010081
82
83dtype_map = {
84 DataType.uint8: NpuDataType.UINT8,
85 DataType.int8: NpuDataType.INT8,
86 DataType.uint16: NpuDataType.UINT16,
87 DataType.int16: NpuDataType.INT16,
88 DataType.int32: NpuDataType.INT32,
89}
90
91
Louis Verhaarde8a5a782020-11-02 18:04:27 +010092# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
93elementwise_op_map = {
94 Op.Mul: NpuElementWiseOp.MUL,
95 Op.Add: NpuElementWiseOp.ADD,
96 Op.Sub: NpuElementWiseOp.SUB,
97 Op.Minimum: NpuElementWiseOp.MIN,
98 Op.Maximum: NpuElementWiseOp.MAX,
99 Op.LeakyRelu: NpuElementWiseOp.LRELU,
100 Op.Abs: NpuElementWiseOp.ABS,
101 Op.CLZ: NpuElementWiseOp.CLZ,
102 Op.SHR: NpuElementWiseOp.SHR,
103 Op.SHL: NpuElementWiseOp.SHL,
104}
105
106
Tim Hall3c5cfe92022-03-16 16:31:57 +0000107# inverse of the resampling_mode_map in the register command stream generator
108resampling_mode_inv_map = {
109 resampling_mode.NONE: NpuResamplingMode.NONE,
110 resampling_mode.NEAREST: NpuResamplingMode.NEAREST,
111 resampling_mode.TRANSPOSE: NpuResamplingMode.TRANSPOSE,
112}
113
114
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100115def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
116 if ifm_shape == []:
117 # Scalar needs to be in IFM2
118 return False
119 if ifm2_shape == []:
120 return True
121
122 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
123 if ifm != ifm2 and ifm == 1:
124 # Broadcasted FM needs to be in IFM2
125 return False
126 return True
127
128
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100129def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100130 """Specifies type of rounding to be used"""
131 rounding_mode = NpuRoundingMode.TFL
Tim Hall885033b2022-07-21 11:46:03 +0100132 if op.type.is_resize_op():
Dwight Lidman9d243932021-08-10 12:53:12 +0200133 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100134 elif (
135 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
136 and op.ifm.dtype == DataType.int16
137 ):
138 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100139 elif (
140 not fused_quantize
141 and op.type.is_avgpool_op()
142 and op.memory_function == Op.ConcatSliceWrite
143 and op.kernel.elements_wh() == 1
144 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100145 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100146 if op.rounding_mode is not None:
147 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100148 return rounding_mode
149
150
Rickard Bolin9ae34552022-06-09 13:07:17 +0000151def create_padding(cmd: NpuStripe, primary_op: Operation, npu_op: NpuBlockOperation) -> NpuPadding:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100152 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
153 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100154 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100155
156 # Check if this is for horizontal ifm streaming
157 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100158 top = cmd.pad_top
159 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100160
Tim Hall3751aa42021-12-16 13:17:29 +0000161 # the ifm box coordinate range depends upon whether the primary op was combined with a split slice read
162 ifm_read_offset = primary_op.read_offsets[0]
163 ifm_read_shape = primary_op.read_shapes[0]
164 if ifm_read_offset is None or len(ifm_read_offset) < 2:
165 box_start_coord_min = 0
166 box_end_coord_max = cmd.ps.ifm_shapes[0].width
167 else:
168 box_start_coord_min = ifm_read_offset[-2]
169 box_end_coord_max = ifm_read_shape[-2]
170
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100171 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
172 # because of activation function needed to be fused.
Tim Hall3751aa42021-12-16 13:17:29 +0000173 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > box_start_coord_min:
174 left = 0
175 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < box_end_coord_max:
176 right = 0
Rickard Bolin9ae34552022-06-09 13:07:17 +0000177
178 # If tile padding is selected, modify the tile base addresses and set NpuPadding to zero.
179 if primary_op.attrs.get("padding", None) == Padding.TILE:
180 assert cmd.ifm_tensor.format == TensorFormat.NHCWB16, "Tensor format NHCWB16 required to perform tile padding"
181 assert npu_op.op_type == NpuOperationType.ConvDepthWise, "Tile padding only supported for depthwise convolution"
182 assert npu_op.ifm is not None, "Feature map must be initialized to modify the tile addresses"
183 npu_op.ifm.tiles = modify_tile_addresses_for_padding(
184 npu_op.ifm.tiles,
185 primary_op.attrs.get("explicit_padding", None),
186 channels=cmd.ps.ifm_shapes[0].depth,
187 dtype=cmd.ifm_tensor.dtype,
188 )
189 top, left, bottom, right = 0, 0, 0, 0
Rickard Bolinfea15162022-07-04 16:19:16 +0000190
Louis Verhaard69b31762020-11-17 09:45:20 +0100191 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100192
193
Rickard Bolin9ae34552022-06-09 13:07:17 +0000194def modify_tile_addresses_for_padding(
195 tile_box: NpuTileBox, padding_direction: List[int], channels: int, dtype: DataType
196) -> NpuTileBox:
197 # Addresses are 16-bytes aligned when using the NHCWB16 format, which is required to utilize tiling
198 # Calculate the offset to top right, bottom left and bottom right element in the IFM (top left offset is 0)
199 """
200 Example: 4x4x1 IFM
201 | a b c d | <-- Offset to TR ('d') is (w0-1) = 3
202 | e f g h |
203 | i j k l |
204 | m n o p | <-- Offset to BL ('m') is (w0*(h0-1)) = 12 and to BR ('p') ((w0*h0)-1) = 15
205 """
206 h0, h1, w0, addresses = tile_box
207 elem_size = 2 if dtype == DataType.int16 else 1
208 tr_offset = (w0 - 1) * 16 * elem_size
209 bl_offset = w0 * (h0 - 1) * 16 * (round_up(channels, 16) // 16) * elem_size
210 br_offset = tr_offset + bl_offset
211
212 # Explicit padding order: (Top, Left, Bottom, Right)
213 if padding_direction == (1, 1, 0, 0):
214 # Pad top left corner
215 """
216 | a a b |
217 | a b | -> | a a b |
218 | c d | | c c d |
219 """
220 addresses = [addresses[0]] * 4
221 h0, h1, w0 = 1, 1, 1
222
223 elif padding_direction == (1, 0, 0, 1):
224 # Pad top right corner
225 """
226 | a b b |
227 | a b | -> | a b b |
228 | c d | | c d d |
229 """
230 addresses = [addresses[0], addresses[0] + tr_offset, addresses[0], addresses[0] + tr_offset]
231 h0, h1, w0 = 1, 1, w0
232
233 elif padding_direction == (0, 1, 1, 0):
234 # Pad bottom left corner
235 """
236 | a b | | a a b |
237 | c d | -> | c c d |
238 | c c d |
239 """
240 addresses = [addresses[0], addresses[0], addresses[0] + bl_offset, addresses[0] + bl_offset]
241 h0, h1, w0 = h0, h1, 1
242
243 elif padding_direction == (0, 0, 1, 1):
244 # Pad bottom right corner
245 """
246 | a b | | a b b |
247 | c d | -> | c d d |
248 | c d d |
249 """
250 addresses = [
251 addresses[0],
252 addresses[0] + tr_offset,
253 addresses[0] + bl_offset,
254 addresses[0] + br_offset,
255 ]
256 # h0, h1, w0 = h0, h1, w0
257 else:
258 assert 0, "Invalid padding direction for tile padding"
259
260 return NpuTileBox(height_0=h0, height_1=h1, width_0=w0, addresses=[int(addr) for addr in addresses])
261
262
Louis Verhaard024c3552021-03-17 14:26:34 +0100263def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000264 base_ptr_idx_map = {
265 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
266 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
267 MemType.Scratch: BasePointerIndex.ScratchTensor,
268 }
269
270 if arch.is_spilling_enabled():
271 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100272 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000273 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
274
Louis Verhaard024c3552021-03-17 14:26:34 +0100275 return base_ptr_idx_map[mem_type].value
276
277
278def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
279 """Returns map region -> max size of the region in bytes"""
280 mem_limits = dict()
281 for mem_type in MemType.all():
282 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
283 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
284 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100285
286
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100287def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
288 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100289 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100290 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100291 block = ofm_box.get_block()
292 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100293
294
295def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
296 """Checks if quantization should use 0 as zero point"""
297 if tens.dtype == DataType.int32 and is_ifm_tensor:
298 return True
Rickard Bolinfea15162022-07-04 16:19:16 +0000299 # Force zero point to 0 for ResizeBilinear when converting to a DepthwiseConv since the reference kernel
300 # will ignore the zero point.
301 if ps.primary_op.original_type == Op.ResizeBilinear and ps.primary_op.type == Op.DepthwiseConv2DBias:
302 return True
Tim Hall885033b2022-07-21 11:46:03 +0100303 if ps.primary_op.type not in (Op.AvgPool, Op.CLZ, Op.SHL) and not ps.primary_op.type.is_resize_op():
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100304 return False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200305 if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
306 return False
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100307 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
308 forced_ofm_quantization = ps.primary_op.forced_output_quantization
309 use_0 = (
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200310 (
311 ps.primary_op.activation is None
312 or forced_ofm_quantization is not None
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200313 or (ps.primary_op.type.is_avgpool_op() and ps.primary_op.activation.op_type.is_relu_op())
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200314 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100315 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
316 and not fused_quantize
317 )
318 return use_0
319
320
321def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
322 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100323 op = ps.primary_op
324 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
325 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100326 return None
327 if use_zero_point_0(ps, tens, True):
328 zero_point = 0
329 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100330 zero_point = int(ifm_quant.zero_point)
331 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100332
333
334def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
335 """Gets quantization for OFM"""
336 op = ps.primary_op
337 # Check if operation's output quantization is should be used instead of the output tensor's quantization
338 # (used in LUTs)
339 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
340 if ofm_quant is None:
341 return None
342 if use_zero_point_0(ps, tens, False):
343 zero_point = 0
344 else:
345 zero_point = int(ofm_quant.zero_point)
346 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
347
348
Rickard Bolin17e53b52022-09-06 16:09:01 +0000349def create_feature_map(
350 tens: Tensor,
351 box: Box,
352 arch: ArchitectureFeatures,
353 op_shape4D: Shape4D,
Rickard Bolinfea15162022-07-04 16:19:16 +0000354 tile_base_offsets: List[int],
Rickard Bolin17e53b52022-09-06 16:09:01 +0000355 stride_multiplier: Optional[List[int]] = None,
356) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100357 """Creates feature map with common fields populated"""
358 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100359 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100360 fm.data_type = dtype_map[tens.dtype]
361 if tens.format == TensorFormat.NHWC:
362 fm.layout = NpuLayout.NHWC
363 elif tens.format == TensorFormat.NHCWB16:
364 fm.layout = NpuLayout.NHCWB16
365 else:
366 assert 0, "Incorrect tensor format"
Rickard Bolin17e53b52022-09-06 16:09:01 +0000367
368 strides = tens.get_strides(op_shape4D)
369 assert strides is not None
370
371 if stride_multiplier and stride_multiplier != [1, 1, 1]:
372 assert (
373 tens.format == TensorFormat.NHWC
374 ), "Only default stride multiplier ([1, 1, 1]) supported for NHCWB16 format"
375 # Multiply strides for C/H/W (in that order) with corresponding stride factor
376 for i, stride_factor in enumerate(stride_multiplier, start=1):
377 strides[i] *= stride_factor
378
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100379 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
Rickard Bolin17e53b52022-09-06 16:09:01 +0000380 box.start_coord, box.end_coord, strides, op_shape4D
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100381 )
Rickard Bolin17e53b52022-09-06 16:09:01 +0000382
Rickard Bolinfea15162022-07-04 16:19:16 +0000383 for idx, offset in enumerate(tile_base_offsets):
384 addresses[idx] += offset
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100385 fm.tiles = NpuTileBox(
386 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
387 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100388 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
Tim Hall68df8a12022-03-16 16:51:16 +0000389 fm.name = tens.name
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100390 return fm
391
392
Tim Halld784af72021-06-08 21:25:57 +0100393def create_weights(
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100394 weight_tensor: NpuWeightTensor, weight_box: Box, scale_tensor: NpuWeightTensor, arch: ArchitectureFeatures
395) -> Tuple[List[NpuAddressRange], List[NpuAddressRange]]:
Tim Halld8339a72021-05-27 18:49:40 +0100396 """Returns address ranges for weights and scales"""
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100397 weights = []
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100398 biases = []
Tim Halld784af72021-06-08 21:25:57 +0100399 shared_region = get_region(weight_tensor.mem_type, arch)
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100400 scale_region = get_region(scale_tensor.mem_type, arch) if scale_tensor else 0
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100401
Tim Halld8339a72021-05-27 18:49:40 +0100402 w_tensor_src = weight_tensor
403 if weight_tensor.src_tensor:
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100404 w_tensor_src = cast(NpuWeightTensor, weight_tensor.src_tensor)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100405
Tim Halld8339a72021-05-27 18:49:40 +0100406 core_offset = 0
407 for core in range(0, arch.ncores):
408 # Get weight range per core
409 key = WeightKey(core, weight_box.start_coord[-1])
410 if key in w_tensor_src.encoded_ranges:
411 weight_range = w_tensor_src.encoded_ranges[key]
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000412 if weight_tensor == w_tensor_src:
413 # Straight from source tensor
414 address = weight_tensor.address + weight_range.offset
Tim Hallb5df7732022-05-04 16:20:43 +0100415 else:
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000416 # Weight buffered tensor
417 address = weight_tensor.address + core_offset
418 core_offset += round_up(weight_range.total_bytes, 16)
Tim Halld8339a72021-05-27 18:49:40 +0100419
420 # Location of weights in tensor
421 addr_range = NpuAddressRange(
Tim Halld784af72021-06-08 21:25:57 +0100422 shared_region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16)
Tim Halld8339a72021-05-27 18:49:40 +0100423 )
424 weights.append(addr_range)
Tim Halld784af72021-06-08 21:25:57 +0100425
426 # Location of standalone scales or combined weights tensor scales
427 if scale_tensor:
428 assert scale_tensor.src_tensor is None # Must be standalone
429 scale_range = scale_tensor.encoded_ranges[key]
430 address = scale_tensor.address + scale_range.offset
431 addr_range = NpuAddressRange(scale_region, int(address), round_up(int(scale_range.scale_bytes), 16))
432 else:
433 addr_range = NpuAddressRange(shared_region, int(address), round_up(int(weight_range.scale_bytes), 16))
434
Tim Halld8339a72021-05-27 18:49:40 +0100435 biases.append(addr_range)
436
437 return weights, biases
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100438
439
440def create_npu_activation(op: Operation) -> NpuActivation:
441 """Creates fused activation function"""
442 if op.activation is None:
443 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
444 faf = op.activation.op_type
445 act_op = NpuActivationOp.NONE_OR_RELU
446 if faf == Op.Tanh:
447 act_op = NpuActivationOp.TANH
448 elif faf == Op.Sigmoid:
449 act_op = NpuActivationOp.SIGMOID
450 elif faf == Op.LUT:
451 act_op = NpuActivationOp.TABLE_LOOKUP
452 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000453 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100454
455 act = NpuActivation(act_op)
456 act.min = op.activation.min
457 act.max = op.activation.max
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200458 if act_op is NpuActivationOp.NONE_OR_RELU and op.type.is_avgpool_op() and not op.explicit_scaling:
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200459 quant = op.ofm.quantization
460 if quant and quant.zero_point: # Zero point is not 0
461 scale_f32 = 1 if quant.scale_f32 is None else quant.scale_f32
462 zero_point = quant.zero_point
463 if act.min is not None:
464 act.min = scale_f32 * quantise_float32(act.min, scale_f32, zero_point)
465 if act.max is not None:
466 act.max = scale_f32 * quantise_float32(act.max, scale_f32, zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100467 act.lookup_table_index = op.activation.lut_index
468 return act
469
470
471def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
472 """Sets common fields of the given operation"""
473 ps = cmd.ps
474 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100475
476 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100477 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100478 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100479
Rickard Bolinfea15162022-07-04 16:19:16 +0000480 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0], op.tile_base_offsets_ifm[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100481 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100482 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100483
484 out_block = cmd.ofm_box.get_block()
Rickard Bolinfea15162022-07-04 16:19:16 +0000485 npu_op.ofm = create_feature_map(
486 cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0], op.tile_base_offsets_ofm, op.ofm_stride_multiplier
487 )
Louis Verhaard69b31762020-11-17 09:45:20 +0100488 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100489 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
490
491 if cmd.weight_tensor is not None:
Tim Halld784af72021-06-08 21:25:57 +0100492 npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, cmd.scale_tensor, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100493 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100494 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
495 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100496 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
497
498 if not op.type.is_elementwise_op():
Rickard Bolin9ae34552022-06-09 13:07:17 +0000499 npu_op.padding = create_padding(cmd, op, npu_op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100500 npu_op.kernel = to_npu_kernel(op.kernel)
Tim Hall3c5cfe92022-03-16 16:31:57 +0000501 npu_op.ifm_upscale = resampling_mode_inv_map[op.ifm_resampling_mode]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100502 return npu_op
503
504
505def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
506 """Converts the command to NpuConv2DOperation"""
507 npu_op = NpuConv2DOperation()
508 set_common_op_fields(npu_op, cmd, arch)
509 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
510 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
511 else:
Tim Halld8339a72021-05-27 18:49:40 +0100512 if cmd.weight_tensor.src_tensor:
513 npu_op.block_traversal = cmd.weight_tensor.src_tensor.hw_traversal
514 else:
515 npu_op.block_traversal = cmd.weight_tensor.hw_traversal
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100516 return npu_op
517
518
519def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
520 """Converts the command to NpuConvDepthWiseOperation"""
521 npu_op = NpuConvDepthWiseOperation()
522 set_common_op_fields(npu_op, cmd, arch)
523 return npu_op
524
525
526def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
527 """Converts the command to NpuPoolingOperation"""
528 ps = cmd.ps
529 op = ps.primary_op
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100530 if op.type.is_maxpool_op():
531 pool_op = NpuPoolingOp.MAX
Tim Hall885033b2022-07-21 11:46:03 +0100532 elif op.type.is_avgpool_op() or op.type.is_resize_op():
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100533 pool_op = NpuPoolingOp.AVERAGE
534 elif op.type == Op.ReduceSum:
535 pool_op = NpuPoolingOp.REDUCE_SUM
536 else:
537 assert 0, f"Unknown pool type {op.type}"
538 npu_op = NpuPoolingOperation(pool_op)
539 set_common_op_fields(npu_op, cmd, arch)
540 # Pooling specific info
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200541 if op.explicit_scaling:
542 # Note: reuse of rescale for explicit scaling to not expose this in the external API
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200543 npu_op.rescale = op.explicit_scaling
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100544 return npu_op
545
546
547def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
548 """Converts the command to NpuElementWiseOperation"""
549 ps = cmd.ps
550 op = ps.primary_op
551 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
552 elemwise_op = elementwise_op_map[op.type]
553 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100554
Louis Verhaard1e170182020-11-26 11:42:04 +0100555 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100556 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
557 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
558 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100559 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
560 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
561 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100562 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100563 npu_op.reversed_operands = True
Rickard Bolinfea15162022-07-04 16:19:16 +0000564 npu_op.ifm2 = create_feature_map(
565 cmd.ifm2_tensor,
566 cmd.ifm2_box,
567 arch,
568 ps.ifm_shapes[1],
569 op.tile_base_offsets_ifm[1],
570 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100571 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
572 if cmd.ifm2_tensor.shape == []:
573 # scalar
James Peet7519d502021-07-19 16:47:58 +0100574 npu_op.ifm2_scalar = cmd.ifm2_tensor.get_scalar()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100575 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
576 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100577 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100578 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100579 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100580 set_common_op_fields(npu_op, cmd, arch)
581 # Check if output scale needs to be overridden
582 output_scale = None
Fredrik Svedberg4a434cb2022-09-27 14:13:01 +0200583 if op.explicit_scaling is not None:
584 assert not op.explicit_scaling.per_channel
585 assert op.type in (Op.Add, Op.Mul, Op.Sub)
586 npu_op.rescale = (op.explicit_scaling.multiplier[0], op.explicit_scaling.shift[0])
587 elif op.type == Op.Add and op.original_type.is_resize_op():
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100588 # Force output scale same as the input scale for
Tim Hall885033b2022-07-21 11:46:03 +0100589 # resizebilinear/nearestneighbor 1x1 that is converted to add
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100590 output_scale = npu_op.ifm2.quantization.scale_f32
Tim Hall885033b2022-07-21 11:46:03 +0100591 elif op.type == Op.Abs:
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100592 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Tim Hall885033b2022-07-21 11:46:03 +0100593 elif op.type == Op.LeakyRelu:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100594 output_scale = op.attrs["alpha"]
Tim Hall885033b2022-07-21 11:46:03 +0100595 elif op.type in (Op.Add, Op.Mul, Op.Sub):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100596 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
597 output_scale = 1 / 0x3000
598 if output_scale is not None:
599 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
600 return npu_op
601
602
603def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
604 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100605 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100606 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100607 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100608 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100609 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100610
Tim Halld8339a72021-05-27 18:49:40 +0100611 if cmd.in_tensor.purpose == TensorPurpose.Weights:
612 # Get weight range per core
613 sz = 0
614 for core in range(0, arch.ncores):
615 key = WeightKey(core, cmd.box.start_coord[-1])
616 if key in cmd.in_tensor.encoded_ranges:
617 weight_range = cmd.in_tensor.encoded_ranges[key]
618 sz += round_up(weight_range.total_bytes, 16)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100619
Tim Halld8339a72021-05-27 18:49:40 +0100620 if core == 0:
621 weight_range = cmd.in_tensor.encoded_ranges[key]
622 src_addr = cmd.in_tensor.address + weight_range.offset
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000623 dest_addr = cmd.out_tensor.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100624 else:
Rickard Bolin17e53b52022-09-06 16:09:01 +0000625 src_addr = cmd.in_tensor.address_for_coordinate(cmd.box.start_coord)
626 dest_addr = cmd.out_tensor.address_for_coordinate(cmd.box.start_coord)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100627 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
628 src = NpuAddressRange(src_region, int(src_addr), int(sz))
629 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
630 return NpuDmaOperation(src, dest)
631
632
633def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
634 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100635 npu_op: NpuOperation
636 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100637 npu_op = create_dma_op(cmd, arch)
Tim Hall68df8a12022-03-16 16:51:16 +0000638 npu_op.name = cmd.out_tensor.name
Dwight Lidman9b43f842020-12-08 17:56:44 +0100639 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100640 npu_block_type = cmd.ps.primary_op.type.npu_block_type
641 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
642 npu_op = create_npu_conv2d_op(cmd, arch)
643 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
644 npu_op = create_npu_conv_depthwise_op(cmd, arch)
645 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
646 npu_op = create_npu_pool_op(cmd, arch)
647 elif npu_block_type == NpuBlockType.ElementWise:
648 npu_op = create_npu_elementwise_op(cmd, arch)
649 else:
650 assert 0, f"Unknown command type {npu_block_type}"
Tim Hall68df8a12022-03-16 16:51:16 +0000651 npu_op.name = cmd.ps.primary_op.name
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100652 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100653
654
655def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
656 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
657 # Convert high level command stream to list of NpuOperation
658 npu_op_list = []
659 npu_op_to_cmd = dict() # map from npu op to high level command
660 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100661 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100662 print("Warning: Skipping register command stream generation for", cmd.ps)
663 else:
664 npu_op = convert_command_to_npu_op(cmd, arch)
665 npu_op_list.append(npu_op)
666 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100667 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100668 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100669 if len(sg.high_level_command_stream) > 0:
670 stream_id = DebugDatabase.add_stream(sg)
671 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100672
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100673 def add_to_debug_db(npu_op: NpuOperation, offset: int):
674 """Adds info to the debug database"""
675 if not isinstance(npu_op, NpuDmaOperation):
676 cmd = npu_op_to_cmd[npu_op]
677 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100678
Louis Verhaard024c3552021-03-17 14:26:34 +0100679 sg.register_command_stream = generate_command_stream(
680 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
681 )