blob: 8c5525b01440a58140d555549510efd3d076a1d0 [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Jonas Ohlsson845e2322022-03-01 12:39:55 +010020from typing import cast
Louis Verhaard024c3552021-03-17 14:26:34 +010021from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010022from typing import List
23from typing import Optional
Jonas Ohlsson845e2322022-03-01 12:39:55 +010024from typing import Tuple
Louis Verhaarde8a5a782020-11-02 18:04:27 +010025
26from .api import NpuActivation
27from .api import NpuActivationOp
28from .api import NpuAddressRange
29from .api import NpuBlockOperation
30from .api import NpuBlockTraversal
31from .api import NpuConv2DOperation
32from .api import NpuConvDepthWiseOperation
33from .api import NpuDataType
34from .api import NpuDmaOperation
35from .api import NpuElementWiseOp
36from .api import NpuElementWiseOperation
37from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010038from .api import NpuLayout
39from .api import NpuOperation
40from .api import NpuPadding
41from .api import NpuPoolingOp
42from .api import NpuPoolingOperation
43from .api import NpuQuantization
44from .api import NpuResamplingMode
45from .api import NpuRoundingMode
46from .api import NpuShape3D
47from .api import NpuTileBox
48from .architecture_features import ArchitectureFeatures
49from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010050from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000051from .errors import UnsupportedFeatureError
Tim Hall3c5cfe92022-03-16 16:31:57 +000052from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +010053from .high_level_command_stream import Box
54from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010055from .high_level_command_stream import DMA
56from .high_level_command_stream import NpuStripe
Fredrik Svedberg838df0a2021-09-17 16:29:22 +020057from .numeric_util import quantise_float32
Tim Halld8339a72021-05-27 18:49:40 +010058from .numeric_util import round_up
Louis Verhaarde8a5a782020-11-02 18:04:27 +010059from .operation import NpuBlockType
60from .operation import Op
61from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010062from .register_command_stream_generator import generate_command_stream
63from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010064from .register_command_stream_util import to_npu_kernel
65from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000066from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010067from .tensor import MemType
68from .tensor import Tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010069from .tensor import TensorFormat
70from .tensor import TensorPurpose
Jonas Ohlsson845e2322022-03-01 12:39:55 +010071from .weight_compressor import NpuWeightTensor
Tim Halld8339a72021-05-27 18:49:40 +010072from .weight_compressor import WeightKey
Louis Verhaarde8a5a782020-11-02 18:04:27 +010073
74
Louis Verhaarde8a5a782020-11-02 18:04:27 +010075class BasePointerIndex(IntEnum):
76 WeightTensor = 0 # base address index for the Weight tensor
77 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
78 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010079
80
81dtype_map = {
82 DataType.uint8: NpuDataType.UINT8,
83 DataType.int8: NpuDataType.INT8,
84 DataType.uint16: NpuDataType.UINT16,
85 DataType.int16: NpuDataType.INT16,
86 DataType.int32: NpuDataType.INT32,
87}
88
89
Louis Verhaarde8a5a782020-11-02 18:04:27 +010090# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
91elementwise_op_map = {
92 Op.Mul: NpuElementWiseOp.MUL,
Patrik Gustavssonb081d672021-08-25 13:49:25 +020093 Op.RescaleMul: NpuElementWiseOp.MUL,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010094 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010095 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010096 Op.Sub: NpuElementWiseOp.SUB,
97 Op.Minimum: NpuElementWiseOp.MIN,
98 Op.Maximum: NpuElementWiseOp.MAX,
99 Op.LeakyRelu: NpuElementWiseOp.LRELU,
100 Op.Abs: NpuElementWiseOp.ABS,
101 Op.CLZ: NpuElementWiseOp.CLZ,
102 Op.SHR: NpuElementWiseOp.SHR,
103 Op.SHL: NpuElementWiseOp.SHL,
104}
105
106
Tim Hall3c5cfe92022-03-16 16:31:57 +0000107# inverse of the resampling_mode_map in the register command stream generator
108resampling_mode_inv_map = {
109 resampling_mode.NONE: NpuResamplingMode.NONE,
110 resampling_mode.NEAREST: NpuResamplingMode.NEAREST,
111 resampling_mode.TRANSPOSE: NpuResamplingMode.TRANSPOSE,
112}
113
114
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100115def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
116 if ifm_shape == []:
117 # Scalar needs to be in IFM2
118 return False
119 if ifm2_shape == []:
120 return True
121
122 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
123 if ifm != ifm2 and ifm == 1:
124 # Broadcasted FM needs to be in IFM2
125 return False
126 return True
127
128
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100129def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100130 """Specifies type of rounding to be used"""
131 rounding_mode = NpuRoundingMode.TFL
132 if op.type == Op.ResizeBilinear:
Dwight Lidman9d243932021-08-10 12:53:12 +0200133 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100134 elif (
135 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
136 and op.ifm.dtype == DataType.int16
137 ):
138 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100139 elif (
140 not fused_quantize
141 and op.type.is_avgpool_op()
142 and op.memory_function == Op.ConcatSliceWrite
143 and op.kernel.elements_wh() == 1
144 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100145 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100146 if op.rounding_mode is not None:
147 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100148 return rounding_mode
149
150
151def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
152 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
153 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100154 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100155
156 # Check if this is for horizontal ifm streaming
157 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100158 top = cmd.pad_top
159 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100160
Tim Hall3751aa42021-12-16 13:17:29 +0000161 # the ifm box coordinate range depends upon whether the primary op was combined with a split slice read
162 ifm_read_offset = primary_op.read_offsets[0]
163 ifm_read_shape = primary_op.read_shapes[0]
164 if ifm_read_offset is None or len(ifm_read_offset) < 2:
165 box_start_coord_min = 0
166 box_end_coord_max = cmd.ps.ifm_shapes[0].width
167 else:
168 box_start_coord_min = ifm_read_offset[-2]
169 box_end_coord_max = ifm_read_shape[-2]
170
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100171 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
172 # because of activation function needed to be fused.
Tim Hall3751aa42021-12-16 13:17:29 +0000173 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > box_start_coord_min:
174 left = 0
175 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < box_end_coord_max:
176 right = 0
Louis Verhaard69b31762020-11-17 09:45:20 +0100177 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100178
179
Louis Verhaard024c3552021-03-17 14:26:34 +0100180def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000181 base_ptr_idx_map = {
182 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
183 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
184 MemType.Scratch: BasePointerIndex.ScratchTensor,
185 }
186
187 if arch.is_spilling_enabled():
188 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100189 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000190 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
191
Louis Verhaard024c3552021-03-17 14:26:34 +0100192 return base_ptr_idx_map[mem_type].value
193
194
195def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
196 """Returns map region -> max size of the region in bytes"""
197 mem_limits = dict()
198 for mem_type in MemType.all():
199 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
200 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
201 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100202
203
Louis Verhaardcc5f4de2022-03-01 11:26:58 +0100204def get_upscale(op: Operation) -> NpuResamplingMode:
205 upscale = NpuResamplingMode.NONE
206 if op.type == Op.ResizeBilinear:
207 # perform nearest neighbor upscale
208 upscale = NpuResamplingMode.NEAREST
209 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
210 # perform insert zero upscale
211 upscale = NpuResamplingMode.TRANSPOSE
212 return upscale
Louis Verhaarde91b5312022-01-21 13:38:50 +0100213
214
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100215def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
216 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100217 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100218 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100219 block = ofm_box.get_block()
220 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100221
222
223def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
224 """Checks if quantization should use 0 as zero point"""
225 if tens.dtype == DataType.int32 and is_ifm_tensor:
226 return True
227 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
228 return False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200229 if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
230 return False
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100231 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
232 forced_ofm_quantization = ps.primary_op.forced_output_quantization
233 use_0 = (
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200234 (
235 ps.primary_op.activation is None
236 or forced_ofm_quantization is not None
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200237 or (
238 ps.primary_op.type.is_avgpool_op()
239 and ps.primary_op.activation.op_type.is_relu_op()
240 and not ps.primary_op.rescale
241 )
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200242 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100243 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
244 and not fused_quantize
245 )
246 return use_0
247
248
249def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
250 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100251 op = ps.primary_op
252 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
253 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100254 return None
255 if use_zero_point_0(ps, tens, True):
256 zero_point = 0
257 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100258 zero_point = int(ifm_quant.zero_point)
259 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100260
261
262def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
263 """Gets quantization for OFM"""
264 op = ps.primary_op
265 # Check if operation's output quantization is should be used instead of the output tensor's quantization
266 # (used in LUTs)
267 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
268 if ofm_quant is None:
269 return None
270 if use_zero_point_0(ps, tens, False):
271 zero_point = 0
272 else:
273 zero_point = int(ofm_quant.zero_point)
274 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
275
276
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100277def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100278 """Creates feature map with common fields populated"""
279 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100280 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100281 fm.data_type = dtype_map[tens.dtype]
282 if tens.format == TensorFormat.NHWC:
283 fm.layout = NpuLayout.NHWC
284 elif tens.format == TensorFormat.NHCWB16:
285 fm.layout = NpuLayout.NHCWB16
286 else:
287 assert 0, "Incorrect tensor format"
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100288 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
289 box.start_coord, box.end_coord, op_shape4D
290 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100291 for idx, addr in enumerate(addresses):
292 if addr is None:
293 addresses[idx] = 0
294 fm.tiles = NpuTileBox(
295 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
296 )
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100297 strides = tens.get_strides(shape4D=op_shape4D)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100298 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
299 return fm
300
301
Tim Halld784af72021-06-08 21:25:57 +0100302def create_weights(
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100303 weight_tensor: NpuWeightTensor, weight_box: Box, scale_tensor: NpuWeightTensor, arch: ArchitectureFeatures
304) -> Tuple[List[NpuAddressRange], List[NpuAddressRange]]:
Tim Halld8339a72021-05-27 18:49:40 +0100305 """Returns address ranges for weights and scales"""
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100306 weights = []
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100307 biases = []
Tim Halld784af72021-06-08 21:25:57 +0100308 shared_region = get_region(weight_tensor.mem_type, arch)
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100309 scale_region = get_region(scale_tensor.mem_type, arch) if scale_tensor else 0
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100310
Tim Halld8339a72021-05-27 18:49:40 +0100311 w_tensor_src = weight_tensor
312 if weight_tensor.src_tensor:
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100313 w_tensor_src = cast(NpuWeightTensor, weight_tensor.src_tensor)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100314
Tim Halld8339a72021-05-27 18:49:40 +0100315 core_offset = 0
316 for core in range(0, arch.ncores):
317 # Get weight range per core
318 key = WeightKey(core, weight_box.start_coord[-1])
319 if key in w_tensor_src.encoded_ranges:
320 weight_range = w_tensor_src.encoded_ranges[key]
Louis Verhaardcc5f4de2022-03-01 11:26:58 +0100321 if weight_tensor == w_tensor_src:
322 # Straight from source tensor
323 address = weight_tensor.address + weight_range.offset
Tim Halld8339a72021-05-27 18:49:40 +0100324 else:
Louis Verhaardcc5f4de2022-03-01 11:26:58 +0100325 # Weight buffered tensor
326 address = weight_tensor.address + core_offset
327 core_offset += round_up(weight_range.total_bytes, 16)
Tim Halld8339a72021-05-27 18:49:40 +0100328
329 # Location of weights in tensor
330 addr_range = NpuAddressRange(
Tim Halld784af72021-06-08 21:25:57 +0100331 shared_region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16)
Tim Halld8339a72021-05-27 18:49:40 +0100332 )
333 weights.append(addr_range)
Tim Halld784af72021-06-08 21:25:57 +0100334
335 # Location of standalone scales or combined weights tensor scales
336 if scale_tensor:
337 assert scale_tensor.src_tensor is None # Must be standalone
338 scale_range = scale_tensor.encoded_ranges[key]
339 address = scale_tensor.address + scale_range.offset
340 addr_range = NpuAddressRange(scale_region, int(address), round_up(int(scale_range.scale_bytes), 16))
341 else:
342 addr_range = NpuAddressRange(shared_region, int(address), round_up(int(weight_range.scale_bytes), 16))
343
Tim Halld8339a72021-05-27 18:49:40 +0100344 biases.append(addr_range)
345
346 return weights, biases
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100347
348
349def create_npu_activation(op: Operation) -> NpuActivation:
350 """Creates fused activation function"""
351 if op.activation is None:
352 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
353 faf = op.activation.op_type
354 act_op = NpuActivationOp.NONE_OR_RELU
355 if faf == Op.Tanh:
356 act_op = NpuActivationOp.TANH
357 elif faf == Op.Sigmoid:
358 act_op = NpuActivationOp.SIGMOID
359 elif faf == Op.LUT:
360 act_op = NpuActivationOp.TABLE_LOOKUP
361 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000362 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100363
364 act = NpuActivation(act_op)
365 act.min = op.activation.min
366 act.max = op.activation.max
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200367 if act_op is NpuActivationOp.NONE_OR_RELU and op.type.is_avgpool_op() and not op.rescale:
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200368 quant = op.ofm.quantization
369 if quant and quant.zero_point: # Zero point is not 0
370 scale_f32 = 1 if quant.scale_f32 is None else quant.scale_f32
371 zero_point = quant.zero_point
372 if act.min is not None:
373 act.min = scale_f32 * quantise_float32(act.min, scale_f32, zero_point)
374 if act.max is not None:
375 act.max = scale_f32 * quantise_float32(act.max, scale_f32, zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100376 act.lookup_table_index = op.activation.lut_index
377 return act
378
379
380def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
381 """Sets common fields of the given operation"""
382 ps = cmd.ps
383 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100384
385 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100386 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100387 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100388
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100389 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100390 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100391 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100392
393 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100394 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100395 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100396 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
397
398 if cmd.weight_tensor is not None:
Tim Halld784af72021-06-08 21:25:57 +0100399 npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, cmd.scale_tensor, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100400 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100401 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
402 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100403 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
404
405 if not op.type.is_elementwise_op():
406 npu_op.padding = create_padding(cmd, op)
407 npu_op.kernel = to_npu_kernel(op.kernel)
Tim Hall3c5cfe92022-03-16 16:31:57 +0000408 npu_op.ifm_upscale = resampling_mode_inv_map[op.ifm_resampling_mode]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100409 return npu_op
410
411
412def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
413 """Converts the command to NpuConv2DOperation"""
414 npu_op = NpuConv2DOperation()
415 set_common_op_fields(npu_op, cmd, arch)
416 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
417 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
418 else:
Tim Halld8339a72021-05-27 18:49:40 +0100419 if cmd.weight_tensor.src_tensor:
420 npu_op.block_traversal = cmd.weight_tensor.src_tensor.hw_traversal
421 else:
422 npu_op.block_traversal = cmd.weight_tensor.hw_traversal
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100423 return npu_op
424
425
426def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
427 """Converts the command to NpuConvDepthWiseOperation"""
428 npu_op = NpuConvDepthWiseOperation()
429 set_common_op_fields(npu_op, cmd, arch)
430 return npu_op
431
432
433def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
434 """Converts the command to NpuPoolingOperation"""
435 ps = cmd.ps
436 op = ps.primary_op
437 pool_op = NpuPoolingOp.AVERAGE
438 if op.type.is_maxpool_op():
439 pool_op = NpuPoolingOp.MAX
440 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
441 pool_op = NpuPoolingOp.AVERAGE
442 elif op.type == Op.ReduceSum:
443 pool_op = NpuPoolingOp.REDUCE_SUM
444 else:
445 assert 0, f"Unknown pool type {op.type}"
446 npu_op = NpuPoolingOperation(pool_op)
447 set_common_op_fields(npu_op, cmd, arch)
448 # Pooling specific info
Dwight Lidman4f728c02020-12-17 15:14:45 +0100449 npu_op.rescale = op.rescale
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200450 if op.explicit_scaling:
451 # Note: reuse of rescale for explicit scaling to not expose this in the external API
452 assert npu_op.rescale is None
453 npu_op.rescale = op.explicit_scaling
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100454 return npu_op
455
456
457def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
458 """Converts the command to NpuElementWiseOperation"""
459 ps = cmd.ps
460 op = ps.primary_op
461 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
462 elemwise_op = elementwise_op_map[op.type]
463 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100464
Louis Verhaard1e170182020-11-26 11:42:04 +0100465 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100466 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
467 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
468 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100469 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
470 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
471 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100472 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100473 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100474 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100475 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
476 if cmd.ifm2_tensor.shape == []:
477 # scalar
James Peet7519d502021-07-19 16:47:58 +0100478 npu_op.ifm2_scalar = cmd.ifm2_tensor.get_scalar()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100479 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
480 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100481 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100482 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100483 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100484 set_common_op_fields(npu_op, cmd, arch)
485 # Check if output scale needs to be overridden
486 output_scale = None
487 if op.type == Op.Add and "resizebilinear" in op.attrs:
488 # Force output scale same as the input scale for
489 # resizebilinear 1x1 that is converted to add
490 output_scale = npu_op.ifm2.quantization.scale_f32
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100491 if op.type == Op.Abs:
492 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100493 if op.type == Op.LeakyRelu:
494 output_scale = op.attrs["alpha"]
Patrik Gustavssonb081d672021-08-25 13:49:25 +0200495 if op.type in (Op.RescaleAdd, Op.RescaleMul):
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100496 assert op.rescale is not None, f"{op.type} must have rescale"
497 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100498 if op.type in (Op.Add, Op.Mul, Op.Sub):
499 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
500 output_scale = 1 / 0x3000
501 if output_scale is not None:
502 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
503 return npu_op
504
505
506def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
507 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100508 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100509 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100510 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100511 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100512 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100513
Tim Halld8339a72021-05-27 18:49:40 +0100514 if cmd.in_tensor.purpose == TensorPurpose.Weights:
515 # Get weight range per core
516 sz = 0
517 for core in range(0, arch.ncores):
518 key = WeightKey(core, cmd.box.start_coord[-1])
519 if key in cmd.in_tensor.encoded_ranges:
520 weight_range = cmd.in_tensor.encoded_ranges[key]
521 sz += round_up(weight_range.total_bytes, 16)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100522
Tim Halld8339a72021-05-27 18:49:40 +0100523 if core == 0:
524 weight_range = cmd.in_tensor.encoded_ranges[key]
525 src_addr = cmd.in_tensor.address + weight_range.offset
Louis Verhaardcc5f4de2022-03-01 11:26:58 +0100526 dest_addr = cmd.out_tensor.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100527 else:
Tim Halld8339a72021-05-27 18:49:40 +0100528 start_coord = cmd.box.start_coord
529 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
530 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100531 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
532 src = NpuAddressRange(src_region, int(src_addr), int(sz))
533 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
534 return NpuDmaOperation(src, dest)
535
536
537def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
538 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100539 npu_op: NpuOperation
540 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100541 npu_op = create_dma_op(cmd, arch)
Dwight Lidman9b43f842020-12-08 17:56:44 +0100542 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100543 npu_block_type = cmd.ps.primary_op.type.npu_block_type
544 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
545 npu_op = create_npu_conv2d_op(cmd, arch)
546 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
547 npu_op = create_npu_conv_depthwise_op(cmd, arch)
548 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
549 npu_op = create_npu_pool_op(cmd, arch)
550 elif npu_block_type == NpuBlockType.ElementWise:
551 npu_op = create_npu_elementwise_op(cmd, arch)
552 else:
553 assert 0, f"Unknown command type {npu_block_type}"
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100554 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100555
556
557def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
558 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
559 # Convert high level command stream to list of NpuOperation
560 npu_op_list = []
561 npu_op_to_cmd = dict() # map from npu op to high level command
562 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100563 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100564 print("Warning: Skipping register command stream generation for", cmd.ps)
565 else:
566 npu_op = convert_command_to_npu_op(cmd, arch)
567 npu_op_list.append(npu_op)
568 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100569 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100570 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100571 if len(sg.high_level_command_stream) > 0:
572 stream_id = DebugDatabase.add_stream(sg)
573 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100574
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100575 def add_to_debug_db(npu_op: NpuOperation, offset: int):
576 """Adds info to the debug database"""
577 if not isinstance(npu_op, NpuDmaOperation):
578 cmd = npu_op_to_cmd[npu_op]
579 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100580
Louis Verhaard024c3552021-03-17 14:26:34 +0100581 sg.register_command_stream = generate_command_stream(
582 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
583 )