blob: 6c403c862659b0cf00702d8f93c0ff7f4b626ecc [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Louis Verhaard024c3552021-03-17 14:26:34 +010020from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010021from typing import List
22from typing import Optional
23
24from .api import NpuActivation
25from .api import NpuActivationOp
26from .api import NpuAddressRange
27from .api import NpuBlockOperation
28from .api import NpuBlockTraversal
29from .api import NpuConv2DOperation
30from .api import NpuConvDepthWiseOperation
31from .api import NpuDataType
32from .api import NpuDmaOperation
33from .api import NpuElementWiseOp
34from .api import NpuElementWiseOperation
35from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010036from .api import NpuLayout
37from .api import NpuOperation
38from .api import NpuPadding
39from .api import NpuPoolingOp
40from .api import NpuPoolingOperation
41from .api import NpuQuantization
42from .api import NpuResamplingMode
43from .api import NpuRoundingMode
44from .api import NpuShape3D
45from .api import NpuTileBox
46from .architecture_features import ArchitectureFeatures
47from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010048from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000049from .errors import UnsupportedFeatureError
Louis Verhaarde8a5a782020-11-02 18:04:27 +010050from .high_level_command_stream import Box
51from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010052from .high_level_command_stream import DMA
53from .high_level_command_stream import NpuStripe
Fredrik Svedberg838df0a2021-09-17 16:29:22 +020054from .numeric_util import quantise_float32
Tim Halld8339a72021-05-27 18:49:40 +010055from .numeric_util import round_up
Louis Verhaarde8a5a782020-11-02 18:04:27 +010056from .operation import NpuBlockType
57from .operation import Op
58from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010059from .register_command_stream_generator import generate_command_stream
60from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010061from .register_command_stream_util import to_npu_kernel
62from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000063from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010064from .tensor import MemType
65from .tensor import Tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010066from .tensor import TensorFormat
67from .tensor import TensorPurpose
Tim Halld8339a72021-05-27 18:49:40 +010068from .tensor import TensorSubPurpose
69from .weight_compressor import WeightKey
Louis Verhaarde8a5a782020-11-02 18:04:27 +010070
71
Louis Verhaarde8a5a782020-11-02 18:04:27 +010072class BasePointerIndex(IntEnum):
73 WeightTensor = 0 # base address index for the Weight tensor
74 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
75 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010076
77
78dtype_map = {
79 DataType.uint8: NpuDataType.UINT8,
80 DataType.int8: NpuDataType.INT8,
81 DataType.uint16: NpuDataType.UINT16,
82 DataType.int16: NpuDataType.INT16,
83 DataType.int32: NpuDataType.INT32,
84}
85
86
Louis Verhaarde8a5a782020-11-02 18:04:27 +010087# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
88elementwise_op_map = {
89 Op.Mul: NpuElementWiseOp.MUL,
Patrik Gustavssonb081d672021-08-25 13:49:25 +020090 Op.RescaleMul: NpuElementWiseOp.MUL,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010091 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010092 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010093 Op.Sub: NpuElementWiseOp.SUB,
94 Op.Minimum: NpuElementWiseOp.MIN,
95 Op.Maximum: NpuElementWiseOp.MAX,
96 Op.LeakyRelu: NpuElementWiseOp.LRELU,
97 Op.Abs: NpuElementWiseOp.ABS,
98 Op.CLZ: NpuElementWiseOp.CLZ,
99 Op.SHR: NpuElementWiseOp.SHR,
100 Op.SHL: NpuElementWiseOp.SHL,
101}
102
103
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100104def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
105 if ifm_shape == []:
106 # Scalar needs to be in IFM2
107 return False
108 if ifm2_shape == []:
109 return True
110
111 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
112 if ifm != ifm2 and ifm == 1:
113 # Broadcasted FM needs to be in IFM2
114 return False
115 return True
116
117
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100118def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100119 """Specifies type of rounding to be used"""
120 rounding_mode = NpuRoundingMode.TFL
121 if op.type == Op.ResizeBilinear:
Dwight Lidman9d243932021-08-10 12:53:12 +0200122 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100123 elif (
124 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
125 and op.ifm.dtype == DataType.int16
126 ):
127 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100128 elif (
129 not fused_quantize
130 and op.type.is_avgpool_op()
131 and op.memory_function == Op.ConcatSliceWrite
132 and op.kernel.elements_wh() == 1
133 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100134 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100135 if op.rounding_mode is not None:
136 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100137 return rounding_mode
138
139
140def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
141 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
142 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100143 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100144
145 # Check if this is for horizontal ifm streaming
146 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100147 top = cmd.pad_top
148 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100149
Tim Hall3751aa42021-12-16 13:17:29 +0000150 # the ifm box coordinate range depends upon whether the primary op was combined with a split slice read
151 ifm_read_offset = primary_op.read_offsets[0]
152 ifm_read_shape = primary_op.read_shapes[0]
153 if ifm_read_offset is None or len(ifm_read_offset) < 2:
154 box_start_coord_min = 0
155 box_end_coord_max = cmd.ps.ifm_shapes[0].width
156 else:
157 box_start_coord_min = ifm_read_offset[-2]
158 box_end_coord_max = ifm_read_shape[-2]
159
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100160 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
161 # because of activation function needed to be fused.
Tim Hall3751aa42021-12-16 13:17:29 +0000162 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > box_start_coord_min:
163 left = 0
164 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < box_end_coord_max:
165 right = 0
Louis Verhaard69b31762020-11-17 09:45:20 +0100166 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100167
168
Louis Verhaard024c3552021-03-17 14:26:34 +0100169def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000170 base_ptr_idx_map = {
171 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
172 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
173 MemType.Scratch: BasePointerIndex.ScratchTensor,
174 }
175
176 if arch.is_spilling_enabled():
177 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100178 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000179 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
180
Louis Verhaard024c3552021-03-17 14:26:34 +0100181 return base_ptr_idx_map[mem_type].value
182
183
184def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
185 """Returns map region -> max size of the region in bytes"""
186 mem_limits = dict()
187 for mem_type in MemType.all():
188 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
189 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
190 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100191
192
193def get_upscale(op: Operation) -> NpuResamplingMode:
194 upscale = NpuResamplingMode.NONE
195 if op.type == Op.ResizeBilinear:
196 # perform nearest neighbor upscale
197 upscale = NpuResamplingMode.NEAREST
198 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
199 # perform insert zero upscale
200 upscale = NpuResamplingMode.TRANSPOSE
201 return upscale
202
203
Louis Verhaarde91b5312022-01-21 13:38:50 +0100204def get_double_buffer_offset(arch: ArchitectureFeatures, range_index: int, core: int) -> int:
205 """Returns 0 if the first half of a double buffer should be used, 1 if the second half should be used"""
206 return ((range_index - core) // arch.ncores) % 2
207
208
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100209def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
210 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100211 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100212 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100213 block = ofm_box.get_block()
214 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100215
216
217def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
218 """Checks if quantization should use 0 as zero point"""
219 if tens.dtype == DataType.int32 and is_ifm_tensor:
220 return True
221 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
222 return False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200223 if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
224 return False
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100225 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
226 forced_ofm_quantization = ps.primary_op.forced_output_quantization
227 use_0 = (
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200228 (
229 ps.primary_op.activation is None
230 or forced_ofm_quantization is not None
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200231 or (
232 ps.primary_op.type.is_avgpool_op()
233 and ps.primary_op.activation.op_type.is_relu_op()
234 and not ps.primary_op.rescale
235 )
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200236 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100237 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
238 and not fused_quantize
239 )
240 return use_0
241
242
243def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
244 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100245 op = ps.primary_op
246 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
247 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100248 return None
249 if use_zero_point_0(ps, tens, True):
250 zero_point = 0
251 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100252 zero_point = int(ifm_quant.zero_point)
253 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100254
255
256def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
257 """Gets quantization for OFM"""
258 op = ps.primary_op
259 # Check if operation's output quantization is should be used instead of the output tensor's quantization
260 # (used in LUTs)
261 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
262 if ofm_quant is None:
263 return None
264 if use_zero_point_0(ps, tens, False):
265 zero_point = 0
266 else:
267 zero_point = int(ofm_quant.zero_point)
268 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
269
270
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100271def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100272 """Creates feature map with common fields populated"""
273 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100274 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100275 fm.data_type = dtype_map[tens.dtype]
276 if tens.format == TensorFormat.NHWC:
277 fm.layout = NpuLayout.NHWC
278 elif tens.format == TensorFormat.NHCWB16:
279 fm.layout = NpuLayout.NHCWB16
280 else:
281 assert 0, "Incorrect tensor format"
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100282 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
283 box.start_coord, box.end_coord, op_shape4D
284 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100285 for idx, addr in enumerate(addresses):
286 if addr is None:
287 addresses[idx] = 0
288 fm.tiles = NpuTileBox(
289 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
290 )
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100291 strides = tens.get_strides(shape4D=op_shape4D)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100292 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
293 return fm
294
295
Tim Halld784af72021-06-08 21:25:57 +0100296def create_weights(
297 weight_tensor: Tensor, weight_box: Box, scale_tensor: Tensor, arch: ArchitectureFeatures
298) -> List[NpuAddressRange]:
Tim Halld8339a72021-05-27 18:49:40 +0100299 """Returns address ranges for weights and scales"""
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100300 weights = []
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100301 biases = []
Tim Halld784af72021-06-08 21:25:57 +0100302 shared_region = get_region(weight_tensor.mem_type, arch)
303 scale_region = scale_tensor and get_region(scale_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100304
Tim Halld8339a72021-05-27 18:49:40 +0100305 w_tensor_src = weight_tensor
306 if weight_tensor.src_tensor:
307 w_tensor_src = weight_tensor.src_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100308
Tim Halld8339a72021-05-27 18:49:40 +0100309 core_offset = 0
310 for core in range(0, arch.ncores):
311 # Get weight range per core
312 key = WeightKey(core, weight_box.start_coord[-1])
313 if key in w_tensor_src.encoded_ranges:
314 weight_range = w_tensor_src.encoded_ranges[key]
315 if weight_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
316 assert weight_tensor != w_tensor_src
317 # Double buffered inside weight_tensor
Louis Verhaarde91b5312022-01-21 13:38:50 +0100318 address = weight_tensor.address + core_offset
319 address += get_double_buffer_offset(arch, weight_range.index, core) * w_tensor_src.max_range_bytes
Tim Halld8339a72021-05-27 18:49:40 +0100320 core_offset += round_up(weight_range.total_bytes, 16)
321 else:
322 if weight_tensor == w_tensor_src:
323 # Straight from source tensor
324 address = weight_tensor.address + weight_range.offset
325 else:
326 # Single buffered inside weight tensor
327 address = weight_tensor.address + core_offset
328 core_offset += round_up(weight_range.total_bytes, 16)
329
330 # Location of weights in tensor
331 addr_range = NpuAddressRange(
Tim Halld784af72021-06-08 21:25:57 +0100332 shared_region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16)
Tim Halld8339a72021-05-27 18:49:40 +0100333 )
334 weights.append(addr_range)
Tim Halld784af72021-06-08 21:25:57 +0100335
336 # Location of standalone scales or combined weights tensor scales
337 if scale_tensor:
338 assert scale_tensor.src_tensor is None # Must be standalone
339 scale_range = scale_tensor.encoded_ranges[key]
340 address = scale_tensor.address + scale_range.offset
341 addr_range = NpuAddressRange(scale_region, int(address), round_up(int(scale_range.scale_bytes), 16))
342 else:
343 addr_range = NpuAddressRange(shared_region, int(address), round_up(int(weight_range.scale_bytes), 16))
344
Tim Halld8339a72021-05-27 18:49:40 +0100345 biases.append(addr_range)
346
347 return weights, biases
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100348
349
350def create_npu_activation(op: Operation) -> NpuActivation:
351 """Creates fused activation function"""
352 if op.activation is None:
353 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
354 faf = op.activation.op_type
355 act_op = NpuActivationOp.NONE_OR_RELU
356 if faf == Op.Tanh:
357 act_op = NpuActivationOp.TANH
358 elif faf == Op.Sigmoid:
359 act_op = NpuActivationOp.SIGMOID
360 elif faf == Op.LUT:
361 act_op = NpuActivationOp.TABLE_LOOKUP
362 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000363 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100364
365 act = NpuActivation(act_op)
366 act.min = op.activation.min
367 act.max = op.activation.max
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200368 if act_op is NpuActivationOp.NONE_OR_RELU and op.type.is_avgpool_op() and not op.rescale:
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200369 quant = op.ofm.quantization
370 if quant and quant.zero_point: # Zero point is not 0
371 scale_f32 = 1 if quant.scale_f32 is None else quant.scale_f32
372 zero_point = quant.zero_point
373 if act.min is not None:
374 act.min = scale_f32 * quantise_float32(act.min, scale_f32, zero_point)
375 if act.max is not None:
376 act.max = scale_f32 * quantise_float32(act.max, scale_f32, zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100377 act.lookup_table_index = op.activation.lut_index
378 return act
379
380
381def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
382 """Sets common fields of the given operation"""
383 ps = cmd.ps
384 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100385
386 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100387 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100388 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100389
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100390 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100391 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100392 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100393
394 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100395 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100396 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100397 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
398
399 if cmd.weight_tensor is not None:
Tim Halld784af72021-06-08 21:25:57 +0100400 npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, cmd.scale_tensor, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100401 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100402 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
403 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100404 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
405
406 if not op.type.is_elementwise_op():
407 npu_op.padding = create_padding(cmd, op)
408 npu_op.kernel = to_npu_kernel(op.kernel)
409 npu_op.ifm_upscale = get_upscale(op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100410 return npu_op
411
412
413def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
414 """Converts the command to NpuConv2DOperation"""
415 npu_op = NpuConv2DOperation()
416 set_common_op_fields(npu_op, cmd, arch)
417 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
418 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
419 else:
Tim Halld8339a72021-05-27 18:49:40 +0100420 if cmd.weight_tensor.src_tensor:
421 npu_op.block_traversal = cmd.weight_tensor.src_tensor.hw_traversal
422 else:
423 npu_op.block_traversal = cmd.weight_tensor.hw_traversal
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100424 return npu_op
425
426
427def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
428 """Converts the command to NpuConvDepthWiseOperation"""
429 npu_op = NpuConvDepthWiseOperation()
430 set_common_op_fields(npu_op, cmd, arch)
431 return npu_op
432
433
434def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
435 """Converts the command to NpuPoolingOperation"""
436 ps = cmd.ps
437 op = ps.primary_op
438 pool_op = NpuPoolingOp.AVERAGE
439 if op.type.is_maxpool_op():
440 pool_op = NpuPoolingOp.MAX
441 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
442 pool_op = NpuPoolingOp.AVERAGE
443 elif op.type == Op.ReduceSum:
444 pool_op = NpuPoolingOp.REDUCE_SUM
445 else:
446 assert 0, f"Unknown pool type {op.type}"
447 npu_op = NpuPoolingOperation(pool_op)
448 set_common_op_fields(npu_op, cmd, arch)
449 # Pooling specific info
Dwight Lidman4f728c02020-12-17 15:14:45 +0100450 npu_op.rescale = op.rescale
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200451 if op.explicit_scaling:
452 # Note: reuse of rescale for explicit scaling to not expose this in the external API
453 assert npu_op.rescale is None
454 npu_op.rescale = op.explicit_scaling
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100455 return npu_op
456
457
458def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
459 """Converts the command to NpuElementWiseOperation"""
460 ps = cmd.ps
461 op = ps.primary_op
462 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
463 elemwise_op = elementwise_op_map[op.type]
464 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100465
Louis Verhaard1e170182020-11-26 11:42:04 +0100466 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100467 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
468 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
469 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100470 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
471 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
472 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100473 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100474 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100475 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100476 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
477 if cmd.ifm2_tensor.shape == []:
478 # scalar
James Peet7519d502021-07-19 16:47:58 +0100479 npu_op.ifm2_scalar = cmd.ifm2_tensor.get_scalar()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100480 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
481 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100482 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100483 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100484 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100485 set_common_op_fields(npu_op, cmd, arch)
486 # Check if output scale needs to be overridden
487 output_scale = None
488 if op.type == Op.Add and "resizebilinear" in op.attrs:
489 # Force output scale same as the input scale for
490 # resizebilinear 1x1 that is converted to add
491 output_scale = npu_op.ifm2.quantization.scale_f32
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100492 if op.type == Op.Abs:
493 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100494 if op.type == Op.LeakyRelu:
495 output_scale = op.attrs["alpha"]
Patrik Gustavssonb081d672021-08-25 13:49:25 +0200496 if op.type in (Op.RescaleAdd, Op.RescaleMul):
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100497 assert op.rescale is not None, f"{op.type} must have rescale"
498 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100499 if op.type in (Op.Add, Op.Mul, Op.Sub):
500 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
501 output_scale = 1 / 0x3000
502 if output_scale is not None:
503 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
504 return npu_op
505
506
507def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
508 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100509 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100510 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100511 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100512 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100513 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100514
Tim Halld8339a72021-05-27 18:49:40 +0100515 if cmd.in_tensor.purpose == TensorPurpose.Weights:
516 # Get weight range per core
517 sz = 0
518 for core in range(0, arch.ncores):
519 key = WeightKey(core, cmd.box.start_coord[-1])
520 if key in cmd.in_tensor.encoded_ranges:
521 weight_range = cmd.in_tensor.encoded_ranges[key]
522 sz += round_up(weight_range.total_bytes, 16)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100523
Tim Halld8339a72021-05-27 18:49:40 +0100524 if core == 0:
525 weight_range = cmd.in_tensor.encoded_ranges[key]
526 src_addr = cmd.in_tensor.address + weight_range.offset
527
528 if cmd.out_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
529 dest_addr = cmd.out_tensor.address + cmd.in_tensor.max_range_bytes * (
Louis Verhaarde91b5312022-01-21 13:38:50 +0100530 get_double_buffer_offset(arch, weight_range.index, core)
Tim Halld8339a72021-05-27 18:49:40 +0100531 )
532 else:
533 dest_addr = cmd.out_tensor.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100534 else:
Tim Halld8339a72021-05-27 18:49:40 +0100535 start_coord = cmd.box.start_coord
536 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
537 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100538 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
539 src = NpuAddressRange(src_region, int(src_addr), int(sz))
540 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
541 return NpuDmaOperation(src, dest)
542
543
544def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
545 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100546 npu_op: NpuOperation
547 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100548 npu_op = create_dma_op(cmd, arch)
Dwight Lidman9b43f842020-12-08 17:56:44 +0100549 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100550 npu_block_type = cmd.ps.primary_op.type.npu_block_type
551 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
552 npu_op = create_npu_conv2d_op(cmd, arch)
553 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
554 npu_op = create_npu_conv_depthwise_op(cmd, arch)
555 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
556 npu_op = create_npu_pool_op(cmd, arch)
557 elif npu_block_type == NpuBlockType.ElementWise:
558 npu_op = create_npu_elementwise_op(cmd, arch)
559 else:
560 assert 0, f"Unknown command type {npu_block_type}"
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100561 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100562
563
564def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
565 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
566 # Convert high level command stream to list of NpuOperation
567 npu_op_list = []
568 npu_op_to_cmd = dict() # map from npu op to high level command
569 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100570 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100571 print("Warning: Skipping register command stream generation for", cmd.ps)
572 else:
573 npu_op = convert_command_to_npu_op(cmd, arch)
574 npu_op_list.append(npu_op)
575 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100576 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100577 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100578 if len(sg.high_level_command_stream) > 0:
579 stream_id = DebugDatabase.add_stream(sg)
580 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100581
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100582 def add_to_debug_db(npu_op: NpuOperation, offset: int):
583 """Adds info to the debug database"""
584 if not isinstance(npu_op, NpuDmaOperation):
585 cmd = npu_op_to_cmd[npu_op]
586 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100587
Louis Verhaard024c3552021-03-17 14:26:34 +0100588 sg.register_command_stream = generate_command_stream(
589 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
590 )