blob: c822132031c82342368888a3a6e425fa22d3a20a [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Jonas Ohlsson845e2322022-03-01 12:39:55 +010020from typing import cast
Louis Verhaard024c3552021-03-17 14:26:34 +010021from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010022from typing import List
23from typing import Optional
Jonas Ohlsson845e2322022-03-01 12:39:55 +010024from typing import Tuple
Louis Verhaarde8a5a782020-11-02 18:04:27 +010025
26from .api import NpuActivation
27from .api import NpuActivationOp
28from .api import NpuAddressRange
29from .api import NpuBlockOperation
30from .api import NpuBlockTraversal
31from .api import NpuConv2DOperation
32from .api import NpuConvDepthWiseOperation
33from .api import NpuDataType
34from .api import NpuDmaOperation
35from .api import NpuElementWiseOp
36from .api import NpuElementWiseOperation
37from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010038from .api import NpuLayout
39from .api import NpuOperation
40from .api import NpuPadding
41from .api import NpuPoolingOp
42from .api import NpuPoolingOperation
43from .api import NpuQuantization
44from .api import NpuResamplingMode
45from .api import NpuRoundingMode
46from .api import NpuShape3D
47from .api import NpuTileBox
48from .architecture_features import ArchitectureFeatures
49from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010050from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000051from .errors import UnsupportedFeatureError
Tim Hall3c5cfe92022-03-16 16:31:57 +000052from .ethos_u55_regs.ethos_u55_regs import resampling_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +010053from .high_level_command_stream import Box
54from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010055from .high_level_command_stream import DMA
56from .high_level_command_stream import NpuStripe
Fredrik Svedberg838df0a2021-09-17 16:29:22 +020057from .numeric_util import quantise_float32
Tim Halld8339a72021-05-27 18:49:40 +010058from .numeric_util import round_up
Louis Verhaarde8a5a782020-11-02 18:04:27 +010059from .operation import NpuBlockType
60from .operation import Op
61from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010062from .register_command_stream_generator import generate_command_stream
63from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010064from .register_command_stream_util import to_npu_kernel
65from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000066from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010067from .tensor import MemType
68from .tensor import Tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010069from .tensor import TensorFormat
70from .tensor import TensorPurpose
Tim Halld8339a72021-05-27 18:49:40 +010071from .tensor import TensorSubPurpose
Jonas Ohlsson845e2322022-03-01 12:39:55 +010072from .weight_compressor import NpuWeightTensor
Tim Halld8339a72021-05-27 18:49:40 +010073from .weight_compressor import WeightKey
Louis Verhaarde8a5a782020-11-02 18:04:27 +010074
75
Louis Verhaarde8a5a782020-11-02 18:04:27 +010076class BasePointerIndex(IntEnum):
77 WeightTensor = 0 # base address index for the Weight tensor
78 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
79 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010080
81
82dtype_map = {
83 DataType.uint8: NpuDataType.UINT8,
84 DataType.int8: NpuDataType.INT8,
85 DataType.uint16: NpuDataType.UINT16,
86 DataType.int16: NpuDataType.INT16,
87 DataType.int32: NpuDataType.INT32,
88}
89
90
Louis Verhaarde8a5a782020-11-02 18:04:27 +010091# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
92elementwise_op_map = {
93 Op.Mul: NpuElementWiseOp.MUL,
Patrik Gustavssonb081d672021-08-25 13:49:25 +020094 Op.RescaleMul: NpuElementWiseOp.MUL,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010095 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010096 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010097 Op.Sub: NpuElementWiseOp.SUB,
98 Op.Minimum: NpuElementWiseOp.MIN,
99 Op.Maximum: NpuElementWiseOp.MAX,
100 Op.LeakyRelu: NpuElementWiseOp.LRELU,
101 Op.Abs: NpuElementWiseOp.ABS,
102 Op.CLZ: NpuElementWiseOp.CLZ,
103 Op.SHR: NpuElementWiseOp.SHR,
104 Op.SHL: NpuElementWiseOp.SHL,
105}
106
107
Tim Hall3c5cfe92022-03-16 16:31:57 +0000108# inverse of the resampling_mode_map in the register command stream generator
109resampling_mode_inv_map = {
110 resampling_mode.NONE: NpuResamplingMode.NONE,
111 resampling_mode.NEAREST: NpuResamplingMode.NEAREST,
112 resampling_mode.TRANSPOSE: NpuResamplingMode.TRANSPOSE,
113}
114
115
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100116def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
117 if ifm_shape == []:
118 # Scalar needs to be in IFM2
119 return False
120 if ifm2_shape == []:
121 return True
122
123 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
124 if ifm != ifm2 and ifm == 1:
125 # Broadcasted FM needs to be in IFM2
126 return False
127 return True
128
129
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100130def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100131 """Specifies type of rounding to be used"""
132 rounding_mode = NpuRoundingMode.TFL
133 if op.type == Op.ResizeBilinear:
Dwight Lidman9d243932021-08-10 12:53:12 +0200134 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100135 elif (
136 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
137 and op.ifm.dtype == DataType.int16
138 ):
139 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100140 elif (
141 not fused_quantize
142 and op.type.is_avgpool_op()
143 and op.memory_function == Op.ConcatSliceWrite
144 and op.kernel.elements_wh() == 1
145 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100146 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100147 if op.rounding_mode is not None:
148 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100149 return rounding_mode
150
151
152def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
153 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
154 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100155 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100156
157 # Check if this is for horizontal ifm streaming
158 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100159 top = cmd.pad_top
160 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100161
Tim Hall3751aa42021-12-16 13:17:29 +0000162 # the ifm box coordinate range depends upon whether the primary op was combined with a split slice read
163 ifm_read_offset = primary_op.read_offsets[0]
164 ifm_read_shape = primary_op.read_shapes[0]
165 if ifm_read_offset is None or len(ifm_read_offset) < 2:
166 box_start_coord_min = 0
167 box_end_coord_max = cmd.ps.ifm_shapes[0].width
168 else:
169 box_start_coord_min = ifm_read_offset[-2]
170 box_end_coord_max = ifm_read_shape[-2]
171
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100172 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
173 # because of activation function needed to be fused.
Tim Hall3751aa42021-12-16 13:17:29 +0000174 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > box_start_coord_min:
175 left = 0
176 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < box_end_coord_max:
177 right = 0
Louis Verhaard69b31762020-11-17 09:45:20 +0100178 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100179
180
Louis Verhaard024c3552021-03-17 14:26:34 +0100181def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000182 base_ptr_idx_map = {
183 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
184 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
185 MemType.Scratch: BasePointerIndex.ScratchTensor,
186 }
187
188 if arch.is_spilling_enabled():
189 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100190 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000191 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
192
Louis Verhaard024c3552021-03-17 14:26:34 +0100193 return base_ptr_idx_map[mem_type].value
194
195
196def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
197 """Returns map region -> max size of the region in bytes"""
198 mem_limits = dict()
199 for mem_type in MemType.all():
200 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
201 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
202 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100203
204
Louis Verhaarde91b5312022-01-21 13:38:50 +0100205def get_double_buffer_offset(arch: ArchitectureFeatures, range_index: int, core: int) -> int:
206 """Returns 0 if the first half of a double buffer should be used, 1 if the second half should be used"""
207 return ((range_index - core) // arch.ncores) % 2
208
209
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100210def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
211 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100212 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100213 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100214 block = ofm_box.get_block()
215 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100216
217
218def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
219 """Checks if quantization should use 0 as zero point"""
220 if tens.dtype == DataType.int32 and is_ifm_tensor:
221 return True
222 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
223 return False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200224 if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
225 return False
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100226 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
227 forced_ofm_quantization = ps.primary_op.forced_output_quantization
228 use_0 = (
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200229 (
230 ps.primary_op.activation is None
231 or forced_ofm_quantization is not None
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200232 or (
233 ps.primary_op.type.is_avgpool_op()
234 and ps.primary_op.activation.op_type.is_relu_op()
235 and not ps.primary_op.rescale
236 )
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200237 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100238 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
239 and not fused_quantize
240 )
241 return use_0
242
243
244def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
245 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100246 op = ps.primary_op
247 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
248 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100249 return None
250 if use_zero_point_0(ps, tens, True):
251 zero_point = 0
252 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100253 zero_point = int(ifm_quant.zero_point)
254 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100255
256
257def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
258 """Gets quantization for OFM"""
259 op = ps.primary_op
260 # Check if operation's output quantization is should be used instead of the output tensor's quantization
261 # (used in LUTs)
262 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
263 if ofm_quant is None:
264 return None
265 if use_zero_point_0(ps, tens, False):
266 zero_point = 0
267 else:
268 zero_point = int(ofm_quant.zero_point)
269 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
270
271
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100272def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100273 """Creates feature map with common fields populated"""
274 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100275 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100276 fm.data_type = dtype_map[tens.dtype]
277 if tens.format == TensorFormat.NHWC:
278 fm.layout = NpuLayout.NHWC
279 elif tens.format == TensorFormat.NHCWB16:
280 fm.layout = NpuLayout.NHCWB16
281 else:
282 assert 0, "Incorrect tensor format"
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100283 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
284 box.start_coord, box.end_coord, op_shape4D
285 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100286 for idx, addr in enumerate(addresses):
287 if addr is None:
288 addresses[idx] = 0
289 fm.tiles = NpuTileBox(
290 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
291 )
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100292 strides = tens.get_strides(shape4D=op_shape4D)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100293 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
294 return fm
295
296
Tim Halld784af72021-06-08 21:25:57 +0100297def create_weights(
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100298 weight_tensor: NpuWeightTensor, weight_box: Box, scale_tensor: NpuWeightTensor, arch: ArchitectureFeatures
299) -> Tuple[List[NpuAddressRange], List[NpuAddressRange]]:
Tim Halld8339a72021-05-27 18:49:40 +0100300 """Returns address ranges for weights and scales"""
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100301 weights = []
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100302 biases = []
Tim Halld784af72021-06-08 21:25:57 +0100303 shared_region = get_region(weight_tensor.mem_type, arch)
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100304 scale_region = get_region(scale_tensor.mem_type, arch) if scale_tensor else 0
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100305
Tim Halld8339a72021-05-27 18:49:40 +0100306 w_tensor_src = weight_tensor
307 if weight_tensor.src_tensor:
Jonas Ohlsson845e2322022-03-01 12:39:55 +0100308 w_tensor_src = cast(NpuWeightTensor, weight_tensor.src_tensor)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100309
Tim Halld8339a72021-05-27 18:49:40 +0100310 core_offset = 0
311 for core in range(0, arch.ncores):
312 # Get weight range per core
313 key = WeightKey(core, weight_box.start_coord[-1])
314 if key in w_tensor_src.encoded_ranges:
315 weight_range = w_tensor_src.encoded_ranges[key]
316 if weight_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
317 assert weight_tensor != w_tensor_src
318 # Double buffered inside weight_tensor
Louis Verhaarde91b5312022-01-21 13:38:50 +0100319 address = weight_tensor.address + core_offset
320 address += get_double_buffer_offset(arch, weight_range.index, core) * w_tensor_src.max_range_bytes
Tim Halld8339a72021-05-27 18:49:40 +0100321 core_offset += round_up(weight_range.total_bytes, 16)
322 else:
323 if weight_tensor == w_tensor_src:
324 # Straight from source tensor
325 address = weight_tensor.address + weight_range.offset
326 else:
327 # Single buffered inside weight tensor
328 address = weight_tensor.address + core_offset
329 core_offset += round_up(weight_range.total_bytes, 16)
330
331 # Location of weights in tensor
332 addr_range = NpuAddressRange(
Tim Halld784af72021-06-08 21:25:57 +0100333 shared_region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16)
Tim Halld8339a72021-05-27 18:49:40 +0100334 )
335 weights.append(addr_range)
Tim Halld784af72021-06-08 21:25:57 +0100336
337 # Location of standalone scales or combined weights tensor scales
338 if scale_tensor:
339 assert scale_tensor.src_tensor is None # Must be standalone
340 scale_range = scale_tensor.encoded_ranges[key]
341 address = scale_tensor.address + scale_range.offset
342 addr_range = NpuAddressRange(scale_region, int(address), round_up(int(scale_range.scale_bytes), 16))
343 else:
344 addr_range = NpuAddressRange(shared_region, int(address), round_up(int(weight_range.scale_bytes), 16))
345
Tim Halld8339a72021-05-27 18:49:40 +0100346 biases.append(addr_range)
347
348 return weights, biases
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100349
350
351def create_npu_activation(op: Operation) -> NpuActivation:
352 """Creates fused activation function"""
353 if op.activation is None:
354 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
355 faf = op.activation.op_type
356 act_op = NpuActivationOp.NONE_OR_RELU
357 if faf == Op.Tanh:
358 act_op = NpuActivationOp.TANH
359 elif faf == Op.Sigmoid:
360 act_op = NpuActivationOp.SIGMOID
361 elif faf == Op.LUT:
362 act_op = NpuActivationOp.TABLE_LOOKUP
363 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000364 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100365
366 act = NpuActivation(act_op)
367 act.min = op.activation.min
368 act.max = op.activation.max
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200369 if act_op is NpuActivationOp.NONE_OR_RELU and op.type.is_avgpool_op() and not op.rescale:
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200370 quant = op.ofm.quantization
371 if quant and quant.zero_point: # Zero point is not 0
372 scale_f32 = 1 if quant.scale_f32 is None else quant.scale_f32
373 zero_point = quant.zero_point
374 if act.min is not None:
375 act.min = scale_f32 * quantise_float32(act.min, scale_f32, zero_point)
376 if act.max is not None:
377 act.max = scale_f32 * quantise_float32(act.max, scale_f32, zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100378 act.lookup_table_index = op.activation.lut_index
379 return act
380
381
382def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
383 """Sets common fields of the given operation"""
384 ps = cmd.ps
385 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100386
387 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100388 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100389 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100390
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100391 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100392 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100393 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100394
395 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100396 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100397 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100398 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
399
400 if cmd.weight_tensor is not None:
Tim Halld784af72021-06-08 21:25:57 +0100401 npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, cmd.scale_tensor, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100402 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100403 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
404 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100405 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
406
407 if not op.type.is_elementwise_op():
408 npu_op.padding = create_padding(cmd, op)
409 npu_op.kernel = to_npu_kernel(op.kernel)
Tim Hall3c5cfe92022-03-16 16:31:57 +0000410 npu_op.ifm_upscale = resampling_mode_inv_map[op.ifm_resampling_mode]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100411 return npu_op
412
413
414def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
415 """Converts the command to NpuConv2DOperation"""
416 npu_op = NpuConv2DOperation()
417 set_common_op_fields(npu_op, cmd, arch)
418 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
419 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
420 else:
Tim Halld8339a72021-05-27 18:49:40 +0100421 if cmd.weight_tensor.src_tensor:
422 npu_op.block_traversal = cmd.weight_tensor.src_tensor.hw_traversal
423 else:
424 npu_op.block_traversal = cmd.weight_tensor.hw_traversal
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100425 return npu_op
426
427
428def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
429 """Converts the command to NpuConvDepthWiseOperation"""
430 npu_op = NpuConvDepthWiseOperation()
431 set_common_op_fields(npu_op, cmd, arch)
432 return npu_op
433
434
435def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
436 """Converts the command to NpuPoolingOperation"""
437 ps = cmd.ps
438 op = ps.primary_op
439 pool_op = NpuPoolingOp.AVERAGE
440 if op.type.is_maxpool_op():
441 pool_op = NpuPoolingOp.MAX
442 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
443 pool_op = NpuPoolingOp.AVERAGE
444 elif op.type == Op.ReduceSum:
445 pool_op = NpuPoolingOp.REDUCE_SUM
446 else:
447 assert 0, f"Unknown pool type {op.type}"
448 npu_op = NpuPoolingOperation(pool_op)
449 set_common_op_fields(npu_op, cmd, arch)
450 # Pooling specific info
Dwight Lidman4f728c02020-12-17 15:14:45 +0100451 npu_op.rescale = op.rescale
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200452 if op.explicit_scaling:
453 # Note: reuse of rescale for explicit scaling to not expose this in the external API
454 assert npu_op.rescale is None
455 npu_op.rescale = op.explicit_scaling
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100456 return npu_op
457
458
459def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
460 """Converts the command to NpuElementWiseOperation"""
461 ps = cmd.ps
462 op = ps.primary_op
463 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
464 elemwise_op = elementwise_op_map[op.type]
465 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100466
Louis Verhaard1e170182020-11-26 11:42:04 +0100467 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100468 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
469 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
470 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100471 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
472 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
473 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100474 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100475 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100476 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100477 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
478 if cmd.ifm2_tensor.shape == []:
479 # scalar
James Peet7519d502021-07-19 16:47:58 +0100480 npu_op.ifm2_scalar = cmd.ifm2_tensor.get_scalar()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100481 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
482 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100483 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100484 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100485 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100486 set_common_op_fields(npu_op, cmd, arch)
487 # Check if output scale needs to be overridden
488 output_scale = None
489 if op.type == Op.Add and "resizebilinear" in op.attrs:
490 # Force output scale same as the input scale for
491 # resizebilinear 1x1 that is converted to add
492 output_scale = npu_op.ifm2.quantization.scale_f32
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100493 if op.type == Op.Abs:
494 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100495 if op.type == Op.LeakyRelu:
496 output_scale = op.attrs["alpha"]
Patrik Gustavssonb081d672021-08-25 13:49:25 +0200497 if op.type in (Op.RescaleAdd, Op.RescaleMul):
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100498 assert op.rescale is not None, f"{op.type} must have rescale"
499 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100500 if op.type in (Op.Add, Op.Mul, Op.Sub):
501 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
502 output_scale = 1 / 0x3000
503 if output_scale is not None:
504 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
505 return npu_op
506
507
508def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
509 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100510 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100511 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100512 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100513 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100514 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100515
Tim Halld8339a72021-05-27 18:49:40 +0100516 if cmd.in_tensor.purpose == TensorPurpose.Weights:
517 # Get weight range per core
518 sz = 0
519 for core in range(0, arch.ncores):
520 key = WeightKey(core, cmd.box.start_coord[-1])
521 if key in cmd.in_tensor.encoded_ranges:
522 weight_range = cmd.in_tensor.encoded_ranges[key]
523 sz += round_up(weight_range.total_bytes, 16)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100524
Tim Halld8339a72021-05-27 18:49:40 +0100525 if core == 0:
526 weight_range = cmd.in_tensor.encoded_ranges[key]
527 src_addr = cmd.in_tensor.address + weight_range.offset
528
529 if cmd.out_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
530 dest_addr = cmd.out_tensor.address + cmd.in_tensor.max_range_bytes * (
Louis Verhaarde91b5312022-01-21 13:38:50 +0100531 get_double_buffer_offset(arch, weight_range.index, core)
Tim Halld8339a72021-05-27 18:49:40 +0100532 )
533 else:
534 dest_addr = cmd.out_tensor.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100535 else:
Tim Halld8339a72021-05-27 18:49:40 +0100536 start_coord = cmd.box.start_coord
537 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
538 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100539 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
540 src = NpuAddressRange(src_region, int(src_addr), int(sz))
541 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
542 return NpuDmaOperation(src, dest)
543
544
545def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
546 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100547 npu_op: NpuOperation
548 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100549 npu_op = create_dma_op(cmd, arch)
Dwight Lidman9b43f842020-12-08 17:56:44 +0100550 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100551 npu_block_type = cmd.ps.primary_op.type.npu_block_type
552 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
553 npu_op = create_npu_conv2d_op(cmd, arch)
554 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
555 npu_op = create_npu_conv_depthwise_op(cmd, arch)
556 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
557 npu_op = create_npu_pool_op(cmd, arch)
558 elif npu_block_type == NpuBlockType.ElementWise:
559 npu_op = create_npu_elementwise_op(cmd, arch)
560 else:
561 assert 0, f"Unknown command type {npu_block_type}"
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100562 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100563
564
565def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
566 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
567 # Convert high level command stream to list of NpuOperation
568 npu_op_list = []
569 npu_op_to_cmd = dict() # map from npu op to high level command
570 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100571 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100572 print("Warning: Skipping register command stream generation for", cmd.ps)
573 else:
574 npu_op = convert_command_to_npu_op(cmd, arch)
575 npu_op_list.append(npu_op)
576 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100577 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100578 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100579 if len(sg.high_level_command_stream) > 0:
580 stream_id = DebugDatabase.add_stream(sg)
581 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100582
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100583 def add_to_debug_db(npu_op: NpuOperation, offset: int):
584 """Adds info to the debug database"""
585 if not isinstance(npu_op, NpuDmaOperation):
586 cmd = npu_op_to_cmd[npu_op]
587 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100588
Louis Verhaard024c3552021-03-17 14:26:34 +0100589 sg.register_command_stream = generate_command_stream(
590 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
591 )