blob: 4ef7bee870e79487c47749fb286fe2365e82573c [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Louis Verhaard024c3552021-03-17 14:26:34 +010020from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010021from typing import List
22from typing import Optional
23
24from .api import NpuActivation
25from .api import NpuActivationOp
26from .api import NpuAddressRange
27from .api import NpuBlockOperation
28from .api import NpuBlockTraversal
29from .api import NpuConv2DOperation
30from .api import NpuConvDepthWiseOperation
31from .api import NpuDataType
32from .api import NpuDmaOperation
33from .api import NpuElementWiseOp
34from .api import NpuElementWiseOperation
35from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010036from .api import NpuLayout
37from .api import NpuOperation
38from .api import NpuPadding
39from .api import NpuPoolingOp
40from .api import NpuPoolingOperation
41from .api import NpuQuantization
42from .api import NpuResamplingMode
43from .api import NpuRoundingMode
44from .api import NpuShape3D
45from .api import NpuTileBox
46from .architecture_features import ArchitectureFeatures
47from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010048from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000049from .errors import UnsupportedFeatureError
Louis Verhaarde8a5a782020-11-02 18:04:27 +010050from .high_level_command_stream import Box
51from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010052from .high_level_command_stream import DMA
53from .high_level_command_stream import NpuStripe
Tim Halld8339a72021-05-27 18:49:40 +010054from .numeric_util import round_up
Louis Verhaarde8a5a782020-11-02 18:04:27 +010055from .operation import NpuBlockType
56from .operation import Op
57from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010058from .register_command_stream_generator import generate_command_stream
59from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010060from .register_command_stream_util import to_npu_kernel
61from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000062from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010063from .tensor import MemType
64from .tensor import Tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010065from .tensor import TensorFormat
66from .tensor import TensorPurpose
Tim Halld8339a72021-05-27 18:49:40 +010067from .tensor import TensorSubPurpose
68from .weight_compressor import WeightKey
Louis Verhaarde8a5a782020-11-02 18:04:27 +010069
70
Louis Verhaarde8a5a782020-11-02 18:04:27 +010071class BasePointerIndex(IntEnum):
72 WeightTensor = 0 # base address index for the Weight tensor
73 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
74 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010075
76
77dtype_map = {
78 DataType.uint8: NpuDataType.UINT8,
79 DataType.int8: NpuDataType.INT8,
80 DataType.uint16: NpuDataType.UINT16,
81 DataType.int16: NpuDataType.INT16,
82 DataType.int32: NpuDataType.INT32,
83}
84
85
Louis Verhaarde8a5a782020-11-02 18:04:27 +010086# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
87elementwise_op_map = {
88 Op.Mul: NpuElementWiseOp.MUL,
89 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010090 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010091 Op.Sub: NpuElementWiseOp.SUB,
92 Op.Minimum: NpuElementWiseOp.MIN,
93 Op.Maximum: NpuElementWiseOp.MAX,
94 Op.LeakyRelu: NpuElementWiseOp.LRELU,
95 Op.Abs: NpuElementWiseOp.ABS,
96 Op.CLZ: NpuElementWiseOp.CLZ,
97 Op.SHR: NpuElementWiseOp.SHR,
98 Op.SHL: NpuElementWiseOp.SHL,
99}
100
101
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100102def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
103 if ifm_shape == []:
104 # Scalar needs to be in IFM2
105 return False
106 if ifm2_shape == []:
107 return True
108
109 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
110 if ifm != ifm2 and ifm == 1:
111 # Broadcasted FM needs to be in IFM2
112 return False
113 return True
114
115
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100116def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100117 """Specifies type of rounding to be used"""
118 rounding_mode = NpuRoundingMode.TFL
119 if op.type == Op.ResizeBilinear:
120 rounding_mode = NpuRoundingMode.TRUNCATE
121 elif (
122 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
123 and op.ifm.dtype == DataType.int16
124 ):
125 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100126 elif (
127 not fused_quantize
128 and op.type.is_avgpool_op()
129 and op.memory_function == Op.ConcatSliceWrite
130 and op.kernel.elements_wh() == 1
131 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100132 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100133 if op.rounding_mode is not None:
134 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100135 return rounding_mode
136
137
138def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
139 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
140 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100141 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100142
143 # Check if this is for horizontal ifm streaming
144 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100145 top = cmd.pad_top
146 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100147
148 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
149 # because of activation function needed to be fused.
Henrik G Olsson5fabfca2021-04-15 17:57:26 +0200150 if not primary_op.attrs.get("force_padding"):
151 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > 0:
152 left = 0
153 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < cmd.ps.ifm_shapes[0].width:
154 right = 0
Louis Verhaard69b31762020-11-17 09:45:20 +0100155 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100156
157
Louis Verhaard024c3552021-03-17 14:26:34 +0100158def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000159 base_ptr_idx_map = {
160 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
161 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
162 MemType.Scratch: BasePointerIndex.ScratchTensor,
163 }
164
165 if arch.is_spilling_enabled():
166 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100167 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000168 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
169
Louis Verhaard024c3552021-03-17 14:26:34 +0100170 return base_ptr_idx_map[mem_type].value
171
172
173def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
174 """Returns map region -> max size of the region in bytes"""
175 mem_limits = dict()
176 for mem_type in MemType.all():
177 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
178 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
179 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100180
181
182def get_upscale(op: Operation) -> NpuResamplingMode:
183 upscale = NpuResamplingMode.NONE
184 if op.type == Op.ResizeBilinear:
185 # perform nearest neighbor upscale
186 upscale = NpuResamplingMode.NEAREST
187 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
188 # perform insert zero upscale
189 upscale = NpuResamplingMode.TRANSPOSE
190 return upscale
191
192
193def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
194 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100195 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100196 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100197 block = ofm_box.get_block()
198 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100199
200
201def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
202 """Checks if quantization should use 0 as zero point"""
203 if tens.dtype == DataType.int32 and is_ifm_tensor:
204 return True
205 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
206 return False
207 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
208 forced_ofm_quantization = ps.primary_op.forced_output_quantization
209 use_0 = (
210 (ps.primary_op.activation is None or forced_ofm_quantization is not None)
211 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
212 and not fused_quantize
213 )
214 return use_0
215
216
217def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
218 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100219 op = ps.primary_op
220 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
221 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100222 return None
223 if use_zero_point_0(ps, tens, True):
224 zero_point = 0
225 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100226 zero_point = int(ifm_quant.zero_point)
227 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100228
229
230def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
231 """Gets quantization for OFM"""
232 op = ps.primary_op
233 # Check if operation's output quantization is should be used instead of the output tensor's quantization
234 # (used in LUTs)
235 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
236 if ofm_quant is None:
237 return None
238 if use_zero_point_0(ps, tens, False):
239 zero_point = 0
240 else:
241 zero_point = int(ofm_quant.zero_point)
242 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
243
244
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100245def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100246 """Creates feature map with common fields populated"""
247 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100248 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100249 fm.data_type = dtype_map[tens.dtype]
250 if tens.format == TensorFormat.NHWC:
251 fm.layout = NpuLayout.NHWC
252 elif tens.format == TensorFormat.NHCWB16:
253 fm.layout = NpuLayout.NHCWB16
254 else:
255 assert 0, "Incorrect tensor format"
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100256 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
257 box.start_coord, box.end_coord, op_shape4D
258 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100259 for idx, addr in enumerate(addresses):
260 if addr is None:
261 addresses[idx] = 0
262 fm.tiles = NpuTileBox(
263 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
264 )
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100265 strides = tens.get_strides(shape4D=op_shape4D)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100266 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
267 return fm
268
269
270def create_weights(weight_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures) -> List[NpuAddressRange]:
Tim Halld8339a72021-05-27 18:49:40 +0100271 """Returns address ranges for weights and scales"""
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100272 weights = []
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100273 biases = []
Tim Halld8339a72021-05-27 18:49:40 +0100274 region = get_region(weight_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100275
Tim Halld8339a72021-05-27 18:49:40 +0100276 w_tensor_src = weight_tensor
277 if weight_tensor.src_tensor:
278 w_tensor_src = weight_tensor.src_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100279
Tim Halld8339a72021-05-27 18:49:40 +0100280 core_offset = 0
281 for core in range(0, arch.ncores):
282 # Get weight range per core
283 key = WeightKey(core, weight_box.start_coord[-1])
284 if key in w_tensor_src.encoded_ranges:
285 weight_range = w_tensor_src.encoded_ranges[key]
286 if weight_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
287 assert weight_tensor != w_tensor_src
288 # Double buffered inside weight_tensor
289 address = weight_tensor.address + w_tensor_src.max_range_bytes * ((weight_range.index - core) % 2)
290 address += core_offset
291 core_offset += round_up(weight_range.total_bytes, 16)
292 else:
293 if weight_tensor == w_tensor_src:
294 # Straight from source tensor
295 address = weight_tensor.address + weight_range.offset
296 else:
297 # Single buffered inside weight tensor
298 address = weight_tensor.address + core_offset
299 core_offset += round_up(weight_range.total_bytes, 16)
300
301 # Location of weights in tensor
302 addr_range = NpuAddressRange(
303 region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16)
304 )
305 weights.append(addr_range)
306 # Location of biases in tensor
307 addr_range = NpuAddressRange(region, int(address), round_up(int(weight_range.scale_bytes), 16))
308 biases.append(addr_range)
309
310 return weights, biases
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100311
312
313def create_npu_activation(op: Operation) -> NpuActivation:
314 """Creates fused activation function"""
315 if op.activation is None:
316 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
317 faf = op.activation.op_type
318 act_op = NpuActivationOp.NONE_OR_RELU
319 if faf == Op.Tanh:
320 act_op = NpuActivationOp.TANH
321 elif faf == Op.Sigmoid:
322 act_op = NpuActivationOp.SIGMOID
323 elif faf == Op.LUT:
324 act_op = NpuActivationOp.TABLE_LOOKUP
325 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000326 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100327
328 act = NpuActivation(act_op)
329 act.min = op.activation.min
330 act.max = op.activation.max
331 act.lookup_table_index = op.activation.lut_index
332 return act
333
334
335def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
336 """Sets common fields of the given operation"""
337 ps = cmd.ps
338 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100339
340 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100341 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100342 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100343
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100344 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100345 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100346 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100347
348 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100349 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100350 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100351 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
352
353 if cmd.weight_tensor is not None:
Tim Halld8339a72021-05-27 18:49:40 +0100354 npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100355 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100356 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
357 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100358 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
359
360 if not op.type.is_elementwise_op():
361 npu_op.padding = create_padding(cmd, op)
362 npu_op.kernel = to_npu_kernel(op.kernel)
363 npu_op.ifm_upscale = get_upscale(op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100364 return npu_op
365
366
367def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
368 """Converts the command to NpuConv2DOperation"""
369 npu_op = NpuConv2DOperation()
370 set_common_op_fields(npu_op, cmd, arch)
371 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
372 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
373 else:
Tim Halld8339a72021-05-27 18:49:40 +0100374 if cmd.weight_tensor.src_tensor:
375 npu_op.block_traversal = cmd.weight_tensor.src_tensor.hw_traversal
376 else:
377 npu_op.block_traversal = cmd.weight_tensor.hw_traversal
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100378 return npu_op
379
380
381def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
382 """Converts the command to NpuConvDepthWiseOperation"""
383 npu_op = NpuConvDepthWiseOperation()
384 set_common_op_fields(npu_op, cmd, arch)
385 return npu_op
386
387
388def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
389 """Converts the command to NpuPoolingOperation"""
390 ps = cmd.ps
391 op = ps.primary_op
392 pool_op = NpuPoolingOp.AVERAGE
393 if op.type.is_maxpool_op():
394 pool_op = NpuPoolingOp.MAX
395 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
396 pool_op = NpuPoolingOp.AVERAGE
397 elif op.type == Op.ReduceSum:
398 pool_op = NpuPoolingOp.REDUCE_SUM
399 else:
400 assert 0, f"Unknown pool type {op.type}"
401 npu_op = NpuPoolingOperation(pool_op)
402 set_common_op_fields(npu_op, cmd, arch)
403 # Pooling specific info
Dwight Lidman4f728c02020-12-17 15:14:45 +0100404 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100405 return npu_op
406
407
408def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
409 """Converts the command to NpuElementWiseOperation"""
410 ps = cmd.ps
411 op = ps.primary_op
412 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
413 elemwise_op = elementwise_op_map[op.type]
414 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100415
Louis Verhaard1e170182020-11-26 11:42:04 +0100416 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100417 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
418 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
419 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100420 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
421 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
422 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100423 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100424 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100425 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100426 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
427 if cmd.ifm2_tensor.shape == []:
428 # scalar
429 assert cmd.ifm2_tensor.quant_values.size == 1
430 npu_op.ifm2_scalar = cmd.ifm2_tensor.values.item(0)
431 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
432 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100433 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100434 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100435 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100436 set_common_op_fields(npu_op, cmd, arch)
437 # Check if output scale needs to be overridden
438 output_scale = None
439 if op.type == Op.Add and "resizebilinear" in op.attrs:
440 # Force output scale same as the input scale for
441 # resizebilinear 1x1 that is converted to add
442 output_scale = npu_op.ifm2.quantization.scale_f32
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100443 if op.type == Op.Abs:
444 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100445 if op.type == Op.LeakyRelu:
446 output_scale = op.attrs["alpha"]
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100447 if op.type == Op.RescaleAdd:
448 assert op.rescale is not None, f"{op.type} must have rescale"
449 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100450 if op.type in (Op.Add, Op.Mul, Op.Sub):
451 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
452 output_scale = 1 / 0x3000
453 if output_scale is not None:
454 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
455 return npu_op
456
457
458def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
459 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100460 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100461 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100462 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100463 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100464 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100465
Tim Halld8339a72021-05-27 18:49:40 +0100466 if cmd.in_tensor.purpose == TensorPurpose.Weights:
467 # Get weight range per core
468 sz = 0
469 for core in range(0, arch.ncores):
470 key = WeightKey(core, cmd.box.start_coord[-1])
471 if key in cmd.in_tensor.encoded_ranges:
472 weight_range = cmd.in_tensor.encoded_ranges[key]
473 sz += round_up(weight_range.total_bytes, 16)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100474
Tim Halld8339a72021-05-27 18:49:40 +0100475 if core == 0:
476 weight_range = cmd.in_tensor.encoded_ranges[key]
477 src_addr = cmd.in_tensor.address + weight_range.offset
478
479 if cmd.out_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
480 dest_addr = cmd.out_tensor.address + cmd.in_tensor.max_range_bytes * (
481 (weight_range.index - core) % 2
482 )
483 else:
484 dest_addr = cmd.out_tensor.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100485 else:
Tim Halld8339a72021-05-27 18:49:40 +0100486 start_coord = cmd.box.start_coord
487 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
488 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100489 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
490 src = NpuAddressRange(src_region, int(src_addr), int(sz))
491 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
492 return NpuDmaOperation(src, dest)
493
494
495def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
496 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100497 npu_op: NpuOperation
498 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100499 npu_op = create_dma_op(cmd, arch)
Dwight Lidman9b43f842020-12-08 17:56:44 +0100500 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100501 npu_block_type = cmd.ps.primary_op.type.npu_block_type
502 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
503 npu_op = create_npu_conv2d_op(cmd, arch)
504 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
505 npu_op = create_npu_conv_depthwise_op(cmd, arch)
506 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
507 npu_op = create_npu_pool_op(cmd, arch)
508 elif npu_block_type == NpuBlockType.ElementWise:
509 npu_op = create_npu_elementwise_op(cmd, arch)
510 else:
511 assert 0, f"Unknown command type {npu_block_type}"
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100512 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100513
514
515def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
516 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
517 # Convert high level command stream to list of NpuOperation
518 npu_op_list = []
519 npu_op_to_cmd = dict() # map from npu op to high level command
520 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100521 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100522 print("Warning: Skipping register command stream generation for", cmd.ps)
523 else:
524 npu_op = convert_command_to_npu_op(cmd, arch)
525 npu_op_list.append(npu_op)
526 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100527 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100528 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100529 if len(sg.high_level_command_stream) > 0:
530 stream_id = DebugDatabase.add_stream(sg)
531 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100532
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100533 def add_to_debug_db(npu_op: NpuOperation, offset: int):
534 """Adds info to the debug database"""
535 if not isinstance(npu_op, NpuDmaOperation):
536 cmd = npu_op_to_cmd[npu_op]
537 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100538
Louis Verhaard024c3552021-03-17 14:26:34 +0100539 sg.register_command_stream = generate_command_stream(
540 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
541 )