blob: c5d064651b8749a02754bfaad8493d9c5228a69b [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Louis Verhaard024c3552021-03-17 14:26:34 +010020from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010021from typing import List
22from typing import Optional
23
24from .api import NpuActivation
25from .api import NpuActivationOp
26from .api import NpuAddressRange
27from .api import NpuBlockOperation
28from .api import NpuBlockTraversal
29from .api import NpuConv2DOperation
30from .api import NpuConvDepthWiseOperation
31from .api import NpuDataType
32from .api import NpuDmaOperation
33from .api import NpuElementWiseOp
34from .api import NpuElementWiseOperation
35from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010036from .api import NpuLayout
37from .api import NpuOperation
38from .api import NpuPadding
39from .api import NpuPoolingOp
40from .api import NpuPoolingOperation
41from .api import NpuQuantization
42from .api import NpuResamplingMode
43from .api import NpuRoundingMode
44from .api import NpuShape3D
45from .api import NpuTileBox
46from .architecture_features import ArchitectureFeatures
47from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010048from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000049from .errors import UnsupportedFeatureError
Louis Verhaarde8a5a782020-11-02 18:04:27 +010050from .high_level_command_stream import Box
51from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010052from .high_level_command_stream import DMA
53from .high_level_command_stream import NpuStripe
Tim Halld8339a72021-05-27 18:49:40 +010054from .numeric_util import round_up
Louis Verhaarde8a5a782020-11-02 18:04:27 +010055from .operation import NpuBlockType
56from .operation import Op
57from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010058from .register_command_stream_generator import generate_command_stream
59from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010060from .register_command_stream_util import to_npu_kernel
61from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000062from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010063from .tensor import MemType
64from .tensor import Tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010065from .tensor import TensorFormat
66from .tensor import TensorPurpose
Tim Halld8339a72021-05-27 18:49:40 +010067from .tensor import TensorSubPurpose
68from .weight_compressor import WeightKey
Louis Verhaarde8a5a782020-11-02 18:04:27 +010069
70
Louis Verhaarde8a5a782020-11-02 18:04:27 +010071class BasePointerIndex(IntEnum):
72 WeightTensor = 0 # base address index for the Weight tensor
73 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
74 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010075
76
77dtype_map = {
78 DataType.uint8: NpuDataType.UINT8,
79 DataType.int8: NpuDataType.INT8,
80 DataType.uint16: NpuDataType.UINT16,
81 DataType.int16: NpuDataType.INT16,
82 DataType.int32: NpuDataType.INT32,
83}
84
85
Louis Verhaarde8a5a782020-11-02 18:04:27 +010086# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
87elementwise_op_map = {
88 Op.Mul: NpuElementWiseOp.MUL,
89 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010090 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010091 Op.Sub: NpuElementWiseOp.SUB,
92 Op.Minimum: NpuElementWiseOp.MIN,
93 Op.Maximum: NpuElementWiseOp.MAX,
94 Op.LeakyRelu: NpuElementWiseOp.LRELU,
95 Op.Abs: NpuElementWiseOp.ABS,
96 Op.CLZ: NpuElementWiseOp.CLZ,
97 Op.SHR: NpuElementWiseOp.SHR,
98 Op.SHL: NpuElementWiseOp.SHL,
99}
100
101
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100102def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
103 if ifm_shape == []:
104 # Scalar needs to be in IFM2
105 return False
106 if ifm2_shape == []:
107 return True
108
109 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
110 if ifm != ifm2 and ifm == 1:
111 # Broadcasted FM needs to be in IFM2
112 return False
113 return True
114
115
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100116def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100117 """Specifies type of rounding to be used"""
118 rounding_mode = NpuRoundingMode.TFL
119 if op.type == Op.ResizeBilinear:
Dwight Lidman9d243932021-08-10 12:53:12 +0200120 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100121 elif (
122 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
123 and op.ifm.dtype == DataType.int16
124 ):
125 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100126 elif (
127 not fused_quantize
128 and op.type.is_avgpool_op()
129 and op.memory_function == Op.ConcatSliceWrite
130 and op.kernel.elements_wh() == 1
131 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100132 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100133 if op.rounding_mode is not None:
134 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100135 return rounding_mode
136
137
138def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
139 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
140 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100141 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100142
143 # Check if this is for horizontal ifm streaming
144 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100145 top = cmd.pad_top
146 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100147
148 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
149 # because of activation function needed to be fused.
Henrik G Olsson5fabfca2021-04-15 17:57:26 +0200150 if not primary_op.attrs.get("force_padding"):
151 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > 0:
152 left = 0
153 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < cmd.ps.ifm_shapes[0].width:
154 right = 0
Louis Verhaard69b31762020-11-17 09:45:20 +0100155 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100156
157
Louis Verhaard024c3552021-03-17 14:26:34 +0100158def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000159 base_ptr_idx_map = {
160 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
161 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
162 MemType.Scratch: BasePointerIndex.ScratchTensor,
163 }
164
165 if arch.is_spilling_enabled():
166 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100167 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000168 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
169
Louis Verhaard024c3552021-03-17 14:26:34 +0100170 return base_ptr_idx_map[mem_type].value
171
172
173def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
174 """Returns map region -> max size of the region in bytes"""
175 mem_limits = dict()
176 for mem_type in MemType.all():
177 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
178 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
179 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100180
181
182def get_upscale(op: Operation) -> NpuResamplingMode:
183 upscale = NpuResamplingMode.NONE
184 if op.type == Op.ResizeBilinear:
185 # perform nearest neighbor upscale
186 upscale = NpuResamplingMode.NEAREST
187 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
188 # perform insert zero upscale
189 upscale = NpuResamplingMode.TRANSPOSE
190 return upscale
191
192
193def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
194 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100195 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100196 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100197 block = ofm_box.get_block()
198 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100199
200
201def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
202 """Checks if quantization should use 0 as zero point"""
203 if tens.dtype == DataType.int32 and is_ifm_tensor:
204 return True
205 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
206 return False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200207 if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
208 return False
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100209 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
210 forced_ofm_quantization = ps.primary_op.forced_output_quantization
211 use_0 = (
212 (ps.primary_op.activation is None or forced_ofm_quantization is not None)
213 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
214 and not fused_quantize
215 )
216 return use_0
217
218
219def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
220 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100221 op = ps.primary_op
222 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
223 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100224 return None
225 if use_zero_point_0(ps, tens, True):
226 zero_point = 0
227 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100228 zero_point = int(ifm_quant.zero_point)
229 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100230
231
232def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
233 """Gets quantization for OFM"""
234 op = ps.primary_op
235 # Check if operation's output quantization is should be used instead of the output tensor's quantization
236 # (used in LUTs)
237 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
238 if ofm_quant is None:
239 return None
240 if use_zero_point_0(ps, tens, False):
241 zero_point = 0
242 else:
243 zero_point = int(ofm_quant.zero_point)
244 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
245
246
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100247def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100248 """Creates feature map with common fields populated"""
249 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100250 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100251 fm.data_type = dtype_map[tens.dtype]
252 if tens.format == TensorFormat.NHWC:
253 fm.layout = NpuLayout.NHWC
254 elif tens.format == TensorFormat.NHCWB16:
255 fm.layout = NpuLayout.NHCWB16
256 else:
257 assert 0, "Incorrect tensor format"
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100258 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
259 box.start_coord, box.end_coord, op_shape4D
260 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100261 for idx, addr in enumerate(addresses):
262 if addr is None:
263 addresses[idx] = 0
264 fm.tiles = NpuTileBox(
265 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
266 )
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100267 strides = tens.get_strides(shape4D=op_shape4D)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100268 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
269 return fm
270
271
Tim Halld784af72021-06-08 21:25:57 +0100272def create_weights(
273 weight_tensor: Tensor, weight_box: Box, scale_tensor: Tensor, arch: ArchitectureFeatures
274) -> List[NpuAddressRange]:
Tim Halld8339a72021-05-27 18:49:40 +0100275 """Returns address ranges for weights and scales"""
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100276 weights = []
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100277 biases = []
Tim Halld784af72021-06-08 21:25:57 +0100278 shared_region = get_region(weight_tensor.mem_type, arch)
279 scale_region = scale_tensor and get_region(scale_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100280
Tim Halld8339a72021-05-27 18:49:40 +0100281 w_tensor_src = weight_tensor
282 if weight_tensor.src_tensor:
283 w_tensor_src = weight_tensor.src_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100284
Tim Halld8339a72021-05-27 18:49:40 +0100285 core_offset = 0
286 for core in range(0, arch.ncores):
287 # Get weight range per core
288 key = WeightKey(core, weight_box.start_coord[-1])
289 if key in w_tensor_src.encoded_ranges:
290 weight_range = w_tensor_src.encoded_ranges[key]
291 if weight_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
292 assert weight_tensor != w_tensor_src
293 # Double buffered inside weight_tensor
294 address = weight_tensor.address + w_tensor_src.max_range_bytes * ((weight_range.index - core) % 2)
295 address += core_offset
296 core_offset += round_up(weight_range.total_bytes, 16)
297 else:
298 if weight_tensor == w_tensor_src:
299 # Straight from source tensor
300 address = weight_tensor.address + weight_range.offset
301 else:
302 # Single buffered inside weight tensor
303 address = weight_tensor.address + core_offset
304 core_offset += round_up(weight_range.total_bytes, 16)
305
306 # Location of weights in tensor
307 addr_range = NpuAddressRange(
Tim Halld784af72021-06-08 21:25:57 +0100308 shared_region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16)
Tim Halld8339a72021-05-27 18:49:40 +0100309 )
310 weights.append(addr_range)
Tim Halld784af72021-06-08 21:25:57 +0100311
312 # Location of standalone scales or combined weights tensor scales
313 if scale_tensor:
314 assert scale_tensor.src_tensor is None # Must be standalone
315 scale_range = scale_tensor.encoded_ranges[key]
316 address = scale_tensor.address + scale_range.offset
317 addr_range = NpuAddressRange(scale_region, int(address), round_up(int(scale_range.scale_bytes), 16))
318 else:
319 addr_range = NpuAddressRange(shared_region, int(address), round_up(int(weight_range.scale_bytes), 16))
320
Tim Halld8339a72021-05-27 18:49:40 +0100321 biases.append(addr_range)
322
323 return weights, biases
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100324
325
326def create_npu_activation(op: Operation) -> NpuActivation:
327 """Creates fused activation function"""
328 if op.activation is None:
329 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
330 faf = op.activation.op_type
331 act_op = NpuActivationOp.NONE_OR_RELU
332 if faf == Op.Tanh:
333 act_op = NpuActivationOp.TANH
334 elif faf == Op.Sigmoid:
335 act_op = NpuActivationOp.SIGMOID
336 elif faf == Op.LUT:
337 act_op = NpuActivationOp.TABLE_LOOKUP
338 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000339 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100340
341 act = NpuActivation(act_op)
342 act.min = op.activation.min
343 act.max = op.activation.max
344 act.lookup_table_index = op.activation.lut_index
345 return act
346
347
348def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
349 """Sets common fields of the given operation"""
350 ps = cmd.ps
351 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100352
353 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100354 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100355 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100356
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100357 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100358 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100359 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100360
361 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100362 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100363 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100364 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
365
366 if cmd.weight_tensor is not None:
Tim Halld784af72021-06-08 21:25:57 +0100367 npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, cmd.scale_tensor, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100368 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100369 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
370 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100371 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
372
373 if not op.type.is_elementwise_op():
374 npu_op.padding = create_padding(cmd, op)
375 npu_op.kernel = to_npu_kernel(op.kernel)
376 npu_op.ifm_upscale = get_upscale(op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100377 return npu_op
378
379
380def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
381 """Converts the command to NpuConv2DOperation"""
382 npu_op = NpuConv2DOperation()
383 set_common_op_fields(npu_op, cmd, arch)
384 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
385 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
386 else:
Tim Halld8339a72021-05-27 18:49:40 +0100387 if cmd.weight_tensor.src_tensor:
388 npu_op.block_traversal = cmd.weight_tensor.src_tensor.hw_traversal
389 else:
390 npu_op.block_traversal = cmd.weight_tensor.hw_traversal
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100391 return npu_op
392
393
394def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
395 """Converts the command to NpuConvDepthWiseOperation"""
396 npu_op = NpuConvDepthWiseOperation()
397 set_common_op_fields(npu_op, cmd, arch)
398 return npu_op
399
400
401def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
402 """Converts the command to NpuPoolingOperation"""
403 ps = cmd.ps
404 op = ps.primary_op
405 pool_op = NpuPoolingOp.AVERAGE
406 if op.type.is_maxpool_op():
407 pool_op = NpuPoolingOp.MAX
408 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
409 pool_op = NpuPoolingOp.AVERAGE
410 elif op.type == Op.ReduceSum:
411 pool_op = NpuPoolingOp.REDUCE_SUM
412 else:
413 assert 0, f"Unknown pool type {op.type}"
414 npu_op = NpuPoolingOperation(pool_op)
415 set_common_op_fields(npu_op, cmd, arch)
416 # Pooling specific info
Dwight Lidman4f728c02020-12-17 15:14:45 +0100417 npu_op.rescale = op.rescale
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200418 if op.explicit_scaling:
419 # Note: reuse of rescale for explicit scaling to not expose this in the external API
420 assert npu_op.rescale is None
421 npu_op.rescale = op.explicit_scaling
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100422 return npu_op
423
424
425def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
426 """Converts the command to NpuElementWiseOperation"""
427 ps = cmd.ps
428 op = ps.primary_op
429 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
430 elemwise_op = elementwise_op_map[op.type]
431 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100432
Louis Verhaard1e170182020-11-26 11:42:04 +0100433 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100434 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
435 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
436 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100437 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
438 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
439 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100440 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100441 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100442 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100443 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
444 if cmd.ifm2_tensor.shape == []:
445 # scalar
James Peet7519d502021-07-19 16:47:58 +0100446 npu_op.ifm2_scalar = cmd.ifm2_tensor.get_scalar()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100447 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
448 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100449 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100450 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100451 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100452 set_common_op_fields(npu_op, cmd, arch)
453 # Check if output scale needs to be overridden
454 output_scale = None
455 if op.type == Op.Add and "resizebilinear" in op.attrs:
456 # Force output scale same as the input scale for
457 # resizebilinear 1x1 that is converted to add
458 output_scale = npu_op.ifm2.quantization.scale_f32
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100459 if op.type == Op.Abs:
460 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100461 if op.type == Op.LeakyRelu:
462 output_scale = op.attrs["alpha"]
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100463 if op.type == Op.RescaleAdd:
464 assert op.rescale is not None, f"{op.type} must have rescale"
465 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100466 if op.type in (Op.Add, Op.Mul, Op.Sub):
467 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
468 output_scale = 1 / 0x3000
469 if output_scale is not None:
470 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
471 return npu_op
472
473
474def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
475 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100476 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100477 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100478 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100479 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100480 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100481
Tim Halld8339a72021-05-27 18:49:40 +0100482 if cmd.in_tensor.purpose == TensorPurpose.Weights:
483 # Get weight range per core
484 sz = 0
485 for core in range(0, arch.ncores):
486 key = WeightKey(core, cmd.box.start_coord[-1])
487 if key in cmd.in_tensor.encoded_ranges:
488 weight_range = cmd.in_tensor.encoded_ranges[key]
489 sz += round_up(weight_range.total_bytes, 16)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100490
Tim Halld8339a72021-05-27 18:49:40 +0100491 if core == 0:
492 weight_range = cmd.in_tensor.encoded_ranges[key]
493 src_addr = cmd.in_tensor.address + weight_range.offset
494
495 if cmd.out_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
496 dest_addr = cmd.out_tensor.address + cmd.in_tensor.max_range_bytes * (
497 (weight_range.index - core) % 2
498 )
499 else:
500 dest_addr = cmd.out_tensor.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100501 else:
Tim Halld8339a72021-05-27 18:49:40 +0100502 start_coord = cmd.box.start_coord
503 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
504 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100505 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
506 src = NpuAddressRange(src_region, int(src_addr), int(sz))
507 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
508 return NpuDmaOperation(src, dest)
509
510
511def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
512 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100513 npu_op: NpuOperation
514 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100515 npu_op = create_dma_op(cmd, arch)
Dwight Lidman9b43f842020-12-08 17:56:44 +0100516 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100517 npu_block_type = cmd.ps.primary_op.type.npu_block_type
518 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
519 npu_op = create_npu_conv2d_op(cmd, arch)
520 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
521 npu_op = create_npu_conv_depthwise_op(cmd, arch)
522 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
523 npu_op = create_npu_pool_op(cmd, arch)
524 elif npu_block_type == NpuBlockType.ElementWise:
525 npu_op = create_npu_elementwise_op(cmd, arch)
526 else:
527 assert 0, f"Unknown command type {npu_block_type}"
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100528 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100529
530
531def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
532 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
533 # Convert high level command stream to list of NpuOperation
534 npu_op_list = []
535 npu_op_to_cmd = dict() # map from npu op to high level command
536 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100537 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100538 print("Warning: Skipping register command stream generation for", cmd.ps)
539 else:
540 npu_op = convert_command_to_npu_op(cmd, arch)
541 npu_op_list.append(npu_op)
542 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100543 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100544 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100545 if len(sg.high_level_command_stream) > 0:
546 stream_id = DebugDatabase.add_stream(sg)
547 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100548
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100549 def add_to_debug_db(npu_op: NpuOperation, offset: int):
550 """Adds info to the debug database"""
551 if not isinstance(npu_op, NpuDmaOperation):
552 cmd = npu_op_to_cmd[npu_op]
553 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100554
Louis Verhaard024c3552021-03-17 14:26:34 +0100555 sg.register_command_stream = generate_command_stream(
556 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
557 )