blob: 9abfbd403f9d33485c6415c2697f9fb314d0420a [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Louis Verhaard024c3552021-03-17 14:26:34 +010020from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010021from typing import List
22from typing import Optional
23
24from .api import NpuActivation
25from .api import NpuActivationOp
26from .api import NpuAddressRange
27from .api import NpuBlockOperation
28from .api import NpuBlockTraversal
29from .api import NpuConv2DOperation
30from .api import NpuConvDepthWiseOperation
31from .api import NpuDataType
32from .api import NpuDmaOperation
33from .api import NpuElementWiseOp
34from .api import NpuElementWiseOperation
35from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010036from .api import NpuLayout
37from .api import NpuOperation
38from .api import NpuPadding
39from .api import NpuPoolingOp
40from .api import NpuPoolingOperation
41from .api import NpuQuantization
42from .api import NpuResamplingMode
43from .api import NpuRoundingMode
44from .api import NpuShape3D
45from .api import NpuTileBox
46from .architecture_features import ArchitectureFeatures
47from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010048from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000049from .errors import UnsupportedFeatureError
Louis Verhaarde8a5a782020-11-02 18:04:27 +010050from .high_level_command_stream import Box
51from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010052from .high_level_command_stream import DMA
53from .high_level_command_stream import NpuStripe
Fredrik Svedberg838df0a2021-09-17 16:29:22 +020054from .numeric_util import quantise_float32
Tim Halld8339a72021-05-27 18:49:40 +010055from .numeric_util import round_up
Louis Verhaarde8a5a782020-11-02 18:04:27 +010056from .operation import NpuBlockType
57from .operation import Op
58from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010059from .register_command_stream_generator import generate_command_stream
60from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010061from .register_command_stream_util import to_npu_kernel
62from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000063from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010064from .tensor import MemType
65from .tensor import Tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010066from .tensor import TensorFormat
67from .tensor import TensorPurpose
Tim Halld8339a72021-05-27 18:49:40 +010068from .tensor import TensorSubPurpose
69from .weight_compressor import WeightKey
Louis Verhaarde8a5a782020-11-02 18:04:27 +010070
71
Louis Verhaarde8a5a782020-11-02 18:04:27 +010072class BasePointerIndex(IntEnum):
73 WeightTensor = 0 # base address index for the Weight tensor
74 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
75 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010076
77
78dtype_map = {
79 DataType.uint8: NpuDataType.UINT8,
80 DataType.int8: NpuDataType.INT8,
81 DataType.uint16: NpuDataType.UINT16,
82 DataType.int16: NpuDataType.INT16,
83 DataType.int32: NpuDataType.INT32,
84}
85
86
Louis Verhaarde8a5a782020-11-02 18:04:27 +010087# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
88elementwise_op_map = {
89 Op.Mul: NpuElementWiseOp.MUL,
Patrik Gustavssonb081d672021-08-25 13:49:25 +020090 Op.RescaleMul: NpuElementWiseOp.MUL,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010091 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010092 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010093 Op.Sub: NpuElementWiseOp.SUB,
94 Op.Minimum: NpuElementWiseOp.MIN,
95 Op.Maximum: NpuElementWiseOp.MAX,
96 Op.LeakyRelu: NpuElementWiseOp.LRELU,
97 Op.Abs: NpuElementWiseOp.ABS,
98 Op.CLZ: NpuElementWiseOp.CLZ,
99 Op.SHR: NpuElementWiseOp.SHR,
100 Op.SHL: NpuElementWiseOp.SHL,
101}
102
103
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100104def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
105 if ifm_shape == []:
106 # Scalar needs to be in IFM2
107 return False
108 if ifm2_shape == []:
109 return True
110
111 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
112 if ifm != ifm2 and ifm == 1:
113 # Broadcasted FM needs to be in IFM2
114 return False
115 return True
116
117
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100118def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100119 """Specifies type of rounding to be used"""
120 rounding_mode = NpuRoundingMode.TFL
121 if op.type == Op.ResizeBilinear:
Dwight Lidman9d243932021-08-10 12:53:12 +0200122 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100123 elif (
124 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
125 and op.ifm.dtype == DataType.int16
126 ):
127 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100128 elif (
129 not fused_quantize
130 and op.type.is_avgpool_op()
131 and op.memory_function == Op.ConcatSliceWrite
132 and op.kernel.elements_wh() == 1
133 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100134 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100135 if op.rounding_mode is not None:
136 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100137 return rounding_mode
138
139
140def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
141 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
142 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100143 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100144
145 # Check if this is for horizontal ifm streaming
146 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100147 top = cmd.pad_top
148 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100149
Tim Hall3751aa42021-12-16 13:17:29 +0000150 # the ifm box coordinate range depends upon whether the primary op was combined with a split slice read
151 ifm_read_offset = primary_op.read_offsets[0]
152 ifm_read_shape = primary_op.read_shapes[0]
153 if ifm_read_offset is None or len(ifm_read_offset) < 2:
154 box_start_coord_min = 0
155 box_end_coord_max = cmd.ps.ifm_shapes[0].width
156 else:
157 box_start_coord_min = ifm_read_offset[-2]
158 box_end_coord_max = ifm_read_shape[-2]
159
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100160 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
161 # because of activation function needed to be fused.
Tim Hall3751aa42021-12-16 13:17:29 +0000162 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > box_start_coord_min:
163 left = 0
164 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < box_end_coord_max:
165 right = 0
Louis Verhaard69b31762020-11-17 09:45:20 +0100166 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100167
168
Louis Verhaard024c3552021-03-17 14:26:34 +0100169def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000170 base_ptr_idx_map = {
171 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
172 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
173 MemType.Scratch: BasePointerIndex.ScratchTensor,
174 }
175
176 if arch.is_spilling_enabled():
177 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100178 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000179 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
180
Louis Verhaard024c3552021-03-17 14:26:34 +0100181 return base_ptr_idx_map[mem_type].value
182
183
184def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
185 """Returns map region -> max size of the region in bytes"""
186 mem_limits = dict()
187 for mem_type in MemType.all():
188 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
189 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
190 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100191
192
193def get_upscale(op: Operation) -> NpuResamplingMode:
194 upscale = NpuResamplingMode.NONE
195 if op.type == Op.ResizeBilinear:
196 # perform nearest neighbor upscale
197 upscale = NpuResamplingMode.NEAREST
198 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
199 # perform insert zero upscale
200 upscale = NpuResamplingMode.TRANSPOSE
201 return upscale
202
203
204def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
205 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100206 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100207 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100208 block = ofm_box.get_block()
209 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100210
211
212def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
213 """Checks if quantization should use 0 as zero point"""
214 if tens.dtype == DataType.int32 and is_ifm_tensor:
215 return True
216 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
217 return False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200218 if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
219 return False
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100220 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
221 forced_ofm_quantization = ps.primary_op.forced_output_quantization
222 use_0 = (
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200223 (
224 ps.primary_op.activation is None
225 or forced_ofm_quantization is not None
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200226 or (
227 ps.primary_op.type.is_avgpool_op()
228 and ps.primary_op.activation.op_type.is_relu_op()
229 and not ps.primary_op.rescale
230 )
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200231 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100232 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
233 and not fused_quantize
234 )
235 return use_0
236
237
238def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
239 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100240 op = ps.primary_op
241 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
242 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100243 return None
244 if use_zero_point_0(ps, tens, True):
245 zero_point = 0
246 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100247 zero_point = int(ifm_quant.zero_point)
248 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100249
250
251def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
252 """Gets quantization for OFM"""
253 op = ps.primary_op
254 # Check if operation's output quantization is should be used instead of the output tensor's quantization
255 # (used in LUTs)
256 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
257 if ofm_quant is None:
258 return None
259 if use_zero_point_0(ps, tens, False):
260 zero_point = 0
261 else:
262 zero_point = int(ofm_quant.zero_point)
263 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
264
265
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100266def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100267 """Creates feature map with common fields populated"""
268 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100269 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100270 fm.data_type = dtype_map[tens.dtype]
271 if tens.format == TensorFormat.NHWC:
272 fm.layout = NpuLayout.NHWC
273 elif tens.format == TensorFormat.NHCWB16:
274 fm.layout = NpuLayout.NHCWB16
275 else:
276 assert 0, "Incorrect tensor format"
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100277 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
278 box.start_coord, box.end_coord, op_shape4D
279 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100280 for idx, addr in enumerate(addresses):
281 if addr is None:
282 addresses[idx] = 0
283 fm.tiles = NpuTileBox(
284 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
285 )
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100286 strides = tens.get_strides(shape4D=op_shape4D)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100287 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
288 return fm
289
290
Tim Halld784af72021-06-08 21:25:57 +0100291def create_weights(
292 weight_tensor: Tensor, weight_box: Box, scale_tensor: Tensor, arch: ArchitectureFeatures
293) -> List[NpuAddressRange]:
Tim Halld8339a72021-05-27 18:49:40 +0100294 """Returns address ranges for weights and scales"""
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100295 weights = []
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100296 biases = []
Tim Halld784af72021-06-08 21:25:57 +0100297 shared_region = get_region(weight_tensor.mem_type, arch)
298 scale_region = scale_tensor and get_region(scale_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100299
Tim Halld8339a72021-05-27 18:49:40 +0100300 w_tensor_src = weight_tensor
301 if weight_tensor.src_tensor:
302 w_tensor_src = weight_tensor.src_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100303
Tim Halld8339a72021-05-27 18:49:40 +0100304 core_offset = 0
305 for core in range(0, arch.ncores):
306 # Get weight range per core
307 key = WeightKey(core, weight_box.start_coord[-1])
308 if key in w_tensor_src.encoded_ranges:
309 weight_range = w_tensor_src.encoded_ranges[key]
310 if weight_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
311 assert weight_tensor != w_tensor_src
312 # Double buffered inside weight_tensor
313 address = weight_tensor.address + w_tensor_src.max_range_bytes * ((weight_range.index - core) % 2)
314 address += core_offset
315 core_offset += round_up(weight_range.total_bytes, 16)
316 else:
317 if weight_tensor == w_tensor_src:
318 # Straight from source tensor
319 address = weight_tensor.address + weight_range.offset
320 else:
321 # Single buffered inside weight tensor
322 address = weight_tensor.address + core_offset
323 core_offset += round_up(weight_range.total_bytes, 16)
324
325 # Location of weights in tensor
326 addr_range = NpuAddressRange(
Tim Halld784af72021-06-08 21:25:57 +0100327 shared_region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16)
Tim Halld8339a72021-05-27 18:49:40 +0100328 )
329 weights.append(addr_range)
Tim Halld784af72021-06-08 21:25:57 +0100330
331 # Location of standalone scales or combined weights tensor scales
332 if scale_tensor:
333 assert scale_tensor.src_tensor is None # Must be standalone
334 scale_range = scale_tensor.encoded_ranges[key]
335 address = scale_tensor.address + scale_range.offset
336 addr_range = NpuAddressRange(scale_region, int(address), round_up(int(scale_range.scale_bytes), 16))
337 else:
338 addr_range = NpuAddressRange(shared_region, int(address), round_up(int(weight_range.scale_bytes), 16))
339
Tim Halld8339a72021-05-27 18:49:40 +0100340 biases.append(addr_range)
341
342 return weights, biases
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100343
344
345def create_npu_activation(op: Operation) -> NpuActivation:
346 """Creates fused activation function"""
347 if op.activation is None:
348 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
349 faf = op.activation.op_type
350 act_op = NpuActivationOp.NONE_OR_RELU
351 if faf == Op.Tanh:
352 act_op = NpuActivationOp.TANH
353 elif faf == Op.Sigmoid:
354 act_op = NpuActivationOp.SIGMOID
355 elif faf == Op.LUT:
356 act_op = NpuActivationOp.TABLE_LOOKUP
357 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000358 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100359
360 act = NpuActivation(act_op)
361 act.min = op.activation.min
362 act.max = op.activation.max
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200363 if act_op is NpuActivationOp.NONE_OR_RELU and op.type.is_avgpool_op() and not op.rescale:
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200364 quant = op.ofm.quantization
365 if quant and quant.zero_point: # Zero point is not 0
366 scale_f32 = 1 if quant.scale_f32 is None else quant.scale_f32
367 zero_point = quant.zero_point
368 if act.min is not None:
369 act.min = scale_f32 * quantise_float32(act.min, scale_f32, zero_point)
370 if act.max is not None:
371 act.max = scale_f32 * quantise_float32(act.max, scale_f32, zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100372 act.lookup_table_index = op.activation.lut_index
373 return act
374
375
376def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
377 """Sets common fields of the given operation"""
378 ps = cmd.ps
379 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100380
381 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100382 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100383 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100384
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100385 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100386 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100387 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100388
389 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100390 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100391 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100392 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
393
394 if cmd.weight_tensor is not None:
Tim Halld784af72021-06-08 21:25:57 +0100395 npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, cmd.scale_tensor, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100396 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100397 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
398 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100399 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
400
401 if not op.type.is_elementwise_op():
402 npu_op.padding = create_padding(cmd, op)
403 npu_op.kernel = to_npu_kernel(op.kernel)
404 npu_op.ifm_upscale = get_upscale(op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100405 return npu_op
406
407
408def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
409 """Converts the command to NpuConv2DOperation"""
410 npu_op = NpuConv2DOperation()
411 set_common_op_fields(npu_op, cmd, arch)
412 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
413 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
414 else:
Tim Halld8339a72021-05-27 18:49:40 +0100415 if cmd.weight_tensor.src_tensor:
416 npu_op.block_traversal = cmd.weight_tensor.src_tensor.hw_traversal
417 else:
418 npu_op.block_traversal = cmd.weight_tensor.hw_traversal
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100419 return npu_op
420
421
422def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
423 """Converts the command to NpuConvDepthWiseOperation"""
424 npu_op = NpuConvDepthWiseOperation()
425 set_common_op_fields(npu_op, cmd, arch)
426 return npu_op
427
428
429def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
430 """Converts the command to NpuPoolingOperation"""
431 ps = cmd.ps
432 op = ps.primary_op
433 pool_op = NpuPoolingOp.AVERAGE
434 if op.type.is_maxpool_op():
435 pool_op = NpuPoolingOp.MAX
436 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
437 pool_op = NpuPoolingOp.AVERAGE
438 elif op.type == Op.ReduceSum:
439 pool_op = NpuPoolingOp.REDUCE_SUM
440 else:
441 assert 0, f"Unknown pool type {op.type}"
442 npu_op = NpuPoolingOperation(pool_op)
443 set_common_op_fields(npu_op, cmd, arch)
444 # Pooling specific info
Dwight Lidman4f728c02020-12-17 15:14:45 +0100445 npu_op.rescale = op.rescale
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200446 if op.explicit_scaling:
447 # Note: reuse of rescale for explicit scaling to not expose this in the external API
448 assert npu_op.rescale is None
449 npu_op.rescale = op.explicit_scaling
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100450 return npu_op
451
452
453def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
454 """Converts the command to NpuElementWiseOperation"""
455 ps = cmd.ps
456 op = ps.primary_op
457 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
458 elemwise_op = elementwise_op_map[op.type]
459 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100460
Louis Verhaard1e170182020-11-26 11:42:04 +0100461 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100462 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
463 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
464 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100465 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
466 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
467 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100468 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100469 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100470 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100471 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
472 if cmd.ifm2_tensor.shape == []:
473 # scalar
James Peet7519d502021-07-19 16:47:58 +0100474 npu_op.ifm2_scalar = cmd.ifm2_tensor.get_scalar()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100475 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
476 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100477 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100478 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100479 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100480 set_common_op_fields(npu_op, cmd, arch)
481 # Check if output scale needs to be overridden
482 output_scale = None
483 if op.type == Op.Add and "resizebilinear" in op.attrs:
484 # Force output scale same as the input scale for
485 # resizebilinear 1x1 that is converted to add
486 output_scale = npu_op.ifm2.quantization.scale_f32
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100487 if op.type == Op.Abs:
488 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100489 if op.type == Op.LeakyRelu:
490 output_scale = op.attrs["alpha"]
Patrik Gustavssonb081d672021-08-25 13:49:25 +0200491 if op.type in (Op.RescaleAdd, Op.RescaleMul):
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100492 assert op.rescale is not None, f"{op.type} must have rescale"
493 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100494 if op.type in (Op.Add, Op.Mul, Op.Sub):
495 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
496 output_scale = 1 / 0x3000
497 if output_scale is not None:
498 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
499 return npu_op
500
501
502def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
503 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100504 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100505 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100506 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100507 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100508 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100509
Tim Halld8339a72021-05-27 18:49:40 +0100510 if cmd.in_tensor.purpose == TensorPurpose.Weights:
511 # Get weight range per core
512 sz = 0
513 for core in range(0, arch.ncores):
514 key = WeightKey(core, cmd.box.start_coord[-1])
515 if key in cmd.in_tensor.encoded_ranges:
516 weight_range = cmd.in_tensor.encoded_ranges[key]
517 sz += round_up(weight_range.total_bytes, 16)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100518
Tim Halld8339a72021-05-27 18:49:40 +0100519 if core == 0:
520 weight_range = cmd.in_tensor.encoded_ranges[key]
521 src_addr = cmd.in_tensor.address + weight_range.offset
522
523 if cmd.out_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
524 dest_addr = cmd.out_tensor.address + cmd.in_tensor.max_range_bytes * (
525 (weight_range.index - core) % 2
526 )
527 else:
528 dest_addr = cmd.out_tensor.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100529 else:
Tim Halld8339a72021-05-27 18:49:40 +0100530 start_coord = cmd.box.start_coord
531 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
532 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100533 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
534 src = NpuAddressRange(src_region, int(src_addr), int(sz))
535 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
536 return NpuDmaOperation(src, dest)
537
538
539def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
540 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100541 npu_op: NpuOperation
542 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100543 npu_op = create_dma_op(cmd, arch)
Dwight Lidman9b43f842020-12-08 17:56:44 +0100544 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100545 npu_block_type = cmd.ps.primary_op.type.npu_block_type
546 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
547 npu_op = create_npu_conv2d_op(cmd, arch)
548 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
549 npu_op = create_npu_conv_depthwise_op(cmd, arch)
550 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
551 npu_op = create_npu_pool_op(cmd, arch)
552 elif npu_block_type == NpuBlockType.ElementWise:
553 npu_op = create_npu_elementwise_op(cmd, arch)
554 else:
555 assert 0, f"Unknown command type {npu_block_type}"
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100556 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100557
558
559def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
560 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
561 # Convert high level command stream to list of NpuOperation
562 npu_op_list = []
563 npu_op_to_cmd = dict() # map from npu op to high level command
564 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100565 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100566 print("Warning: Skipping register command stream generation for", cmd.ps)
567 else:
568 npu_op = convert_command_to_npu_op(cmd, arch)
569 npu_op_list.append(npu_op)
570 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100571 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100572 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100573 if len(sg.high_level_command_stream) > 0:
574 stream_id = DebugDatabase.add_stream(sg)
575 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100576
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100577 def add_to_debug_db(npu_op: NpuOperation, offset: int):
578 """Adds info to the debug database"""
579 if not isinstance(npu_op, NpuDmaOperation):
580 cmd = npu_op_to_cmd[npu_op]
581 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100582
Louis Verhaard024c3552021-03-17 14:26:34 +0100583 sg.register_command_stream = generate_command_stream(
584 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
585 )