blob: f8c9de36c69db46957c9326552ab637a319de575 [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Louis Verhaard024c3552021-03-17 14:26:34 +010020from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010021from typing import List
22from typing import Optional
23
24from .api import NpuActivation
25from .api import NpuActivationOp
26from .api import NpuAddressRange
27from .api import NpuBlockOperation
28from .api import NpuBlockTraversal
29from .api import NpuConv2DOperation
30from .api import NpuConvDepthWiseOperation
31from .api import NpuDataType
32from .api import NpuDmaOperation
33from .api import NpuElementWiseOp
34from .api import NpuElementWiseOperation
35from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010036from .api import NpuLayout
37from .api import NpuOperation
38from .api import NpuPadding
39from .api import NpuPoolingOp
40from .api import NpuPoolingOperation
41from .api import NpuQuantization
42from .api import NpuResamplingMode
43from .api import NpuRoundingMode
44from .api import NpuShape3D
45from .api import NpuTileBox
46from .architecture_features import ArchitectureFeatures
47from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010048from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000049from .errors import UnsupportedFeatureError
Louis Verhaarde8a5a782020-11-02 18:04:27 +010050from .high_level_command_stream import Box
51from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010052from .high_level_command_stream import DMA
53from .high_level_command_stream import NpuStripe
Tim Halld8339a72021-05-27 18:49:40 +010054from .numeric_util import round_up
Louis Verhaarde8a5a782020-11-02 18:04:27 +010055from .operation import NpuBlockType
56from .operation import Op
57from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010058from .register_command_stream_generator import generate_command_stream
59from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010060from .register_command_stream_util import to_npu_kernel
61from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000062from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010063from .tensor import MemType
64from .tensor import Tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010065from .tensor import TensorFormat
66from .tensor import TensorPurpose
Tim Halld8339a72021-05-27 18:49:40 +010067from .tensor import TensorSubPurpose
68from .weight_compressor import WeightKey
Louis Verhaarde8a5a782020-11-02 18:04:27 +010069
70
Louis Verhaarde8a5a782020-11-02 18:04:27 +010071class BasePointerIndex(IntEnum):
72 WeightTensor = 0 # base address index for the Weight tensor
73 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
74 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010075
76
77dtype_map = {
78 DataType.uint8: NpuDataType.UINT8,
79 DataType.int8: NpuDataType.INT8,
80 DataType.uint16: NpuDataType.UINT16,
81 DataType.int16: NpuDataType.INT16,
82 DataType.int32: NpuDataType.INT32,
83}
84
85
Louis Verhaarde8a5a782020-11-02 18:04:27 +010086# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
87elementwise_op_map = {
88 Op.Mul: NpuElementWiseOp.MUL,
89 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010090 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010091 Op.Sub: NpuElementWiseOp.SUB,
92 Op.Minimum: NpuElementWiseOp.MIN,
93 Op.Maximum: NpuElementWiseOp.MAX,
94 Op.LeakyRelu: NpuElementWiseOp.LRELU,
95 Op.Abs: NpuElementWiseOp.ABS,
96 Op.CLZ: NpuElementWiseOp.CLZ,
97 Op.SHR: NpuElementWiseOp.SHR,
98 Op.SHL: NpuElementWiseOp.SHL,
99}
100
101
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100102def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
103 if ifm_shape == []:
104 # Scalar needs to be in IFM2
105 return False
106 if ifm2_shape == []:
107 return True
108
109 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
110 if ifm != ifm2 and ifm == 1:
111 # Broadcasted FM needs to be in IFM2
112 return False
113 return True
114
115
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100116def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100117 """Specifies type of rounding to be used"""
118 rounding_mode = NpuRoundingMode.TFL
119 if op.type == Op.ResizeBilinear:
Dwight Lidman9d243932021-08-10 12:53:12 +0200120 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100121 elif (
122 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
123 and op.ifm.dtype == DataType.int16
124 ):
125 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100126 elif (
127 not fused_quantize
128 and op.type.is_avgpool_op()
129 and op.memory_function == Op.ConcatSliceWrite
130 and op.kernel.elements_wh() == 1
131 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100132 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100133 if op.rounding_mode is not None:
134 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100135 return rounding_mode
136
137
138def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
139 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
140 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100141 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100142
143 # Check if this is for horizontal ifm streaming
144 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100145 top = cmd.pad_top
146 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100147
148 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
149 # because of activation function needed to be fused.
Henrik G Olsson5fabfca2021-04-15 17:57:26 +0200150 if not primary_op.attrs.get("force_padding"):
151 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > 0:
152 left = 0
153 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < cmd.ps.ifm_shapes[0].width:
154 right = 0
Louis Verhaard69b31762020-11-17 09:45:20 +0100155 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100156
157
Louis Verhaard024c3552021-03-17 14:26:34 +0100158def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000159 base_ptr_idx_map = {
160 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
161 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
162 MemType.Scratch: BasePointerIndex.ScratchTensor,
163 }
164
165 if arch.is_spilling_enabled():
166 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100167 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000168 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
169
Louis Verhaard024c3552021-03-17 14:26:34 +0100170 return base_ptr_idx_map[mem_type].value
171
172
173def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
174 """Returns map region -> max size of the region in bytes"""
175 mem_limits = dict()
176 for mem_type in MemType.all():
177 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
178 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
179 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100180
181
182def get_upscale(op: Operation) -> NpuResamplingMode:
183 upscale = NpuResamplingMode.NONE
184 if op.type == Op.ResizeBilinear:
185 # perform nearest neighbor upscale
186 upscale = NpuResamplingMode.NEAREST
187 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
188 # perform insert zero upscale
189 upscale = NpuResamplingMode.TRANSPOSE
190 return upscale
191
192
193def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
194 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100195 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100196 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100197 block = ofm_box.get_block()
198 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100199
200
201def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
202 """Checks if quantization should use 0 as zero point"""
203 if tens.dtype == DataType.int32 and is_ifm_tensor:
204 return True
205 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
206 return False
207 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
208 forced_ofm_quantization = ps.primary_op.forced_output_quantization
209 use_0 = (
210 (ps.primary_op.activation is None or forced_ofm_quantization is not None)
211 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
212 and not fused_quantize
213 )
214 return use_0
215
216
217def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
218 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100219 op = ps.primary_op
220 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
221 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100222 return None
223 if use_zero_point_0(ps, tens, True):
224 zero_point = 0
225 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100226 zero_point = int(ifm_quant.zero_point)
227 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100228
229
230def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
231 """Gets quantization for OFM"""
232 op = ps.primary_op
233 # Check if operation's output quantization is should be used instead of the output tensor's quantization
234 # (used in LUTs)
235 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
236 if ofm_quant is None:
237 return None
238 if use_zero_point_0(ps, tens, False):
239 zero_point = 0
240 else:
241 zero_point = int(ofm_quant.zero_point)
242 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
243
244
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100245def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100246 """Creates feature map with common fields populated"""
247 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100248 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100249 fm.data_type = dtype_map[tens.dtype]
250 if tens.format == TensorFormat.NHWC:
251 fm.layout = NpuLayout.NHWC
252 elif tens.format == TensorFormat.NHCWB16:
253 fm.layout = NpuLayout.NHCWB16
254 else:
255 assert 0, "Incorrect tensor format"
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100256 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
257 box.start_coord, box.end_coord, op_shape4D
258 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100259 for idx, addr in enumerate(addresses):
260 if addr is None:
261 addresses[idx] = 0
262 fm.tiles = NpuTileBox(
263 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
264 )
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100265 strides = tens.get_strides(shape4D=op_shape4D)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100266 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
267 return fm
268
269
Tim Halld784af72021-06-08 21:25:57 +0100270def create_weights(
271 weight_tensor: Tensor, weight_box: Box, scale_tensor: Tensor, arch: ArchitectureFeatures
272) -> List[NpuAddressRange]:
Tim Halld8339a72021-05-27 18:49:40 +0100273 """Returns address ranges for weights and scales"""
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100274 weights = []
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100275 biases = []
Tim Halld784af72021-06-08 21:25:57 +0100276 shared_region = get_region(weight_tensor.mem_type, arch)
277 scale_region = scale_tensor and get_region(scale_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100278
Tim Halld8339a72021-05-27 18:49:40 +0100279 w_tensor_src = weight_tensor
280 if weight_tensor.src_tensor:
281 w_tensor_src = weight_tensor.src_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100282
Tim Halld8339a72021-05-27 18:49:40 +0100283 core_offset = 0
284 for core in range(0, arch.ncores):
285 # Get weight range per core
286 key = WeightKey(core, weight_box.start_coord[-1])
287 if key in w_tensor_src.encoded_ranges:
288 weight_range = w_tensor_src.encoded_ranges[key]
289 if weight_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
290 assert weight_tensor != w_tensor_src
291 # Double buffered inside weight_tensor
292 address = weight_tensor.address + w_tensor_src.max_range_bytes * ((weight_range.index - core) % 2)
293 address += core_offset
294 core_offset += round_up(weight_range.total_bytes, 16)
295 else:
296 if weight_tensor == w_tensor_src:
297 # Straight from source tensor
298 address = weight_tensor.address + weight_range.offset
299 else:
300 # Single buffered inside weight tensor
301 address = weight_tensor.address + core_offset
302 core_offset += round_up(weight_range.total_bytes, 16)
303
304 # Location of weights in tensor
305 addr_range = NpuAddressRange(
Tim Halld784af72021-06-08 21:25:57 +0100306 shared_region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16)
Tim Halld8339a72021-05-27 18:49:40 +0100307 )
308 weights.append(addr_range)
Tim Halld784af72021-06-08 21:25:57 +0100309
310 # Location of standalone scales or combined weights tensor scales
311 if scale_tensor:
312 assert scale_tensor.src_tensor is None # Must be standalone
313 scale_range = scale_tensor.encoded_ranges[key]
314 address = scale_tensor.address + scale_range.offset
315 addr_range = NpuAddressRange(scale_region, int(address), round_up(int(scale_range.scale_bytes), 16))
316 else:
317 addr_range = NpuAddressRange(shared_region, int(address), round_up(int(weight_range.scale_bytes), 16))
318
Tim Halld8339a72021-05-27 18:49:40 +0100319 biases.append(addr_range)
320
321 return weights, biases
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100322
323
324def create_npu_activation(op: Operation) -> NpuActivation:
325 """Creates fused activation function"""
326 if op.activation is None:
327 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
328 faf = op.activation.op_type
329 act_op = NpuActivationOp.NONE_OR_RELU
330 if faf == Op.Tanh:
331 act_op = NpuActivationOp.TANH
332 elif faf == Op.Sigmoid:
333 act_op = NpuActivationOp.SIGMOID
334 elif faf == Op.LUT:
335 act_op = NpuActivationOp.TABLE_LOOKUP
336 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000337 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100338
339 act = NpuActivation(act_op)
340 act.min = op.activation.min
341 act.max = op.activation.max
342 act.lookup_table_index = op.activation.lut_index
343 return act
344
345
346def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
347 """Sets common fields of the given operation"""
348 ps = cmd.ps
349 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100350
351 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100352 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100353 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100354
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100355 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100356 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100357 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100358
359 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100360 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100361 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100362 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
363
364 if cmd.weight_tensor is not None:
Tim Halld784af72021-06-08 21:25:57 +0100365 npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, cmd.scale_tensor, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100366 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100367 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
368 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100369 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
370
371 if not op.type.is_elementwise_op():
372 npu_op.padding = create_padding(cmd, op)
373 npu_op.kernel = to_npu_kernel(op.kernel)
374 npu_op.ifm_upscale = get_upscale(op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100375 return npu_op
376
377
378def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
379 """Converts the command to NpuConv2DOperation"""
380 npu_op = NpuConv2DOperation()
381 set_common_op_fields(npu_op, cmd, arch)
382 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
383 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
384 else:
Tim Halld8339a72021-05-27 18:49:40 +0100385 if cmd.weight_tensor.src_tensor:
386 npu_op.block_traversal = cmd.weight_tensor.src_tensor.hw_traversal
387 else:
388 npu_op.block_traversal = cmd.weight_tensor.hw_traversal
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100389 return npu_op
390
391
392def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
393 """Converts the command to NpuConvDepthWiseOperation"""
394 npu_op = NpuConvDepthWiseOperation()
395 set_common_op_fields(npu_op, cmd, arch)
396 return npu_op
397
398
399def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
400 """Converts the command to NpuPoolingOperation"""
401 ps = cmd.ps
402 op = ps.primary_op
403 pool_op = NpuPoolingOp.AVERAGE
404 if op.type.is_maxpool_op():
405 pool_op = NpuPoolingOp.MAX
406 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
407 pool_op = NpuPoolingOp.AVERAGE
408 elif op.type == Op.ReduceSum:
409 pool_op = NpuPoolingOp.REDUCE_SUM
410 else:
411 assert 0, f"Unknown pool type {op.type}"
412 npu_op = NpuPoolingOperation(pool_op)
413 set_common_op_fields(npu_op, cmd, arch)
414 # Pooling specific info
Dwight Lidman4f728c02020-12-17 15:14:45 +0100415 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100416 return npu_op
417
418
419def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
420 """Converts the command to NpuElementWiseOperation"""
421 ps = cmd.ps
422 op = ps.primary_op
423 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
424 elemwise_op = elementwise_op_map[op.type]
425 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100426
Louis Verhaard1e170182020-11-26 11:42:04 +0100427 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100428 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
429 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
430 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100431 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
432 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
433 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100434 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100435 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100436 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100437 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
438 if cmd.ifm2_tensor.shape == []:
439 # scalar
James Peet7519d502021-07-19 16:47:58 +0100440 npu_op.ifm2_scalar = cmd.ifm2_tensor.get_scalar()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100441 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
442 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100443 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100444 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100445 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100446 set_common_op_fields(npu_op, cmd, arch)
447 # Check if output scale needs to be overridden
448 output_scale = None
449 if op.type == Op.Add and "resizebilinear" in op.attrs:
450 # Force output scale same as the input scale for
451 # resizebilinear 1x1 that is converted to add
452 output_scale = npu_op.ifm2.quantization.scale_f32
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100453 if op.type == Op.Abs:
454 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100455 if op.type == Op.LeakyRelu:
456 output_scale = op.attrs["alpha"]
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100457 if op.type == Op.RescaleAdd:
458 assert op.rescale is not None, f"{op.type} must have rescale"
459 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100460 if op.type in (Op.Add, Op.Mul, Op.Sub):
461 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
462 output_scale = 1 / 0x3000
463 if output_scale is not None:
464 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
465 return npu_op
466
467
468def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
469 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100470 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100471 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100472 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100473 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100474 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100475
Tim Halld8339a72021-05-27 18:49:40 +0100476 if cmd.in_tensor.purpose == TensorPurpose.Weights:
477 # Get weight range per core
478 sz = 0
479 for core in range(0, arch.ncores):
480 key = WeightKey(core, cmd.box.start_coord[-1])
481 if key in cmd.in_tensor.encoded_ranges:
482 weight_range = cmd.in_tensor.encoded_ranges[key]
483 sz += round_up(weight_range.total_bytes, 16)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100484
Tim Halld8339a72021-05-27 18:49:40 +0100485 if core == 0:
486 weight_range = cmd.in_tensor.encoded_ranges[key]
487 src_addr = cmd.in_tensor.address + weight_range.offset
488
489 if cmd.out_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
490 dest_addr = cmd.out_tensor.address + cmd.in_tensor.max_range_bytes * (
491 (weight_range.index - core) % 2
492 )
493 else:
494 dest_addr = cmd.out_tensor.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100495 else:
Tim Halld8339a72021-05-27 18:49:40 +0100496 start_coord = cmd.box.start_coord
497 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
498 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100499 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
500 src = NpuAddressRange(src_region, int(src_addr), int(sz))
501 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
502 return NpuDmaOperation(src, dest)
503
504
505def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
506 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100507 npu_op: NpuOperation
508 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100509 npu_op = create_dma_op(cmd, arch)
Dwight Lidman9b43f842020-12-08 17:56:44 +0100510 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100511 npu_block_type = cmd.ps.primary_op.type.npu_block_type
512 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
513 npu_op = create_npu_conv2d_op(cmd, arch)
514 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
515 npu_op = create_npu_conv_depthwise_op(cmd, arch)
516 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
517 npu_op = create_npu_pool_op(cmd, arch)
518 elif npu_block_type == NpuBlockType.ElementWise:
519 npu_op = create_npu_elementwise_op(cmd, arch)
520 else:
521 assert 0, f"Unknown command type {npu_block_type}"
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100522 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100523
524
525def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
526 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
527 # Convert high level command stream to list of NpuOperation
528 npu_op_list = []
529 npu_op_to_cmd = dict() # map from npu op to high level command
530 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100531 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100532 print("Warning: Skipping register command stream generation for", cmd.ps)
533 else:
534 npu_op = convert_command_to_npu_op(cmd, arch)
535 npu_op_list.append(npu_op)
536 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100537 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100538 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100539 if len(sg.high_level_command_stream) > 0:
540 stream_id = DebugDatabase.add_stream(sg)
541 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100542
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100543 def add_to_debug_db(npu_op: NpuOperation, offset: int):
544 """Adds info to the debug database"""
545 if not isinstance(npu_op, NpuDmaOperation):
546 cmd = npu_op_to_cmd[npu_op]
547 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100548
Louis Verhaard024c3552021-03-17 14:26:34 +0100549 sg.register_command_stream = generate_command_stream(
550 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
551 )