blob: e6bfc1c4c4d596cad9a3b6381501c8b1afa6ecf9 [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Jonas Ohlsson845e2322022-03-01 12:39:55 +010020from typing import cast
Louis Verhaard024c3552021-03-17 14:26:34 +010021from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010022from typing import List
23from typing import Optional
Jonas Ohlsson845e2322022-03-01 12:39:55 +010024from typing import Tuple
Louis Verhaarde8a5a782020-11-02 18:04:27 +010025
26from .api import NpuActivation
27from .api import NpuActivationOp
28from .api import NpuAddressRange
29from .api import NpuBlockOperation
30from .api import NpuBlockTraversal
31from .api import NpuConv2DOperation
32from .api import NpuConvDepthWiseOperation
33from .api import NpuDataType
34from .api import NpuDmaOperation
35from .api import NpuElementWiseOp
36from .api import NpuElementWiseOperation
37from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010038from .api import NpuLayout
39from .api import NpuOperation
40from .api import NpuPadding
41from .api import NpuPoolingOp
42from .api import NpuPoolingOperation
43from .api import NpuQuantization
44from .api import NpuResamplingMode
45from .api import NpuRoundingMode
46from .api import NpuShape3D
47from .api import NpuTileBox
48from .architecture_features import ArchitectureFeatures
49from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010050from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000051from .errors import UnsupportedFeatureError
Tim Hall3c5cfe92022-03-16 16:31:57 +000052from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +010053from .high_level_command_stream import Box
54from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010055from .high_level_command_stream import DMA
56from .high_level_command_stream import NpuStripe
Fredrik Svedberg838df0a2021-09-17 16:29:22 +020057from .numeric_util import quantise_float32
Tim Halld8339a72021-05-27 18:49:40 +010058from .numeric_util import round_up
Louis Verhaarde8a5a782020-11-02 18:04:27 +010059from .operation import NpuBlockType
60from .operation import Op
61from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010062from .register_command_stream_generator import generate_command_stream
63from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010064from .register_command_stream_util import to_npu_kernel
65from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000066from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010067from .tensor import MemType
68from .tensor import Tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010069from .tensor import TensorFormat
70from .tensor import TensorPurpose
Jonas Ohlsson845e2322022-03-01 12:39:55 +010071from .weight_compressor import NpuWeightTensor
Tim Halld8339a72021-05-27 18:49:40 +010072from .weight_compressor import WeightKey
Louis Verhaarde8a5a782020-11-02 18:04:27 +010073
74
Louis Verhaarde8a5a782020-11-02 18:04:27 +010075class BasePointerIndex(IntEnum):
76 WeightTensor = 0 # base address index for the Weight tensor
77 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
78 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010079
80
81dtype_map = {
82 DataType.uint8: NpuDataType.UINT8,
83 DataType.int8: NpuDataType.INT8,
84 DataType.uint16: NpuDataType.UINT16,
85 DataType.int16: NpuDataType.INT16,
86 DataType.int32: NpuDataType.INT32,
87}
88
89
Louis Verhaarde8a5a782020-11-02 18:04:27 +010090# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
91elementwise_op_map = {
92 Op.Mul: NpuElementWiseOp.MUL,
Patrik Gustavssonb081d672021-08-25 13:49:25 +020093 Op.RescaleMul: NpuElementWiseOp.MUL,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010094 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010095 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010096 Op.Sub: NpuElementWiseOp.SUB,
97 Op.Minimum: NpuElementWiseOp.MIN,
98 Op.Maximum: NpuElementWiseOp.MAX,
99 Op.LeakyRelu: NpuElementWiseOp.LRELU,
100 Op.Abs: NpuElementWiseOp.ABS,
101 Op.CLZ: NpuElementWiseOp.CLZ,
102 Op.SHR: NpuElementWiseOp.SHR,
103 Op.SHL: NpuElementWiseOp.SHL,
104}
105
106
Tim Hall3c5cfe92022-03-16 16:31:57 +0000107# inverse of the resampling_mode_map in the register command stream generator
108resampling_mode_inv_map = {
109 resampling_mode.NONE: NpuResamplingMode.NONE,
110 resampling_mode.NEAREST: NpuResamplingMode.NEAREST,
111 resampling_mode.TRANSPOSE: NpuResamplingMode.TRANSPOSE,
112}
113
114
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100115def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
116 if ifm_shape == []:
117 # Scalar needs to be in IFM2
118 return False
119 if ifm2_shape == []:
120 return True
121
122 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
123 if ifm != ifm2 and ifm == 1:
124 # Broadcasted FM needs to be in IFM2
125 return False
126 return True
127
128
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100129def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100130 """Specifies type of rounding to be used"""
131 rounding_mode = NpuRoundingMode.TFL
132 if op.type == Op.ResizeBilinear:
Dwight Lidman9d243932021-08-10 12:53:12 +0200133 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100134 elif (
135 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
136 and op.ifm.dtype == DataType.int16
137 ):
138 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100139 elif (
140 not fused_quantize
141 and op.type.is_avgpool_op()
142 and op.memory_function == Op.ConcatSliceWrite
143 and op.kernel.elements_wh() == 1
144 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100145 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100146 if op.rounding_mode is not None:
147 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100148 return rounding_mode
149
150
151def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
152 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
153 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100154 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100155
156 # Check if this is for horizontal ifm streaming
157 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100158 top = cmd.pad_top
159 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100160
Tim Hall3751aa42021-12-16 13:17:29 +0000161 # the ifm box coordinate range depends upon whether the primary op was combined with a split slice read
162 ifm_read_offset = primary_op.read_offsets[0]
163 ifm_read_shape = primary_op.read_shapes[0]
164 if ifm_read_offset is None or len(ifm_read_offset) < 2:
165 box_start_coord_min = 0
166 box_end_coord_max = cmd.ps.ifm_shapes[0].width
167 else:
168 box_start_coord_min = ifm_read_offset[-2]
169 box_end_coord_max = ifm_read_shape[-2]
170
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100171 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
172 # because of activation function needed to be fused.
Tim Hall3751aa42021-12-16 13:17:29 +0000173 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > box_start_coord_min:
174 left = 0
175 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < box_end_coord_max:
176 right = 0
Louis Verhaard69b31762020-11-17 09:45:20 +0100177 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100178
179
Louis Verhaard024c3552021-03-17 14:26:34 +0100180def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000181 base_ptr_idx_map = {
182 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
183 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
184 MemType.Scratch: BasePointerIndex.ScratchTensor,
185 }
186
187 if arch.is_spilling_enabled():
188 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100189 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000190 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
191
Louis Verhaard024c3552021-03-17 14:26:34 +0100192 return base_ptr_idx_map[mem_type].value
193
194
195def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
196 """Returns map region -> max size of the region in bytes"""
197 mem_limits = dict()
198 for mem_type in MemType.all():
199 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
200 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
201 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100202
203
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000204def get_upscale(op: Operation) -> NpuResamplingMode:
205 upscale = NpuResamplingMode.NONE
206 if op.type == Op.ResizeBilinear:
207 # perform nearest neighbor upscale
208 upscale = NpuResamplingMode.NEAREST
209 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
210 # perform insert zero upscale
211 upscale = NpuResamplingMode.TRANSPOSE
212 return upscale
Louis Verhaarde91b5312022-01-21 13:38:50 +0100213
214
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100215def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
216 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100217 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100218 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100219 block = ofm_box.get_block()
220 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100221
222
223def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
224 """Checks if quantization should use 0 as zero point"""
225 if tens.dtype == DataType.int32 and is_ifm_tensor:
226 return True
227 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
228 return False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200229 if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
230 return False
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100231 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
232 forced_ofm_quantization = ps.primary_op.forced_output_quantization
233 use_0 = (
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200234 (
235 ps.primary_op.activation is None
236 or forced_ofm_quantization is not None
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200237 or (
238 ps.primary_op.type.is_avgpool_op()
239 and ps.primary_op.activation.op_type.is_relu_op()
240 and not ps.primary_op.rescale
241 )
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200242 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100243 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
244 and not fused_quantize
245 )
246 return use_0
247
248
249def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
250 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100251 op = ps.primary_op
252 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
253 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100254 return None
255 if use_zero_point_0(ps, tens, True):
256 zero_point = 0
257 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100258 zero_point = int(ifm_quant.zero_point)
259 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100260
261
262def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
263 """Gets quantization for OFM"""
264 op = ps.primary_op
265 # Check if operation's output quantization is should be used instead of the output tensor's quantization
266 # (used in LUTs)
267 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
268 if ofm_quant is None:
269 return None
270 if use_zero_point_0(ps, tens, False):
271 zero_point = 0
272 else:
273 zero_point = int(ofm_quant.zero_point)
274 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
275
276
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100277def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100278 """Creates feature map with common fields populated"""
279 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100280 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100281 fm.data_type = dtype_map[tens.dtype]
282 if tens.format == TensorFormat.NHWC:
283 fm.layout = NpuLayout.NHWC
284 elif tens.format == TensorFormat.NHCWB16:
285 fm.layout = NpuLayout.NHCWB16
286 else:
287 assert 0, "Incorrect tensor format"
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100288 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
289 box.start_coord, box.end_coord, op_shape4D
290 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100291 for idx, addr in enumerate(addresses):
292 if addr is None:
293 addresses[idx] = 0
294 fm.tiles = NpuTileBox(
295 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
296 )
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100297 strides = tens.get_strides(shape4D=op_shape4D)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100298 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
Tim Hall68df8a12022-03-16 16:51:16 +0000299 fm.name = tens.name
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100300 return fm
301
302
Tim Halld784af72021-06-08 21:25:57 +0100303def create_weights(
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100304 weight_tensor: NpuWeightTensor, weight_box: Box, scale_tensor: NpuWeightTensor, arch: ArchitectureFeatures
305) -> Tuple[List[NpuAddressRange], List[NpuAddressRange]]:
Tim Halld8339a72021-05-27 18:49:40 +0100306 """Returns address ranges for weights and scales"""
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100307 weights = []
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100308 biases = []
Tim Halld784af72021-06-08 21:25:57 +0100309 shared_region = get_region(weight_tensor.mem_type, arch)
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100310 scale_region = get_region(scale_tensor.mem_type, arch) if scale_tensor else 0
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100311
Tim Halld8339a72021-05-27 18:49:40 +0100312 w_tensor_src = weight_tensor
313 if weight_tensor.src_tensor:
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100314 w_tensor_src = cast(NpuWeightTensor, weight_tensor.src_tensor)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100315
Tim Halld8339a72021-05-27 18:49:40 +0100316 core_offset = 0
317 for core in range(0, arch.ncores):
318 # Get weight range per core
319 key = WeightKey(core, weight_box.start_coord[-1])
320 if key in w_tensor_src.encoded_ranges:
321 weight_range = w_tensor_src.encoded_ranges[key]
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000322 if weight_tensor == w_tensor_src:
323 # Straight from source tensor
324 address = weight_tensor.address + weight_range.offset
Tim Hallb5df7732022-05-04 16:20:43 +0100325 else:
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000326 # Weight buffered tensor
327 address = weight_tensor.address + core_offset
328 core_offset += round_up(weight_range.total_bytes, 16)
Tim Halld8339a72021-05-27 18:49:40 +0100329
330 # Location of weights in tensor
331 addr_range = NpuAddressRange(
Tim Halld784af72021-06-08 21:25:57 +0100332 shared_region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16)
Tim Halld8339a72021-05-27 18:49:40 +0100333 )
334 weights.append(addr_range)
Tim Halld784af72021-06-08 21:25:57 +0100335
336 # Location of standalone scales or combined weights tensor scales
337 if scale_tensor:
338 assert scale_tensor.src_tensor is None # Must be standalone
339 scale_range = scale_tensor.encoded_ranges[key]
340 address = scale_tensor.address + scale_range.offset
341 addr_range = NpuAddressRange(scale_region, int(address), round_up(int(scale_range.scale_bytes), 16))
342 else:
343 addr_range = NpuAddressRange(shared_region, int(address), round_up(int(weight_range.scale_bytes), 16))
344
Tim Halld8339a72021-05-27 18:49:40 +0100345 biases.append(addr_range)
346
347 return weights, biases
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100348
349
350def create_npu_activation(op: Operation) -> NpuActivation:
351 """Creates fused activation function"""
352 if op.activation is None:
353 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
354 faf = op.activation.op_type
355 act_op = NpuActivationOp.NONE_OR_RELU
356 if faf == Op.Tanh:
357 act_op = NpuActivationOp.TANH
358 elif faf == Op.Sigmoid:
359 act_op = NpuActivationOp.SIGMOID
360 elif faf == Op.LUT:
361 act_op = NpuActivationOp.TABLE_LOOKUP
362 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000363 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100364
365 act = NpuActivation(act_op)
366 act.min = op.activation.min
367 act.max = op.activation.max
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200368 if act_op is NpuActivationOp.NONE_OR_RELU and op.type.is_avgpool_op() and not op.rescale:
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200369 quant = op.ofm.quantization
370 if quant and quant.zero_point: # Zero point is not 0
371 scale_f32 = 1 if quant.scale_f32 is None else quant.scale_f32
372 zero_point = quant.zero_point
373 if act.min is not None:
374 act.min = scale_f32 * quantise_float32(act.min, scale_f32, zero_point)
375 if act.max is not None:
376 act.max = scale_f32 * quantise_float32(act.max, scale_f32, zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100377 act.lookup_table_index = op.activation.lut_index
378 return act
379
380
381def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
382 """Sets common fields of the given operation"""
383 ps = cmd.ps
384 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100385
386 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100387 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100388 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100389
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100390 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100391 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100392 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100393
394 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100395 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100396 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100397 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
398
399 if cmd.weight_tensor is not None:
Tim Halld784af72021-06-08 21:25:57 +0100400 npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, cmd.scale_tensor, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100401 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100402 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
403 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100404 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
405
406 if not op.type.is_elementwise_op():
407 npu_op.padding = create_padding(cmd, op)
408 npu_op.kernel = to_npu_kernel(op.kernel)
Tim Hall3c5cfe92022-03-16 16:31:57 +0000409 npu_op.ifm_upscale = resampling_mode_inv_map[op.ifm_resampling_mode]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100410 return npu_op
411
412
413def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
414 """Converts the command to NpuConv2DOperation"""
415 npu_op = NpuConv2DOperation()
416 set_common_op_fields(npu_op, cmd, arch)
417 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
418 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
419 else:
Tim Halld8339a72021-05-27 18:49:40 +0100420 if cmd.weight_tensor.src_tensor:
421 npu_op.block_traversal = cmd.weight_tensor.src_tensor.hw_traversal
422 else:
423 npu_op.block_traversal = cmd.weight_tensor.hw_traversal
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100424 return npu_op
425
426
427def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
428 """Converts the command to NpuConvDepthWiseOperation"""
429 npu_op = NpuConvDepthWiseOperation()
430 set_common_op_fields(npu_op, cmd, arch)
431 return npu_op
432
433
434def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
435 """Converts the command to NpuPoolingOperation"""
436 ps = cmd.ps
437 op = ps.primary_op
438 pool_op = NpuPoolingOp.AVERAGE
439 if op.type.is_maxpool_op():
440 pool_op = NpuPoolingOp.MAX
441 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
442 pool_op = NpuPoolingOp.AVERAGE
443 elif op.type == Op.ReduceSum:
444 pool_op = NpuPoolingOp.REDUCE_SUM
445 else:
446 assert 0, f"Unknown pool type {op.type}"
447 npu_op = NpuPoolingOperation(pool_op)
448 set_common_op_fields(npu_op, cmd, arch)
449 # Pooling specific info
Dwight Lidman4f728c02020-12-17 15:14:45 +0100450 npu_op.rescale = op.rescale
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200451 if op.explicit_scaling:
452 # Note: reuse of rescale for explicit scaling to not expose this in the external API
453 assert npu_op.rescale is None
454 npu_op.rescale = op.explicit_scaling
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100455 return npu_op
456
457
458def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
459 """Converts the command to NpuElementWiseOperation"""
460 ps = cmd.ps
461 op = ps.primary_op
462 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
463 elemwise_op = elementwise_op_map[op.type]
464 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100465
Louis Verhaard1e170182020-11-26 11:42:04 +0100466 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100467 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
468 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
469 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100470 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
471 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
472 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100473 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100474 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100475 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100476 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
477 if cmd.ifm2_tensor.shape == []:
478 # scalar
James Peet7519d502021-07-19 16:47:58 +0100479 npu_op.ifm2_scalar = cmd.ifm2_tensor.get_scalar()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100480 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
481 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100482 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100483 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100484 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100485 set_common_op_fields(npu_op, cmd, arch)
486 # Check if output scale needs to be overridden
487 output_scale = None
488 if op.type == Op.Add and "resizebilinear" in op.attrs:
489 # Force output scale same as the input scale for
490 # resizebilinear 1x1 that is converted to add
491 output_scale = npu_op.ifm2.quantization.scale_f32
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100492 if op.type == Op.Abs:
493 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100494 if op.type == Op.LeakyRelu:
495 output_scale = op.attrs["alpha"]
Patrik Gustavssonb081d672021-08-25 13:49:25 +0200496 if op.type in (Op.RescaleAdd, Op.RescaleMul):
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100497 assert op.rescale is not None, f"{op.type} must have rescale"
498 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100499 if op.type in (Op.Add, Op.Mul, Op.Sub):
500 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
501 output_scale = 1 / 0x3000
502 if output_scale is not None:
503 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
504 return npu_op
505
506
507def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
508 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100509 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100510 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100511 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100512 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100513 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100514
Tim Halld8339a72021-05-27 18:49:40 +0100515 if cmd.in_tensor.purpose == TensorPurpose.Weights:
516 # Get weight range per core
517 sz = 0
518 for core in range(0, arch.ncores):
519 key = WeightKey(core, cmd.box.start_coord[-1])
520 if key in cmd.in_tensor.encoded_ranges:
521 weight_range = cmd.in_tensor.encoded_ranges[key]
522 sz += round_up(weight_range.total_bytes, 16)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100523
Tim Halld8339a72021-05-27 18:49:40 +0100524 if core == 0:
525 weight_range = cmd.in_tensor.encoded_ranges[key]
526 src_addr = cmd.in_tensor.address + weight_range.offset
Rickard Bolinfd8b5002022-05-16 09:11:06 +0000527 dest_addr = cmd.out_tensor.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100528 else:
Tim Halld8339a72021-05-27 18:49:40 +0100529 start_coord = cmd.box.start_coord
530 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
531 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100532 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
533 src = NpuAddressRange(src_region, int(src_addr), int(sz))
534 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
535 return NpuDmaOperation(src, dest)
536
537
538def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
539 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100540 npu_op: NpuOperation
541 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100542 npu_op = create_dma_op(cmd, arch)
Tim Hall68df8a12022-03-16 16:51:16 +0000543 npu_op.name = cmd.out_tensor.name
Dwight Lidman9b43f842020-12-08 17:56:44 +0100544 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100545 npu_block_type = cmd.ps.primary_op.type.npu_block_type
546 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
547 npu_op = create_npu_conv2d_op(cmd, arch)
548 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
549 npu_op = create_npu_conv_depthwise_op(cmd, arch)
550 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
551 npu_op = create_npu_pool_op(cmd, arch)
552 elif npu_block_type == NpuBlockType.ElementWise:
553 npu_op = create_npu_elementwise_op(cmd, arch)
554 else:
555 assert 0, f"Unknown command type {npu_block_type}"
Tim Hall68df8a12022-03-16 16:51:16 +0000556 npu_op.name = cmd.ps.primary_op.name
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100557 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100558
559
560def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
561 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
562 # Convert high level command stream to list of NpuOperation
563 npu_op_list = []
564 npu_op_to_cmd = dict() # map from npu op to high level command
565 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100566 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100567 print("Warning: Skipping register command stream generation for", cmd.ps)
568 else:
569 npu_op = convert_command_to_npu_op(cmd, arch)
570 npu_op_list.append(npu_op)
571 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100572 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100573 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100574 if len(sg.high_level_command_stream) > 0:
575 stream_id = DebugDatabase.add_stream(sg)
576 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100577
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100578 def add_to_debug_db(npu_op: NpuOperation, offset: int):
579 """Adds info to the debug database"""
580 if not isinstance(npu_op, NpuDmaOperation):
581 cmd = npu_op_to_cmd[npu_op]
582 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100583
Louis Verhaard024c3552021-03-17 14:26:34 +0100584 sg.register_command_stream = generate_command_stream(
585 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
586 )