blob: 67d1cd9b08b2b592294b6a1ec3b63fad1a8a5a69 [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Louis Verhaard024c3552021-03-17 14:26:34 +010020from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010021from typing import List
22from typing import Optional
23
24from .api import NpuActivation
25from .api import NpuActivationOp
26from .api import NpuAddressRange
27from .api import NpuBlockOperation
28from .api import NpuBlockTraversal
29from .api import NpuConv2DOperation
30from .api import NpuConvDepthWiseOperation
31from .api import NpuDataType
32from .api import NpuDmaOperation
33from .api import NpuElementWiseOp
34from .api import NpuElementWiseOperation
35from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010036from .api import NpuLayout
37from .api import NpuOperation
38from .api import NpuPadding
39from .api import NpuPoolingOp
40from .api import NpuPoolingOperation
41from .api import NpuQuantization
42from .api import NpuResamplingMode
43from .api import NpuRoundingMode
44from .api import NpuShape3D
45from .api import NpuTileBox
46from .architecture_features import ArchitectureFeatures
47from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010048from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000049from .errors import UnsupportedFeatureError
Louis Verhaarde8a5a782020-11-02 18:04:27 +010050from .high_level_command_stream import Box
51from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010052from .high_level_command_stream import DMA
53from .high_level_command_stream import NpuStripe
Tim Halld8339a72021-05-27 18:49:40 +010054from .numeric_util import round_up
Louis Verhaarde8a5a782020-11-02 18:04:27 +010055from .operation import NpuBlockType
56from .operation import Op
57from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010058from .register_command_stream_generator import generate_command_stream
59from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010060from .register_command_stream_util import to_npu_kernel
61from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000062from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010063from .tensor import MemType
64from .tensor import Tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010065from .tensor import TensorFormat
66from .tensor import TensorPurpose
Tim Halld8339a72021-05-27 18:49:40 +010067from .tensor import TensorSubPurpose
68from .weight_compressor import WeightKey
Louis Verhaarde8a5a782020-11-02 18:04:27 +010069
70
Louis Verhaarde8a5a782020-11-02 18:04:27 +010071class BasePointerIndex(IntEnum):
72 WeightTensor = 0 # base address index for the Weight tensor
73 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
74 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010075
76
77dtype_map = {
78 DataType.uint8: NpuDataType.UINT8,
79 DataType.int8: NpuDataType.INT8,
80 DataType.uint16: NpuDataType.UINT16,
81 DataType.int16: NpuDataType.INT16,
82 DataType.int32: NpuDataType.INT32,
83}
84
85
Louis Verhaarde8a5a782020-11-02 18:04:27 +010086# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
87elementwise_op_map = {
88 Op.Mul: NpuElementWiseOp.MUL,
Patrik Gustavssonb081d672021-08-25 13:49:25 +020089 Op.RescaleMul: NpuElementWiseOp.MUL,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010090 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010091 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010092 Op.Sub: NpuElementWiseOp.SUB,
93 Op.Minimum: NpuElementWiseOp.MIN,
94 Op.Maximum: NpuElementWiseOp.MAX,
95 Op.LeakyRelu: NpuElementWiseOp.LRELU,
96 Op.Abs: NpuElementWiseOp.ABS,
97 Op.CLZ: NpuElementWiseOp.CLZ,
98 Op.SHR: NpuElementWiseOp.SHR,
99 Op.SHL: NpuElementWiseOp.SHL,
100}
101
102
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100103def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
104 if ifm_shape == []:
105 # Scalar needs to be in IFM2
106 return False
107 if ifm2_shape == []:
108 return True
109
110 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
111 if ifm != ifm2 and ifm == 1:
112 # Broadcasted FM needs to be in IFM2
113 return False
114 return True
115
116
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100117def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100118 """Specifies type of rounding to be used"""
119 rounding_mode = NpuRoundingMode.TFL
120 if op.type == Op.ResizeBilinear:
Dwight Lidman9d243932021-08-10 12:53:12 +0200121 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100122 elif (
123 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
124 and op.ifm.dtype == DataType.int16
125 ):
126 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100127 elif (
128 not fused_quantize
129 and op.type.is_avgpool_op()
130 and op.memory_function == Op.ConcatSliceWrite
131 and op.kernel.elements_wh() == 1
132 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100133 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100134 if op.rounding_mode is not None:
135 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100136 return rounding_mode
137
138
139def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
140 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
141 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100142 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100143
144 # Check if this is for horizontal ifm streaming
145 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100146 top = cmd.pad_top
147 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100148
149 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
150 # because of activation function needed to be fused.
Henrik G Olsson5fabfca2021-04-15 17:57:26 +0200151 if not primary_op.attrs.get("force_padding"):
152 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > 0:
153 left = 0
154 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < cmd.ps.ifm_shapes[0].width:
155 right = 0
Louis Verhaard69b31762020-11-17 09:45:20 +0100156 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100157
158
Louis Verhaard024c3552021-03-17 14:26:34 +0100159def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000160 base_ptr_idx_map = {
161 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
162 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
163 MemType.Scratch: BasePointerIndex.ScratchTensor,
164 }
165
166 if arch.is_spilling_enabled():
167 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100168 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000169 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
170
Louis Verhaard024c3552021-03-17 14:26:34 +0100171 return base_ptr_idx_map[mem_type].value
172
173
174def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
175 """Returns map region -> max size of the region in bytes"""
176 mem_limits = dict()
177 for mem_type in MemType.all():
178 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
179 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
180 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100181
182
183def get_upscale(op: Operation) -> NpuResamplingMode:
184 upscale = NpuResamplingMode.NONE
185 if op.type == Op.ResizeBilinear:
186 # perform nearest neighbor upscale
187 upscale = NpuResamplingMode.NEAREST
188 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
189 # perform insert zero upscale
190 upscale = NpuResamplingMode.TRANSPOSE
191 return upscale
192
193
194def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
195 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100196 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100197 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100198 block = ofm_box.get_block()
199 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100200
201
202def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
203 """Checks if quantization should use 0 as zero point"""
204 if tens.dtype == DataType.int32 and is_ifm_tensor:
205 return True
206 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
207 return False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200208 if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
209 return False
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100210 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
211 forced_ofm_quantization = ps.primary_op.forced_output_quantization
212 use_0 = (
213 (ps.primary_op.activation is None or forced_ofm_quantization is not None)
214 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
215 and not fused_quantize
216 )
217 return use_0
218
219
220def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
221 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100222 op = ps.primary_op
223 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
224 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100225 return None
226 if use_zero_point_0(ps, tens, True):
227 zero_point = 0
228 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100229 zero_point = int(ifm_quant.zero_point)
230 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100231
232
233def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
234 """Gets quantization for OFM"""
235 op = ps.primary_op
236 # Check if operation's output quantization is should be used instead of the output tensor's quantization
237 # (used in LUTs)
238 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
239 if ofm_quant is None:
240 return None
241 if use_zero_point_0(ps, tens, False):
242 zero_point = 0
243 else:
244 zero_point = int(ofm_quant.zero_point)
245 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
246
247
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100248def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100249 """Creates feature map with common fields populated"""
250 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100251 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100252 fm.data_type = dtype_map[tens.dtype]
253 if tens.format == TensorFormat.NHWC:
254 fm.layout = NpuLayout.NHWC
255 elif tens.format == TensorFormat.NHCWB16:
256 fm.layout = NpuLayout.NHCWB16
257 else:
258 assert 0, "Incorrect tensor format"
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100259 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
260 box.start_coord, box.end_coord, op_shape4D
261 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100262 for idx, addr in enumerate(addresses):
263 if addr is None:
264 addresses[idx] = 0
265 fm.tiles = NpuTileBox(
266 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
267 )
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100268 strides = tens.get_strides(shape4D=op_shape4D)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100269 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
270 return fm
271
272
Tim Halld784af72021-06-08 21:25:57 +0100273def create_weights(
274 weight_tensor: Tensor, weight_box: Box, scale_tensor: Tensor, arch: ArchitectureFeatures
275) -> List[NpuAddressRange]:
Tim Halld8339a72021-05-27 18:49:40 +0100276 """Returns address ranges for weights and scales"""
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100277 weights = []
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100278 biases = []
Tim Halld784af72021-06-08 21:25:57 +0100279 shared_region = get_region(weight_tensor.mem_type, arch)
280 scale_region = scale_tensor and get_region(scale_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100281
Tim Halld8339a72021-05-27 18:49:40 +0100282 w_tensor_src = weight_tensor
283 if weight_tensor.src_tensor:
284 w_tensor_src = weight_tensor.src_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100285
Tim Halld8339a72021-05-27 18:49:40 +0100286 core_offset = 0
287 for core in range(0, arch.ncores):
288 # Get weight range per core
289 key = WeightKey(core, weight_box.start_coord[-1])
290 if key in w_tensor_src.encoded_ranges:
291 weight_range = w_tensor_src.encoded_ranges[key]
292 if weight_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
293 assert weight_tensor != w_tensor_src
294 # Double buffered inside weight_tensor
295 address = weight_tensor.address + w_tensor_src.max_range_bytes * ((weight_range.index - core) % 2)
296 address += core_offset
297 core_offset += round_up(weight_range.total_bytes, 16)
298 else:
299 if weight_tensor == w_tensor_src:
300 # Straight from source tensor
301 address = weight_tensor.address + weight_range.offset
302 else:
303 # Single buffered inside weight tensor
304 address = weight_tensor.address + core_offset
305 core_offset += round_up(weight_range.total_bytes, 16)
306
307 # Location of weights in tensor
308 addr_range = NpuAddressRange(
Tim Halld784af72021-06-08 21:25:57 +0100309 shared_region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16)
Tim Halld8339a72021-05-27 18:49:40 +0100310 )
311 weights.append(addr_range)
Tim Halld784af72021-06-08 21:25:57 +0100312
313 # Location of standalone scales or combined weights tensor scales
314 if scale_tensor:
315 assert scale_tensor.src_tensor is None # Must be standalone
316 scale_range = scale_tensor.encoded_ranges[key]
317 address = scale_tensor.address + scale_range.offset
318 addr_range = NpuAddressRange(scale_region, int(address), round_up(int(scale_range.scale_bytes), 16))
319 else:
320 addr_range = NpuAddressRange(shared_region, int(address), round_up(int(weight_range.scale_bytes), 16))
321
Tim Halld8339a72021-05-27 18:49:40 +0100322 biases.append(addr_range)
323
324 return weights, biases
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100325
326
327def create_npu_activation(op: Operation) -> NpuActivation:
328 """Creates fused activation function"""
329 if op.activation is None:
330 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
331 faf = op.activation.op_type
332 act_op = NpuActivationOp.NONE_OR_RELU
333 if faf == Op.Tanh:
334 act_op = NpuActivationOp.TANH
335 elif faf == Op.Sigmoid:
336 act_op = NpuActivationOp.SIGMOID
337 elif faf == Op.LUT:
338 act_op = NpuActivationOp.TABLE_LOOKUP
339 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000340 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100341
342 act = NpuActivation(act_op)
343 act.min = op.activation.min
344 act.max = op.activation.max
345 act.lookup_table_index = op.activation.lut_index
346 return act
347
348
349def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
350 """Sets common fields of the given operation"""
351 ps = cmd.ps
352 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100353
354 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100355 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100356 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100357
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100358 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100359 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100360 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100361
362 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100363 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100364 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100365 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
366
367 if cmd.weight_tensor is not None:
Tim Halld784af72021-06-08 21:25:57 +0100368 npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, cmd.scale_tensor, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100369 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100370 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
371 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100372 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
373
374 if not op.type.is_elementwise_op():
375 npu_op.padding = create_padding(cmd, op)
376 npu_op.kernel = to_npu_kernel(op.kernel)
377 npu_op.ifm_upscale = get_upscale(op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100378 return npu_op
379
380
381def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
382 """Converts the command to NpuConv2DOperation"""
383 npu_op = NpuConv2DOperation()
384 set_common_op_fields(npu_op, cmd, arch)
385 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
386 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
387 else:
Tim Halld8339a72021-05-27 18:49:40 +0100388 if cmd.weight_tensor.src_tensor:
389 npu_op.block_traversal = cmd.weight_tensor.src_tensor.hw_traversal
390 else:
391 npu_op.block_traversal = cmd.weight_tensor.hw_traversal
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100392 return npu_op
393
394
395def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
396 """Converts the command to NpuConvDepthWiseOperation"""
397 npu_op = NpuConvDepthWiseOperation()
398 set_common_op_fields(npu_op, cmd, arch)
399 return npu_op
400
401
402def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
403 """Converts the command to NpuPoolingOperation"""
404 ps = cmd.ps
405 op = ps.primary_op
406 pool_op = NpuPoolingOp.AVERAGE
407 if op.type.is_maxpool_op():
408 pool_op = NpuPoolingOp.MAX
409 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
410 pool_op = NpuPoolingOp.AVERAGE
411 elif op.type == Op.ReduceSum:
412 pool_op = NpuPoolingOp.REDUCE_SUM
413 else:
414 assert 0, f"Unknown pool type {op.type}"
415 npu_op = NpuPoolingOperation(pool_op)
416 set_common_op_fields(npu_op, cmd, arch)
417 # Pooling specific info
Dwight Lidman4f728c02020-12-17 15:14:45 +0100418 npu_op.rescale = op.rescale
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200419 if op.explicit_scaling:
420 # Note: reuse of rescale for explicit scaling to not expose this in the external API
421 assert npu_op.rescale is None
422 npu_op.rescale = op.explicit_scaling
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100423 return npu_op
424
425
426def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
427 """Converts the command to NpuElementWiseOperation"""
428 ps = cmd.ps
429 op = ps.primary_op
430 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
431 elemwise_op = elementwise_op_map[op.type]
432 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100433
Louis Verhaard1e170182020-11-26 11:42:04 +0100434 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100435 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
436 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
437 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100438 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
439 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
440 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100441 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100442 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100443 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100444 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
445 if cmd.ifm2_tensor.shape == []:
446 # scalar
James Peet7519d502021-07-19 16:47:58 +0100447 npu_op.ifm2_scalar = cmd.ifm2_tensor.get_scalar()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100448 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
449 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100450 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100451 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100452 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100453 set_common_op_fields(npu_op, cmd, arch)
454 # Check if output scale needs to be overridden
455 output_scale = None
456 if op.type == Op.Add and "resizebilinear" in op.attrs:
457 # Force output scale same as the input scale for
458 # resizebilinear 1x1 that is converted to add
459 output_scale = npu_op.ifm2.quantization.scale_f32
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100460 if op.type == Op.Abs:
461 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100462 if op.type == Op.LeakyRelu:
463 output_scale = op.attrs["alpha"]
Patrik Gustavssonb081d672021-08-25 13:49:25 +0200464 if op.type in (Op.RescaleAdd, Op.RescaleMul):
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100465 assert op.rescale is not None, f"{op.type} must have rescale"
466 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100467 if op.type in (Op.Add, Op.Mul, Op.Sub):
468 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
469 output_scale = 1 / 0x3000
470 if output_scale is not None:
471 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
472 return npu_op
473
474
475def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
476 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100477 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100478 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100479 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100480 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100481 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100482
Tim Halld8339a72021-05-27 18:49:40 +0100483 if cmd.in_tensor.purpose == TensorPurpose.Weights:
484 # Get weight range per core
485 sz = 0
486 for core in range(0, arch.ncores):
487 key = WeightKey(core, cmd.box.start_coord[-1])
488 if key in cmd.in_tensor.encoded_ranges:
489 weight_range = cmd.in_tensor.encoded_ranges[key]
490 sz += round_up(weight_range.total_bytes, 16)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100491
Tim Halld8339a72021-05-27 18:49:40 +0100492 if core == 0:
493 weight_range = cmd.in_tensor.encoded_ranges[key]
494 src_addr = cmd.in_tensor.address + weight_range.offset
495
496 if cmd.out_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
497 dest_addr = cmd.out_tensor.address + cmd.in_tensor.max_range_bytes * (
498 (weight_range.index - core) % 2
499 )
500 else:
501 dest_addr = cmd.out_tensor.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100502 else:
Tim Halld8339a72021-05-27 18:49:40 +0100503 start_coord = cmd.box.start_coord
504 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
505 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100506 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
507 src = NpuAddressRange(src_region, int(src_addr), int(sz))
508 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
509 return NpuDmaOperation(src, dest)
510
511
512def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
513 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100514 npu_op: NpuOperation
515 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100516 npu_op = create_dma_op(cmd, arch)
Dwight Lidman9b43f842020-12-08 17:56:44 +0100517 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100518 npu_block_type = cmd.ps.primary_op.type.npu_block_type
519 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
520 npu_op = create_npu_conv2d_op(cmd, arch)
521 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
522 npu_op = create_npu_conv_depthwise_op(cmd, arch)
523 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
524 npu_op = create_npu_pool_op(cmd, arch)
525 elif npu_block_type == NpuBlockType.ElementWise:
526 npu_op = create_npu_elementwise_op(cmd, arch)
527 else:
528 assert 0, f"Unknown command type {npu_block_type}"
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100529 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100530
531
532def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
533 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
534 # Convert high level command stream to list of NpuOperation
535 npu_op_list = []
536 npu_op_to_cmd = dict() # map from npu op to high level command
537 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100538 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100539 print("Warning: Skipping register command stream generation for", cmd.ps)
540 else:
541 npu_op = convert_command_to_npu_op(cmd, arch)
542 npu_op_list.append(npu_op)
543 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100544 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100545 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100546 if len(sg.high_level_command_stream) > 0:
547 stream_id = DebugDatabase.add_stream(sg)
548 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100549
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100550 def add_to_debug_db(npu_op: NpuOperation, offset: int):
551 """Adds info to the debug database"""
552 if not isinstance(npu_op, NpuDmaOperation):
553 cmd = npu_op_to_cmd[npu_op]
554 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100555
Louis Verhaard024c3552021-03-17 14:26:34 +0100556 sg.register_command_stream = generate_command_stream(
557 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
558 )