blob: 2ce150fca71ca011b95061dbabf969485d910a66 [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Jonas Ohlsson845e2322022-03-01 12:39:55 +010020from typing import cast
Louis Verhaard024c3552021-03-17 14:26:34 +010021from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010022from typing import List
23from typing import Optional
Jonas Ohlsson845e2322022-03-01 12:39:55 +010024from typing import Tuple
Louis Verhaarde8a5a782020-11-02 18:04:27 +010025
26from .api import NpuActivation
27from .api import NpuActivationOp
28from .api import NpuAddressRange
29from .api import NpuBlockOperation
30from .api import NpuBlockTraversal
31from .api import NpuConv2DOperation
32from .api import NpuConvDepthWiseOperation
33from .api import NpuDataType
34from .api import NpuDmaOperation
35from .api import NpuElementWiseOp
36from .api import NpuElementWiseOperation
37from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010038from .api import NpuLayout
39from .api import NpuOperation
40from .api import NpuPadding
41from .api import NpuPoolingOp
42from .api import NpuPoolingOperation
43from .api import NpuQuantization
44from .api import NpuResamplingMode
45from .api import NpuRoundingMode
46from .api import NpuShape3D
47from .api import NpuTileBox
48from .architecture_features import ArchitectureFeatures
49from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010050from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000051from .errors import UnsupportedFeatureError
Tim Hall3c5cfe92022-03-16 16:31:57 +000052from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +010053from .high_level_command_stream import Box
54from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010055from .high_level_command_stream import DMA
56from .high_level_command_stream import NpuStripe
Fredrik Svedberg838df0a2021-09-17 16:29:22 +020057from .numeric_util import quantise_float32
Tim Halld8339a72021-05-27 18:49:40 +010058from .numeric_util import round_up
Louis Verhaarde8a5a782020-11-02 18:04:27 +010059from .operation import NpuBlockType
60from .operation import Op
61from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010062from .register_command_stream_generator import generate_command_stream
63from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010064from .register_command_stream_util import to_npu_kernel
65from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000066from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010067from .tensor import MemType
68from .tensor import Tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010069from .tensor import TensorFormat
70from .tensor import TensorPurpose
Jonas Ohlsson845e2322022-03-01 12:39:55 +010071from .weight_compressor import NpuWeightTensor
Tim Halld8339a72021-05-27 18:49:40 +010072from .weight_compressor import WeightKey
Louis Verhaarde8a5a782020-11-02 18:04:27 +010073
74
Louis Verhaarde8a5a782020-11-02 18:04:27 +010075class BasePointerIndex(IntEnum):
76 WeightTensor = 0 # base address index for the Weight tensor
77 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
78 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010079
80
81dtype_map = {
82 DataType.uint8: NpuDataType.UINT8,
83 DataType.int8: NpuDataType.INT8,
84 DataType.uint16: NpuDataType.UINT16,
85 DataType.int16: NpuDataType.INT16,
86 DataType.int32: NpuDataType.INT32,
87}
88
89
Louis Verhaarde8a5a782020-11-02 18:04:27 +010090# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
91elementwise_op_map = {
92 Op.Mul: NpuElementWiseOp.MUL,
Patrik Gustavssonb081d672021-08-25 13:49:25 +020093 Op.RescaleMul: NpuElementWiseOp.MUL,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010094 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010095 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010096 Op.Sub: NpuElementWiseOp.SUB,
97 Op.Minimum: NpuElementWiseOp.MIN,
98 Op.Maximum: NpuElementWiseOp.MAX,
99 Op.LeakyRelu: NpuElementWiseOp.LRELU,
100 Op.Abs: NpuElementWiseOp.ABS,
101 Op.CLZ: NpuElementWiseOp.CLZ,
102 Op.SHR: NpuElementWiseOp.SHR,
103 Op.SHL: NpuElementWiseOp.SHL,
104}
105
106
Tim Hall3c5cfe92022-03-16 16:31:57 +0000107# inverse of the resampling_mode_map in the register command stream generator
108resampling_mode_inv_map = {
109 resampling_mode.NONE: NpuResamplingMode.NONE,
110 resampling_mode.NEAREST: NpuResamplingMode.NEAREST,
111 resampling_mode.TRANSPOSE: NpuResamplingMode.TRANSPOSE,
112}
113
114
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100115def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
116 if ifm_shape == []:
117 # Scalar needs to be in IFM2
118 return False
119 if ifm2_shape == []:
120 return True
121
122 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
123 if ifm != ifm2 and ifm == 1:
124 # Broadcasted FM needs to be in IFM2
125 return False
126 return True
127
128
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100129def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100130 """Specifies type of rounding to be used"""
131 rounding_mode = NpuRoundingMode.TFL
Tim Hall885033b2022-07-21 11:46:03 +0100132 if op.type.is_resize_op():
Dwight Lidman9d243932021-08-10 12:53:12 +0200133 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100134 elif (
135 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
136 and op.ifm.dtype == DataType.int16
137 ):
138 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100139 elif (
140 not fused_quantize
141 and op.type.is_avgpool_op()
142 and op.memory_function == Op.ConcatSliceWrite
143 and op.kernel.elements_wh() == 1
144 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100145 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100146 if op.rounding_mode is not None:
147 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100148 return rounding_mode
149
150
151def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
152 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
153 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100154 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100155
156 # Check if this is for horizontal ifm streaming
157 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100158 top = cmd.pad_top
159 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100160
Tim Hall3751aa42021-12-16 13:17:29 +0000161 # the ifm box coordinate range depends upon whether the primary op was combined with a split slice read
162 ifm_read_offset = primary_op.read_offsets[0]
163 ifm_read_shape = primary_op.read_shapes[0]
164 if ifm_read_offset is None or len(ifm_read_offset) < 2:
165 box_start_coord_min = 0
166 box_end_coord_max = cmd.ps.ifm_shapes[0].width
167 else:
168 box_start_coord_min = ifm_read_offset[-2]
169 box_end_coord_max = ifm_read_shape[-2]
170
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100171 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
172 # because of activation function needed to be fused.
Tim Hall3751aa42021-12-16 13:17:29 +0000173 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > box_start_coord_min:
174 left = 0
175 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < box_end_coord_max:
176 right = 0
Louis Verhaard69b31762020-11-17 09:45:20 +0100177 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100178
179
Louis Verhaard024c3552021-03-17 14:26:34 +0100180def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000181 base_ptr_idx_map = {
182 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
183 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
184 MemType.Scratch: BasePointerIndex.ScratchTensor,
185 }
186
187 if arch.is_spilling_enabled():
188 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100189 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000190 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
191
Louis Verhaard024c3552021-03-17 14:26:34 +0100192 return base_ptr_idx_map[mem_type].value
193
194
195def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
196 """Returns map region -> max size of the region in bytes"""
197 mem_limits = dict()
198 for mem_type in MemType.all():
199 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
200 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
201 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100202
203
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100204def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
205 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100206 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100207 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100208 block = ofm_box.get_block()
209 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100210
211
212def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
213 """Checks if quantization should use 0 as zero point"""
214 if tens.dtype == DataType.int32 and is_ifm_tensor:
215 return True
Tim Hall885033b2022-07-21 11:46:03 +0100216 if ps.primary_op.type not in (Op.AvgPool, Op.CLZ, Op.SHL) and not ps.primary_op.type.is_resize_op():
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100217 return False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200218 if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
219 return False
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100220 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
221 forced_ofm_quantization = ps.primary_op.forced_output_quantization
222 use_0 = (
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200223 (
224 ps.primary_op.activation is None
225 or forced_ofm_quantization is not None
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200226 or (
227 ps.primary_op.type.is_avgpool_op()
228 and ps.primary_op.activation.op_type.is_relu_op()
229 and not ps.primary_op.rescale
230 )
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200231 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100232 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
233 and not fused_quantize
234 )
235 return use_0
236
237
238def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
239 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100240 op = ps.primary_op
241 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
242 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100243 return None
244 if use_zero_point_0(ps, tens, True):
245 zero_point = 0
246 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100247 zero_point = int(ifm_quant.zero_point)
248 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100249
250
251def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
252 """Gets quantization for OFM"""
253 op = ps.primary_op
254 # Check if operation's output quantization is should be used instead of the output tensor's quantization
255 # (used in LUTs)
256 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
257 if ofm_quant is None:
258 return None
259 if use_zero_point_0(ps, tens, False):
260 zero_point = 0
261 else:
262 zero_point = int(ofm_quant.zero_point)
263 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
264
265
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100266def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100267 """Creates feature map with common fields populated"""
268 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100269 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100270 fm.data_type = dtype_map[tens.dtype]
271 if tens.format == TensorFormat.NHWC:
272 fm.layout = NpuLayout.NHWC
273 elif tens.format == TensorFormat.NHCWB16:
274 fm.layout = NpuLayout.NHCWB16
275 else:
276 assert 0, "Incorrect tensor format"
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100277 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
278 box.start_coord, box.end_coord, op_shape4D
279 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100280 for idx, addr in enumerate(addresses):
281 if addr is None:
282 addresses[idx] = 0
283 fm.tiles = NpuTileBox(
284 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
285 )
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100286 strides = tens.get_strides(shape4D=op_shape4D)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100287 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
Tim Hall68df8a12022-03-16 16:51:16 +0000288 fm.name = tens.name
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100289 return fm
290
291
Tim Halld784af72021-06-08 21:25:57 +0100292def create_weights(
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100293 weight_tensor: NpuWeightTensor, weight_box: Box, scale_tensor: NpuWeightTensor, arch: ArchitectureFeatures
294) -> Tuple[List[NpuAddressRange], List[NpuAddressRange]]:
Tim Halld8339a72021-05-27 18:49:40 +0100295 """Returns address ranges for weights and scales"""
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100296 weights = []
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100297 biases = []
Tim Halld784af72021-06-08 21:25:57 +0100298 shared_region = get_region(weight_tensor.mem_type, arch)
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100299 scale_region = get_region(scale_tensor.mem_type, arch) if scale_tensor else 0
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100300
Tim Halld8339a72021-05-27 18:49:40 +0100301 w_tensor_src = weight_tensor
302 if weight_tensor.src_tensor:
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100303 w_tensor_src = cast(NpuWeightTensor, weight_tensor.src_tensor)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100304
Tim Halld8339a72021-05-27 18:49:40 +0100305 core_offset = 0
306 for core in range(0, arch.ncores):
307 # Get weight range per core
308 key = WeightKey(core, weight_box.start_coord[-1])
309 if key in w_tensor_src.encoded_ranges:
310 weight_range = w_tensor_src.encoded_ranges[key]
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000311 if weight_tensor == w_tensor_src:
312 # Straight from source tensor
313 address = weight_tensor.address + weight_range.offset
Tim Hallb5df7732022-05-04 16:20:43 +0100314 else:
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000315 # Weight buffered tensor
316 address = weight_tensor.address + core_offset
317 core_offset += round_up(weight_range.total_bytes, 16)
Tim Halld8339a72021-05-27 18:49:40 +0100318
319 # Location of weights in tensor
320 addr_range = NpuAddressRange(
Tim Halld784af72021-06-08 21:25:57 +0100321 shared_region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16)
Tim Halld8339a72021-05-27 18:49:40 +0100322 )
323 weights.append(addr_range)
Tim Halld784af72021-06-08 21:25:57 +0100324
325 # Location of standalone scales or combined weights tensor scales
326 if scale_tensor:
327 assert scale_tensor.src_tensor is None # Must be standalone
328 scale_range = scale_tensor.encoded_ranges[key]
329 address = scale_tensor.address + scale_range.offset
330 addr_range = NpuAddressRange(scale_region, int(address), round_up(int(scale_range.scale_bytes), 16))
331 else:
332 addr_range = NpuAddressRange(shared_region, int(address), round_up(int(weight_range.scale_bytes), 16))
333
Tim Halld8339a72021-05-27 18:49:40 +0100334 biases.append(addr_range)
335
336 return weights, biases
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100337
338
339def create_npu_activation(op: Operation) -> NpuActivation:
340 """Creates fused activation function"""
341 if op.activation is None:
342 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
343 faf = op.activation.op_type
344 act_op = NpuActivationOp.NONE_OR_RELU
345 if faf == Op.Tanh:
346 act_op = NpuActivationOp.TANH
347 elif faf == Op.Sigmoid:
348 act_op = NpuActivationOp.SIGMOID
349 elif faf == Op.LUT:
350 act_op = NpuActivationOp.TABLE_LOOKUP
351 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000352 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100353
354 act = NpuActivation(act_op)
355 act.min = op.activation.min
356 act.max = op.activation.max
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200357 if act_op is NpuActivationOp.NONE_OR_RELU and op.type.is_avgpool_op() and not op.rescale:
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200358 quant = op.ofm.quantization
359 if quant and quant.zero_point: # Zero point is not 0
360 scale_f32 = 1 if quant.scale_f32 is None else quant.scale_f32
361 zero_point = quant.zero_point
362 if act.min is not None:
363 act.min = scale_f32 * quantise_float32(act.min, scale_f32, zero_point)
364 if act.max is not None:
365 act.max = scale_f32 * quantise_float32(act.max, scale_f32, zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100366 act.lookup_table_index = op.activation.lut_index
367 return act
368
369
370def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
371 """Sets common fields of the given operation"""
372 ps = cmd.ps
373 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100374
375 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100376 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100377 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100378
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100379 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100380 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100381 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100382
383 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100384 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100385 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100386 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
387
388 if cmd.weight_tensor is not None:
Tim Halld784af72021-06-08 21:25:57 +0100389 npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, cmd.scale_tensor, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100390 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100391 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
392 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100393 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
394
395 if not op.type.is_elementwise_op():
396 npu_op.padding = create_padding(cmd, op)
397 npu_op.kernel = to_npu_kernel(op.kernel)
Tim Hall3c5cfe92022-03-16 16:31:57 +0000398 npu_op.ifm_upscale = resampling_mode_inv_map[op.ifm_resampling_mode]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100399 return npu_op
400
401
402def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
403 """Converts the command to NpuConv2DOperation"""
404 npu_op = NpuConv2DOperation()
405 set_common_op_fields(npu_op, cmd, arch)
406 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
407 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
408 else:
Tim Halld8339a72021-05-27 18:49:40 +0100409 if cmd.weight_tensor.src_tensor:
410 npu_op.block_traversal = cmd.weight_tensor.src_tensor.hw_traversal
411 else:
412 npu_op.block_traversal = cmd.weight_tensor.hw_traversal
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100413 return npu_op
414
415
416def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
417 """Converts the command to NpuConvDepthWiseOperation"""
418 npu_op = NpuConvDepthWiseOperation()
419 set_common_op_fields(npu_op, cmd, arch)
420 return npu_op
421
422
423def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
424 """Converts the command to NpuPoolingOperation"""
425 ps = cmd.ps
426 op = ps.primary_op
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100427 if op.type.is_maxpool_op():
428 pool_op = NpuPoolingOp.MAX
Tim Hall885033b2022-07-21 11:46:03 +0100429 elif op.type.is_avgpool_op() or op.type.is_resize_op():
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100430 pool_op = NpuPoolingOp.AVERAGE
431 elif op.type == Op.ReduceSum:
432 pool_op = NpuPoolingOp.REDUCE_SUM
433 else:
434 assert 0, f"Unknown pool type {op.type}"
435 npu_op = NpuPoolingOperation(pool_op)
436 set_common_op_fields(npu_op, cmd, arch)
437 # Pooling specific info
Dwight Lidman4f728c02020-12-17 15:14:45 +0100438 npu_op.rescale = op.rescale
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200439 if op.explicit_scaling:
440 # Note: reuse of rescale for explicit scaling to not expose this in the external API
441 assert npu_op.rescale is None
442 npu_op.rescale = op.explicit_scaling
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100443 return npu_op
444
445
446def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
447 """Converts the command to NpuElementWiseOperation"""
448 ps = cmd.ps
449 op = ps.primary_op
450 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
451 elemwise_op = elementwise_op_map[op.type]
452 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100453
Louis Verhaard1e170182020-11-26 11:42:04 +0100454 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100455 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
456 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
457 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100458 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
459 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
460 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100461 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100462 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100463 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100464 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
465 if cmd.ifm2_tensor.shape == []:
466 # scalar
James Peet7519d502021-07-19 16:47:58 +0100467 npu_op.ifm2_scalar = cmd.ifm2_tensor.get_scalar()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100468 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
469 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100470 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100471 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100472 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100473 set_common_op_fields(npu_op, cmd, arch)
474 # Check if output scale needs to be overridden
475 output_scale = None
Tim Hall885033b2022-07-21 11:46:03 +0100476 if op.type == Op.Add and op.original_type.is_resize_op():
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100477 # Force output scale same as the input scale for
Tim Hall885033b2022-07-21 11:46:03 +0100478 # resizebilinear/nearestneighbor 1x1 that is converted to add
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100479 output_scale = npu_op.ifm2.quantization.scale_f32
Tim Hall885033b2022-07-21 11:46:03 +0100480 elif op.type == Op.Abs:
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100481 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Tim Hall885033b2022-07-21 11:46:03 +0100482 elif op.type == Op.LeakyRelu:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100483 output_scale = op.attrs["alpha"]
Tim Hall885033b2022-07-21 11:46:03 +0100484 elif op.type in (Op.RescaleAdd, Op.RescaleMul):
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100485 assert op.rescale is not None, f"{op.type} must have rescale"
486 npu_op.rescale = op.rescale
Tim Hall885033b2022-07-21 11:46:03 +0100487 elif op.type in (Op.Add, Op.Mul, Op.Sub):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100488 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
489 output_scale = 1 / 0x3000
490 if output_scale is not None:
491 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
492 return npu_op
493
494
495def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
496 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100497 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100498 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100499 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100500 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100501 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100502
Tim Halld8339a72021-05-27 18:49:40 +0100503 if cmd.in_tensor.purpose == TensorPurpose.Weights:
504 # Get weight range per core
505 sz = 0
506 for core in range(0, arch.ncores):
507 key = WeightKey(core, cmd.box.start_coord[-1])
508 if key in cmd.in_tensor.encoded_ranges:
509 weight_range = cmd.in_tensor.encoded_ranges[key]
510 sz += round_up(weight_range.total_bytes, 16)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100511
Tim Halld8339a72021-05-27 18:49:40 +0100512 if core == 0:
513 weight_range = cmd.in_tensor.encoded_ranges[key]
514 src_addr = cmd.in_tensor.address + weight_range.offset
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000515 dest_addr = cmd.out_tensor.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100516 else:
Tim Halld8339a72021-05-27 18:49:40 +0100517 start_coord = cmd.box.start_coord
518 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
519 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100520 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
521 src = NpuAddressRange(src_region, int(src_addr), int(sz))
522 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
523 return NpuDmaOperation(src, dest)
524
525
526def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
527 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100528 npu_op: NpuOperation
529 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100530 npu_op = create_dma_op(cmd, arch)
Tim Hall68df8a12022-03-16 16:51:16 +0000531 npu_op.name = cmd.out_tensor.name
Dwight Lidman9b43f842020-12-08 17:56:44 +0100532 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100533 npu_block_type = cmd.ps.primary_op.type.npu_block_type
534 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
535 npu_op = create_npu_conv2d_op(cmd, arch)
536 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
537 npu_op = create_npu_conv_depthwise_op(cmd, arch)
538 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
539 npu_op = create_npu_pool_op(cmd, arch)
540 elif npu_block_type == NpuBlockType.ElementWise:
541 npu_op = create_npu_elementwise_op(cmd, arch)
542 else:
543 assert 0, f"Unknown command type {npu_block_type}"
Tim Hall68df8a12022-03-16 16:51:16 +0000544 npu_op.name = cmd.ps.primary_op.name
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100545 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100546
547
548def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
549 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
550 # Convert high level command stream to list of NpuOperation
551 npu_op_list = []
552 npu_op_to_cmd = dict() # map from npu op to high level command
553 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100554 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100555 print("Warning: Skipping register command stream generation for", cmd.ps)
556 else:
557 npu_op = convert_command_to_npu_op(cmd, arch)
558 npu_op_list.append(npu_op)
559 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100560 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100561 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100562 if len(sg.high_level_command_stream) > 0:
563 stream_id = DebugDatabase.add_stream(sg)
564 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100565
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100566 def add_to_debug_db(npu_op: NpuOperation, offset: int):
567 """Adds info to the debug database"""
568 if not isinstance(npu_op, NpuDmaOperation):
569 cmd = npu_op_to_cmd[npu_op]
570 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100571
Louis Verhaard024c3552021-03-17 14:26:34 +0100572 sg.register_command_stream = generate_command_stream(
573 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
574 )