blob: 3a78d6fb0c66f2a2a2a3d30cd2d871dcdc9b92f6 [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Jonas Ohlsson845e2322022-03-01 12:39:55 +010020from typing import cast
Louis Verhaard024c3552021-03-17 14:26:34 +010021from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010022from typing import List
23from typing import Optional
Jonas Ohlsson845e2322022-03-01 12:39:55 +010024from typing import Tuple
Louis Verhaarde8a5a782020-11-02 18:04:27 +010025
26from .api import NpuActivation
27from .api import NpuActivationOp
28from .api import NpuAddressRange
29from .api import NpuBlockOperation
30from .api import NpuBlockTraversal
31from .api import NpuConv2DOperation
32from .api import NpuConvDepthWiseOperation
33from .api import NpuDataType
34from .api import NpuDmaOperation
35from .api import NpuElementWiseOp
36from .api import NpuElementWiseOperation
37from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010038from .api import NpuLayout
39from .api import NpuOperation
40from .api import NpuPadding
41from .api import NpuPoolingOp
42from .api import NpuPoolingOperation
43from .api import NpuQuantization
44from .api import NpuResamplingMode
45from .api import NpuRoundingMode
46from .api import NpuShape3D
47from .api import NpuTileBox
48from .architecture_features import ArchitectureFeatures
49from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010050from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000051from .errors import UnsupportedFeatureError
Tim Hall3c5cfe92022-03-16 16:31:57 +000052from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +010053from .high_level_command_stream import Box
54from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010055from .high_level_command_stream import DMA
56from .high_level_command_stream import NpuStripe
Fredrik Svedberg838df0a2021-09-17 16:29:22 +020057from .numeric_util import quantise_float32
Tim Halld8339a72021-05-27 18:49:40 +010058from .numeric_util import round_up
Louis Verhaarde8a5a782020-11-02 18:04:27 +010059from .operation import NpuBlockType
60from .operation import Op
61from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010062from .register_command_stream_generator import generate_command_stream
63from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010064from .register_command_stream_util import to_npu_kernel
65from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000066from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010067from .tensor import MemType
68from .tensor import Tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010069from .tensor import TensorFormat
70from .tensor import TensorPurpose
Tim Hallb5df7732022-05-04 16:20:43 +010071from .tensor import TensorSubPurpose
Jonas Ohlsson845e2322022-03-01 12:39:55 +010072from .weight_compressor import NpuWeightTensor
Tim Halld8339a72021-05-27 18:49:40 +010073from .weight_compressor import WeightKey
Louis Verhaarde8a5a782020-11-02 18:04:27 +010074
75
Louis Verhaarde8a5a782020-11-02 18:04:27 +010076class BasePointerIndex(IntEnum):
77 WeightTensor = 0 # base address index for the Weight tensor
78 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
79 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010080
81
82dtype_map = {
83 DataType.uint8: NpuDataType.UINT8,
84 DataType.int8: NpuDataType.INT8,
85 DataType.uint16: NpuDataType.UINT16,
86 DataType.int16: NpuDataType.INT16,
87 DataType.int32: NpuDataType.INT32,
88}
89
90
Louis Verhaarde8a5a782020-11-02 18:04:27 +010091# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
92elementwise_op_map = {
93 Op.Mul: NpuElementWiseOp.MUL,
Patrik Gustavssonb081d672021-08-25 13:49:25 +020094 Op.RescaleMul: NpuElementWiseOp.MUL,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010095 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010096 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010097 Op.Sub: NpuElementWiseOp.SUB,
98 Op.Minimum: NpuElementWiseOp.MIN,
99 Op.Maximum: NpuElementWiseOp.MAX,
100 Op.LeakyRelu: NpuElementWiseOp.LRELU,
101 Op.Abs: NpuElementWiseOp.ABS,
102 Op.CLZ: NpuElementWiseOp.CLZ,
103 Op.SHR: NpuElementWiseOp.SHR,
104 Op.SHL: NpuElementWiseOp.SHL,
105}
106
107
Tim Hall3c5cfe92022-03-16 16:31:57 +0000108# inverse of the resampling_mode_map in the register command stream generator
109resampling_mode_inv_map = {
110 resampling_mode.NONE: NpuResamplingMode.NONE,
111 resampling_mode.NEAREST: NpuResamplingMode.NEAREST,
112 resampling_mode.TRANSPOSE: NpuResamplingMode.TRANSPOSE,
113}
114
115
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100116def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
117 if ifm_shape == []:
118 # Scalar needs to be in IFM2
119 return False
120 if ifm2_shape == []:
121 return True
122
123 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
124 if ifm != ifm2 and ifm == 1:
125 # Broadcasted FM needs to be in IFM2
126 return False
127 return True
128
129
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100130def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100131 """Specifies type of rounding to be used"""
132 rounding_mode = NpuRoundingMode.TFL
133 if op.type == Op.ResizeBilinear:
Dwight Lidman9d243932021-08-10 12:53:12 +0200134 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100135 elif (
136 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
137 and op.ifm.dtype == DataType.int16
138 ):
139 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100140 elif (
141 not fused_quantize
142 and op.type.is_avgpool_op()
143 and op.memory_function == Op.ConcatSliceWrite
144 and op.kernel.elements_wh() == 1
145 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100146 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100147 if op.rounding_mode is not None:
148 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100149 return rounding_mode
150
151
152def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
153 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
154 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100155 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100156
157 # Check if this is for horizontal ifm streaming
158 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100159 top = cmd.pad_top
160 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100161
Tim Hall3751aa42021-12-16 13:17:29 +0000162 # the ifm box coordinate range depends upon whether the primary op was combined with a split slice read
163 ifm_read_offset = primary_op.read_offsets[0]
164 ifm_read_shape = primary_op.read_shapes[0]
165 if ifm_read_offset is None or len(ifm_read_offset) < 2:
166 box_start_coord_min = 0
167 box_end_coord_max = cmd.ps.ifm_shapes[0].width
168 else:
169 box_start_coord_min = ifm_read_offset[-2]
170 box_end_coord_max = ifm_read_shape[-2]
171
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100172 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
173 # because of activation function needed to be fused.
Tim Hall3751aa42021-12-16 13:17:29 +0000174 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > box_start_coord_min:
175 left = 0
176 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < box_end_coord_max:
177 right = 0
Louis Verhaard69b31762020-11-17 09:45:20 +0100178 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100179
180
Louis Verhaard024c3552021-03-17 14:26:34 +0100181def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000182 base_ptr_idx_map = {
183 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
184 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
185 MemType.Scratch: BasePointerIndex.ScratchTensor,
186 }
187
188 if arch.is_spilling_enabled():
189 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100190 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000191 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
192
Louis Verhaard024c3552021-03-17 14:26:34 +0100193 return base_ptr_idx_map[mem_type].value
194
195
196def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
197 """Returns map region -> max size of the region in bytes"""
198 mem_limits = dict()
199 for mem_type in MemType.all():
200 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
201 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
202 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100203
204
Tim Hallb5df7732022-05-04 16:20:43 +0100205def get_double_buffer_offset(arch: ArchitectureFeatures, range_index: int, core: int) -> int:
206 """Returns 0 if the first half of a double buffer should be used, 1 if the second half should be used"""
207 return ((range_index - core) // arch.ncores) % 2
Louis Verhaarde91b5312022-01-21 13:38:50 +0100208
209
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100210def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
211 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100212 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100213 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100214 block = ofm_box.get_block()
215 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100216
217
218def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
219 """Checks if quantization should use 0 as zero point"""
220 if tens.dtype == DataType.int32 and is_ifm_tensor:
221 return True
222 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
223 return False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200224 if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
225 return False
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100226 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
227 forced_ofm_quantization = ps.primary_op.forced_output_quantization
228 use_0 = (
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200229 (
230 ps.primary_op.activation is None
231 or forced_ofm_quantization is not None
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200232 or (
233 ps.primary_op.type.is_avgpool_op()
234 and ps.primary_op.activation.op_type.is_relu_op()
235 and not ps.primary_op.rescale
236 )
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200237 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100238 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
239 and not fused_quantize
240 )
241 return use_0
242
243
244def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
245 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100246 op = ps.primary_op
247 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
248 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100249 return None
250 if use_zero_point_0(ps, tens, True):
251 zero_point = 0
252 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100253 zero_point = int(ifm_quant.zero_point)
254 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100255
256
257def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
258 """Gets quantization for OFM"""
259 op = ps.primary_op
260 # Check if operation's output quantization is should be used instead of the output tensor's quantization
261 # (used in LUTs)
262 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
263 if ofm_quant is None:
264 return None
265 if use_zero_point_0(ps, tens, False):
266 zero_point = 0
267 else:
268 zero_point = int(ofm_quant.zero_point)
269 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
270
271
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100272def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100273 """Creates feature map with common fields populated"""
274 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100275 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100276 fm.data_type = dtype_map[tens.dtype]
277 if tens.format == TensorFormat.NHWC:
278 fm.layout = NpuLayout.NHWC
279 elif tens.format == TensorFormat.NHCWB16:
280 fm.layout = NpuLayout.NHCWB16
281 else:
282 assert 0, "Incorrect tensor format"
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100283 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
284 box.start_coord, box.end_coord, op_shape4D
285 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100286 for idx, addr in enumerate(addresses):
287 if addr is None:
288 addresses[idx] = 0
289 fm.tiles = NpuTileBox(
290 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
291 )
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100292 strides = tens.get_strides(shape4D=op_shape4D)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100293 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
Tim Hall68df8a12022-03-16 16:51:16 +0000294 fm.name = tens.name
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100295 return fm
296
297
Tim Halld784af72021-06-08 21:25:57 +0100298def create_weights(
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100299 weight_tensor: NpuWeightTensor, weight_box: Box, scale_tensor: NpuWeightTensor, arch: ArchitectureFeatures
300) -> Tuple[List[NpuAddressRange], List[NpuAddressRange]]:
Tim Halld8339a72021-05-27 18:49:40 +0100301 """Returns address ranges for weights and scales"""
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100302 weights = []
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100303 biases = []
Tim Halld784af72021-06-08 21:25:57 +0100304 shared_region = get_region(weight_tensor.mem_type, arch)
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100305 scale_region = get_region(scale_tensor.mem_type, arch) if scale_tensor else 0
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100306
Tim Halld8339a72021-05-27 18:49:40 +0100307 w_tensor_src = weight_tensor
308 if weight_tensor.src_tensor:
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100309 w_tensor_src = cast(NpuWeightTensor, weight_tensor.src_tensor)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100310
Tim Halld8339a72021-05-27 18:49:40 +0100311 core_offset = 0
312 for core in range(0, arch.ncores):
313 # Get weight range per core
314 key = WeightKey(core, weight_box.start_coord[-1])
315 if key in w_tensor_src.encoded_ranges:
316 weight_range = w_tensor_src.encoded_ranges[key]
Tim Hallb5df7732022-05-04 16:20:43 +0100317 if weight_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
318 assert weight_tensor != w_tensor_src
319 # Double buffered inside weight_tensor
Louis Verhaardcc5f4de2022-03-01 11:26:58 +0100320 address = weight_tensor.address + core_offset
Tim Hallb5df7732022-05-04 16:20:43 +0100321 address += get_double_buffer_offset(arch, weight_range.index, core) * w_tensor_src.max_range_bytes
Louis Verhaardcc5f4de2022-03-01 11:26:58 +0100322 core_offset += round_up(weight_range.total_bytes, 16)
Tim Hallb5df7732022-05-04 16:20:43 +0100323 else:
324 if weight_tensor == w_tensor_src:
325 # Straight from source tensor
326 address = weight_tensor.address + weight_range.offset
327 else:
328 # Single buffered inside weight tensor
329 address = weight_tensor.address + core_offset
330 core_offset += round_up(weight_range.total_bytes, 16)
Tim Halld8339a72021-05-27 18:49:40 +0100331
332 # Location of weights in tensor
333 addr_range = NpuAddressRange(
Tim Halld784af72021-06-08 21:25:57 +0100334 shared_region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16)
Tim Halld8339a72021-05-27 18:49:40 +0100335 )
336 weights.append(addr_range)
Tim Halld784af72021-06-08 21:25:57 +0100337
338 # Location of standalone scales or combined weights tensor scales
339 if scale_tensor:
340 assert scale_tensor.src_tensor is None # Must be standalone
341 scale_range = scale_tensor.encoded_ranges[key]
342 address = scale_tensor.address + scale_range.offset
343 addr_range = NpuAddressRange(scale_region, int(address), round_up(int(scale_range.scale_bytes), 16))
344 else:
345 addr_range = NpuAddressRange(shared_region, int(address), round_up(int(weight_range.scale_bytes), 16))
346
Tim Halld8339a72021-05-27 18:49:40 +0100347 biases.append(addr_range)
348
349 return weights, biases
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100350
351
352def create_npu_activation(op: Operation) -> NpuActivation:
353 """Creates fused activation function"""
354 if op.activation is None:
355 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
356 faf = op.activation.op_type
357 act_op = NpuActivationOp.NONE_OR_RELU
358 if faf == Op.Tanh:
359 act_op = NpuActivationOp.TANH
360 elif faf == Op.Sigmoid:
361 act_op = NpuActivationOp.SIGMOID
362 elif faf == Op.LUT:
363 act_op = NpuActivationOp.TABLE_LOOKUP
364 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000365 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100366
367 act = NpuActivation(act_op)
368 act.min = op.activation.min
369 act.max = op.activation.max
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200370 if act_op is NpuActivationOp.NONE_OR_RELU and op.type.is_avgpool_op() and not op.rescale:
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200371 quant = op.ofm.quantization
372 if quant and quant.zero_point: # Zero point is not 0
373 scale_f32 = 1 if quant.scale_f32 is None else quant.scale_f32
374 zero_point = quant.zero_point
375 if act.min is not None:
376 act.min = scale_f32 * quantise_float32(act.min, scale_f32, zero_point)
377 if act.max is not None:
378 act.max = scale_f32 * quantise_float32(act.max, scale_f32, zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100379 act.lookup_table_index = op.activation.lut_index
380 return act
381
382
383def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
384 """Sets common fields of the given operation"""
385 ps = cmd.ps
386 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100387
388 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100389 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100390 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100391
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100392 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100393 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100394 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100395
396 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100397 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100398 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100399 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
400
401 if cmd.weight_tensor is not None:
Tim Halld784af72021-06-08 21:25:57 +0100402 npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, cmd.scale_tensor, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100403 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100404 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
405 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100406 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
407
408 if not op.type.is_elementwise_op():
409 npu_op.padding = create_padding(cmd, op)
410 npu_op.kernel = to_npu_kernel(op.kernel)
Tim Hall3c5cfe92022-03-16 16:31:57 +0000411 npu_op.ifm_upscale = resampling_mode_inv_map[op.ifm_resampling_mode]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100412 return npu_op
413
414
415def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
416 """Converts the command to NpuConv2DOperation"""
417 npu_op = NpuConv2DOperation()
418 set_common_op_fields(npu_op, cmd, arch)
419 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
420 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
421 else:
Tim Halld8339a72021-05-27 18:49:40 +0100422 if cmd.weight_tensor.src_tensor:
423 npu_op.block_traversal = cmd.weight_tensor.src_tensor.hw_traversal
424 else:
425 npu_op.block_traversal = cmd.weight_tensor.hw_traversal
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100426 return npu_op
427
428
429def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
430 """Converts the command to NpuConvDepthWiseOperation"""
431 npu_op = NpuConvDepthWiseOperation()
432 set_common_op_fields(npu_op, cmd, arch)
433 return npu_op
434
435
436def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
437 """Converts the command to NpuPoolingOperation"""
438 ps = cmd.ps
439 op = ps.primary_op
440 pool_op = NpuPoolingOp.AVERAGE
441 if op.type.is_maxpool_op():
442 pool_op = NpuPoolingOp.MAX
443 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
444 pool_op = NpuPoolingOp.AVERAGE
445 elif op.type == Op.ReduceSum:
446 pool_op = NpuPoolingOp.REDUCE_SUM
447 else:
448 assert 0, f"Unknown pool type {op.type}"
449 npu_op = NpuPoolingOperation(pool_op)
450 set_common_op_fields(npu_op, cmd, arch)
451 # Pooling specific info
Dwight Lidman4f728c02020-12-17 15:14:45 +0100452 npu_op.rescale = op.rescale
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200453 if op.explicit_scaling:
454 # Note: reuse of rescale for explicit scaling to not expose this in the external API
455 assert npu_op.rescale is None
456 npu_op.rescale = op.explicit_scaling
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100457 return npu_op
458
459
460def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
461 """Converts the command to NpuElementWiseOperation"""
462 ps = cmd.ps
463 op = ps.primary_op
464 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
465 elemwise_op = elementwise_op_map[op.type]
466 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100467
Louis Verhaard1e170182020-11-26 11:42:04 +0100468 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100469 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
470 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
471 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100472 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
473 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
474 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100475 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100476 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100477 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100478 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
479 if cmd.ifm2_tensor.shape == []:
480 # scalar
James Peet7519d502021-07-19 16:47:58 +0100481 npu_op.ifm2_scalar = cmd.ifm2_tensor.get_scalar()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100482 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
483 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100484 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100485 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100486 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100487 set_common_op_fields(npu_op, cmd, arch)
488 # Check if output scale needs to be overridden
489 output_scale = None
490 if op.type == Op.Add and "resizebilinear" in op.attrs:
491 # Force output scale same as the input scale for
492 # resizebilinear 1x1 that is converted to add
493 output_scale = npu_op.ifm2.quantization.scale_f32
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100494 if op.type == Op.Abs:
495 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100496 if op.type == Op.LeakyRelu:
497 output_scale = op.attrs["alpha"]
Patrik Gustavssonb081d672021-08-25 13:49:25 +0200498 if op.type in (Op.RescaleAdd, Op.RescaleMul):
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100499 assert op.rescale is not None, f"{op.type} must have rescale"
500 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100501 if op.type in (Op.Add, Op.Mul, Op.Sub):
502 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
503 output_scale = 1 / 0x3000
504 if output_scale is not None:
505 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
506 return npu_op
507
508
509def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
510 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100511 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100512 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100513 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100514 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100515 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100516
Tim Halld8339a72021-05-27 18:49:40 +0100517 if cmd.in_tensor.purpose == TensorPurpose.Weights:
518 # Get weight range per core
519 sz = 0
520 for core in range(0, arch.ncores):
521 key = WeightKey(core, cmd.box.start_coord[-1])
522 if key in cmd.in_tensor.encoded_ranges:
523 weight_range = cmd.in_tensor.encoded_ranges[key]
524 sz += round_up(weight_range.total_bytes, 16)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100525
Tim Halld8339a72021-05-27 18:49:40 +0100526 if core == 0:
527 weight_range = cmd.in_tensor.encoded_ranges[key]
528 src_addr = cmd.in_tensor.address + weight_range.offset
Tim Hallb5df7732022-05-04 16:20:43 +0100529
530 if cmd.out_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
531 dest_addr = cmd.out_tensor.address + cmd.in_tensor.max_range_bytes * (
532 get_double_buffer_offset(arch, weight_range.index, core)
533 )
534 else:
535 dest_addr = cmd.out_tensor.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100536 else:
Tim Halld8339a72021-05-27 18:49:40 +0100537 start_coord = cmd.box.start_coord
538 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
539 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100540 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
541 src = NpuAddressRange(src_region, int(src_addr), int(sz))
542 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
543 return NpuDmaOperation(src, dest)
544
545
546def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
547 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100548 npu_op: NpuOperation
549 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100550 npu_op = create_dma_op(cmd, arch)
Tim Hall68df8a12022-03-16 16:51:16 +0000551 npu_op.name = cmd.out_tensor.name
Dwight Lidman9b43f842020-12-08 17:56:44 +0100552 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100553 npu_block_type = cmd.ps.primary_op.type.npu_block_type
554 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
555 npu_op = create_npu_conv2d_op(cmd, arch)
556 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
557 npu_op = create_npu_conv_depthwise_op(cmd, arch)
558 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
559 npu_op = create_npu_pool_op(cmd, arch)
560 elif npu_block_type == NpuBlockType.ElementWise:
561 npu_op = create_npu_elementwise_op(cmd, arch)
562 else:
563 assert 0, f"Unknown command type {npu_block_type}"
Tim Hall68df8a12022-03-16 16:51:16 +0000564 npu_op.name = cmd.ps.primary_op.name
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100565 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100566
567
568def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
569 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
570 # Convert high level command stream to list of NpuOperation
571 npu_op_list = []
572 npu_op_to_cmd = dict() # map from npu op to high level command
573 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100574 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100575 print("Warning: Skipping register command stream generation for", cmd.ps)
576 else:
577 npu_op = convert_command_to_npu_op(cmd, arch)
578 npu_op_list.append(npu_op)
579 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100580 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100581 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100582 if len(sg.high_level_command_stream) > 0:
583 stream_id = DebugDatabase.add_stream(sg)
584 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100585
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100586 def add_to_debug_db(npu_op: NpuOperation, offset: int):
587 """Adds info to the debug database"""
588 if not isinstance(npu_op, NpuDmaOperation):
589 cmd = npu_op_to_cmd[npu_op]
590 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100591
Louis Verhaard024c3552021-03-17 14:26:34 +0100592 sg.register_command_stream = generate_command_stream(
593 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
594 )