blob: 318960ec61befacf0be67a7db1f0f6a68a7ff035 [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Louis Verhaard024c3552021-03-17 14:26:34 +010020from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010021from typing import List
22from typing import Optional
23
24from .api import NpuActivation
25from .api import NpuActivationOp
26from .api import NpuAddressRange
27from .api import NpuBlockOperation
28from .api import NpuBlockTraversal
29from .api import NpuConv2DOperation
30from .api import NpuConvDepthWiseOperation
31from .api import NpuDataType
32from .api import NpuDmaOperation
33from .api import NpuElementWiseOp
34from .api import NpuElementWiseOperation
35from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010036from .api import NpuLayout
37from .api import NpuOperation
38from .api import NpuPadding
39from .api import NpuPoolingOp
40from .api import NpuPoolingOperation
41from .api import NpuQuantization
42from .api import NpuResamplingMode
43from .api import NpuRoundingMode
44from .api import NpuShape3D
45from .api import NpuTileBox
46from .architecture_features import ArchitectureFeatures
47from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010048from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000049from .errors import UnsupportedFeatureError
Louis Verhaarde8a5a782020-11-02 18:04:27 +010050from .high_level_command_stream import Box
51from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010052from .high_level_command_stream import DMA
53from .high_level_command_stream import NpuStripe
Fredrik Svedberg838df0a2021-09-17 16:29:22 +020054from .numeric_util import quantise_float32
Tim Halld8339a72021-05-27 18:49:40 +010055from .numeric_util import round_up
Louis Verhaarde8a5a782020-11-02 18:04:27 +010056from .operation import NpuBlockType
57from .operation import Op
58from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010059from .register_command_stream_generator import generate_command_stream
60from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010061from .register_command_stream_util import to_npu_kernel
62from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000063from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010064from .tensor import MemType
65from .tensor import Tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010066from .tensor import TensorFormat
67from .tensor import TensorPurpose
Tim Halld8339a72021-05-27 18:49:40 +010068from .tensor import TensorSubPurpose
69from .weight_compressor import WeightKey
Louis Verhaarde8a5a782020-11-02 18:04:27 +010070
71
Louis Verhaarde8a5a782020-11-02 18:04:27 +010072class BasePointerIndex(IntEnum):
73 WeightTensor = 0 # base address index for the Weight tensor
74 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
75 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010076
77
78dtype_map = {
79 DataType.uint8: NpuDataType.UINT8,
80 DataType.int8: NpuDataType.INT8,
81 DataType.uint16: NpuDataType.UINT16,
82 DataType.int16: NpuDataType.INT16,
83 DataType.int32: NpuDataType.INT32,
84}
85
86
Louis Verhaarde8a5a782020-11-02 18:04:27 +010087# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
88elementwise_op_map = {
89 Op.Mul: NpuElementWiseOp.MUL,
Patrik Gustavssonb081d672021-08-25 13:49:25 +020090 Op.RescaleMul: NpuElementWiseOp.MUL,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010091 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010092 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010093 Op.Sub: NpuElementWiseOp.SUB,
94 Op.Minimum: NpuElementWiseOp.MIN,
95 Op.Maximum: NpuElementWiseOp.MAX,
96 Op.LeakyRelu: NpuElementWiseOp.LRELU,
97 Op.Abs: NpuElementWiseOp.ABS,
98 Op.CLZ: NpuElementWiseOp.CLZ,
99 Op.SHR: NpuElementWiseOp.SHR,
100 Op.SHL: NpuElementWiseOp.SHL,
101}
102
103
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100104def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
105 if ifm_shape == []:
106 # Scalar needs to be in IFM2
107 return False
108 if ifm2_shape == []:
109 return True
110
111 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
112 if ifm != ifm2 and ifm == 1:
113 # Broadcasted FM needs to be in IFM2
114 return False
115 return True
116
117
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100118def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100119 """Specifies type of rounding to be used"""
120 rounding_mode = NpuRoundingMode.TFL
121 if op.type == Op.ResizeBilinear:
Dwight Lidman9d243932021-08-10 12:53:12 +0200122 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100123 elif (
124 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
125 and op.ifm.dtype == DataType.int16
126 ):
127 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100128 elif (
129 not fused_quantize
130 and op.type.is_avgpool_op()
131 and op.memory_function == Op.ConcatSliceWrite
132 and op.kernel.elements_wh() == 1
133 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100134 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100135 if op.rounding_mode is not None:
136 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100137 return rounding_mode
138
139
140def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
141 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
142 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100143 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100144
145 # Check if this is for horizontal ifm streaming
146 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100147 top = cmd.pad_top
148 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100149
150 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
151 # because of activation function needed to be fused.
Henrik G Olsson5fabfca2021-04-15 17:57:26 +0200152 if not primary_op.attrs.get("force_padding"):
153 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > 0:
154 left = 0
155 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < cmd.ps.ifm_shapes[0].width:
156 right = 0
Louis Verhaard69b31762020-11-17 09:45:20 +0100157 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100158
159
Louis Verhaard024c3552021-03-17 14:26:34 +0100160def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000161 base_ptr_idx_map = {
162 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
163 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
164 MemType.Scratch: BasePointerIndex.ScratchTensor,
165 }
166
167 if arch.is_spilling_enabled():
168 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100169 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000170 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
171
Louis Verhaard024c3552021-03-17 14:26:34 +0100172 return base_ptr_idx_map[mem_type].value
173
174
175def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
176 """Returns map region -> max size of the region in bytes"""
177 mem_limits = dict()
178 for mem_type in MemType.all():
179 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
180 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
181 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100182
183
184def get_upscale(op: Operation) -> NpuResamplingMode:
185 upscale = NpuResamplingMode.NONE
186 if op.type == Op.ResizeBilinear:
187 # perform nearest neighbor upscale
188 upscale = NpuResamplingMode.NEAREST
189 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
190 # perform insert zero upscale
191 upscale = NpuResamplingMode.TRANSPOSE
192 return upscale
193
194
195def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
196 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100197 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100198 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100199 block = ofm_box.get_block()
200 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100201
202
203def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
204 """Checks if quantization should use 0 as zero point"""
205 if tens.dtype == DataType.int32 and is_ifm_tensor:
206 return True
207 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
208 return False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200209 if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
210 return False
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100211 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
212 forced_ofm_quantization = ps.primary_op.forced_output_quantization
213 use_0 = (
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200214 (
215 ps.primary_op.activation is None
216 or forced_ofm_quantization is not None
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200217 or (
218 ps.primary_op.type.is_avgpool_op()
219 and ps.primary_op.activation.op_type.is_relu_op()
220 and not ps.primary_op.rescale
221 )
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200222 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100223 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
224 and not fused_quantize
225 )
226 return use_0
227
228
229def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
230 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100231 op = ps.primary_op
232 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
233 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100234 return None
235 if use_zero_point_0(ps, tens, True):
236 zero_point = 0
237 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100238 zero_point = int(ifm_quant.zero_point)
239 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100240
241
242def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
243 """Gets quantization for OFM"""
244 op = ps.primary_op
245 # Check if operation's output quantization is should be used instead of the output tensor's quantization
246 # (used in LUTs)
247 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
248 if ofm_quant is None:
249 return None
250 if use_zero_point_0(ps, tens, False):
251 zero_point = 0
252 else:
253 zero_point = int(ofm_quant.zero_point)
254 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
255
256
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100257def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100258 """Creates feature map with common fields populated"""
259 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100260 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100261 fm.data_type = dtype_map[tens.dtype]
262 if tens.format == TensorFormat.NHWC:
263 fm.layout = NpuLayout.NHWC
264 elif tens.format == TensorFormat.NHCWB16:
265 fm.layout = NpuLayout.NHCWB16
266 else:
267 assert 0, "Incorrect tensor format"
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100268 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
269 box.start_coord, box.end_coord, op_shape4D
270 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100271 for idx, addr in enumerate(addresses):
272 if addr is None:
273 addresses[idx] = 0
274 fm.tiles = NpuTileBox(
275 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
276 )
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100277 strides = tens.get_strides(shape4D=op_shape4D)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100278 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
279 return fm
280
281
Tim Halld784af72021-06-08 21:25:57 +0100282def create_weights(
283 weight_tensor: Tensor, weight_box: Box, scale_tensor: Tensor, arch: ArchitectureFeatures
284) -> List[NpuAddressRange]:
Tim Halld8339a72021-05-27 18:49:40 +0100285 """Returns address ranges for weights and scales"""
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100286 weights = []
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100287 biases = []
Tim Halld784af72021-06-08 21:25:57 +0100288 shared_region = get_region(weight_tensor.mem_type, arch)
289 scale_region = scale_tensor and get_region(scale_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100290
Tim Halld8339a72021-05-27 18:49:40 +0100291 w_tensor_src = weight_tensor
292 if weight_tensor.src_tensor:
293 w_tensor_src = weight_tensor.src_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100294
Tim Halld8339a72021-05-27 18:49:40 +0100295 core_offset = 0
296 for core in range(0, arch.ncores):
297 # Get weight range per core
298 key = WeightKey(core, weight_box.start_coord[-1])
299 if key in w_tensor_src.encoded_ranges:
300 weight_range = w_tensor_src.encoded_ranges[key]
301 if weight_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
302 assert weight_tensor != w_tensor_src
303 # Double buffered inside weight_tensor
304 address = weight_tensor.address + w_tensor_src.max_range_bytes * ((weight_range.index - core) % 2)
305 address += core_offset
306 core_offset += round_up(weight_range.total_bytes, 16)
307 else:
308 if weight_tensor == w_tensor_src:
309 # Straight from source tensor
310 address = weight_tensor.address + weight_range.offset
311 else:
312 # Single buffered inside weight tensor
313 address = weight_tensor.address + core_offset
314 core_offset += round_up(weight_range.total_bytes, 16)
315
316 # Location of weights in tensor
317 addr_range = NpuAddressRange(
Tim Halld784af72021-06-08 21:25:57 +0100318 shared_region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16)
Tim Halld8339a72021-05-27 18:49:40 +0100319 )
320 weights.append(addr_range)
Tim Halld784af72021-06-08 21:25:57 +0100321
322 # Location of standalone scales or combined weights tensor scales
323 if scale_tensor:
324 assert scale_tensor.src_tensor is None # Must be standalone
325 scale_range = scale_tensor.encoded_ranges[key]
326 address = scale_tensor.address + scale_range.offset
327 addr_range = NpuAddressRange(scale_region, int(address), round_up(int(scale_range.scale_bytes), 16))
328 else:
329 addr_range = NpuAddressRange(shared_region, int(address), round_up(int(weight_range.scale_bytes), 16))
330
Tim Halld8339a72021-05-27 18:49:40 +0100331 biases.append(addr_range)
332
333 return weights, biases
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100334
335
336def create_npu_activation(op: Operation) -> NpuActivation:
337 """Creates fused activation function"""
338 if op.activation is None:
339 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
340 faf = op.activation.op_type
341 act_op = NpuActivationOp.NONE_OR_RELU
342 if faf == Op.Tanh:
343 act_op = NpuActivationOp.TANH
344 elif faf == Op.Sigmoid:
345 act_op = NpuActivationOp.SIGMOID
346 elif faf == Op.LUT:
347 act_op = NpuActivationOp.TABLE_LOOKUP
348 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000349 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100350
351 act = NpuActivation(act_op)
352 act.min = op.activation.min
353 act.max = op.activation.max
Fredrik Svedberg6f87be42021-10-07 10:54:20 +0200354 if act_op is NpuActivationOp.NONE_OR_RELU and op.type.is_avgpool_op() and not op.rescale:
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200355 quant = op.ofm.quantization
356 if quant and quant.zero_point: # Zero point is not 0
357 scale_f32 = 1 if quant.scale_f32 is None else quant.scale_f32
358 zero_point = quant.zero_point
359 if act.min is not None:
360 act.min = scale_f32 * quantise_float32(act.min, scale_f32, zero_point)
361 if act.max is not None:
362 act.max = scale_f32 * quantise_float32(act.max, scale_f32, zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100363 act.lookup_table_index = op.activation.lut_index
364 return act
365
366
367def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
368 """Sets common fields of the given operation"""
369 ps = cmd.ps
370 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100371
372 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100373 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100374 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100375
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100376 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100377 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100378 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100379
380 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100381 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100382 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100383 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
384
385 if cmd.weight_tensor is not None:
Tim Halld784af72021-06-08 21:25:57 +0100386 npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, cmd.scale_tensor, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100387 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100388 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
389 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100390 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
391
392 if not op.type.is_elementwise_op():
393 npu_op.padding = create_padding(cmd, op)
394 npu_op.kernel = to_npu_kernel(op.kernel)
395 npu_op.ifm_upscale = get_upscale(op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100396 return npu_op
397
398
399def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
400 """Converts the command to NpuConv2DOperation"""
401 npu_op = NpuConv2DOperation()
402 set_common_op_fields(npu_op, cmd, arch)
403 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
404 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
405 else:
Tim Halld8339a72021-05-27 18:49:40 +0100406 if cmd.weight_tensor.src_tensor:
407 npu_op.block_traversal = cmd.weight_tensor.src_tensor.hw_traversal
408 else:
409 npu_op.block_traversal = cmd.weight_tensor.hw_traversal
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100410 return npu_op
411
412
413def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
414 """Converts the command to NpuConvDepthWiseOperation"""
415 npu_op = NpuConvDepthWiseOperation()
416 set_common_op_fields(npu_op, cmd, arch)
417 return npu_op
418
419
420def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
421 """Converts the command to NpuPoolingOperation"""
422 ps = cmd.ps
423 op = ps.primary_op
424 pool_op = NpuPoolingOp.AVERAGE
425 if op.type.is_maxpool_op():
426 pool_op = NpuPoolingOp.MAX
427 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
428 pool_op = NpuPoolingOp.AVERAGE
429 elif op.type == Op.ReduceSum:
430 pool_op = NpuPoolingOp.REDUCE_SUM
431 else:
432 assert 0, f"Unknown pool type {op.type}"
433 npu_op = NpuPoolingOperation(pool_op)
434 set_common_op_fields(npu_op, cmd, arch)
435 # Pooling specific info
Dwight Lidman4f728c02020-12-17 15:14:45 +0100436 npu_op.rescale = op.rescale
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200437 if op.explicit_scaling:
438 # Note: reuse of rescale for explicit scaling to not expose this in the external API
439 assert npu_op.rescale is None
440 npu_op.rescale = op.explicit_scaling
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100441 return npu_op
442
443
444def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
445 """Converts the command to NpuElementWiseOperation"""
446 ps = cmd.ps
447 op = ps.primary_op
448 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
449 elemwise_op = elementwise_op_map[op.type]
450 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100451
Louis Verhaard1e170182020-11-26 11:42:04 +0100452 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100453 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
454 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
455 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100456 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
457 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
458 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100459 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100460 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100461 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100462 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
463 if cmd.ifm2_tensor.shape == []:
464 # scalar
James Peet7519d502021-07-19 16:47:58 +0100465 npu_op.ifm2_scalar = cmd.ifm2_tensor.get_scalar()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100466 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
467 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100468 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100469 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100470 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100471 set_common_op_fields(npu_op, cmd, arch)
472 # Check if output scale needs to be overridden
473 output_scale = None
474 if op.type == Op.Add and "resizebilinear" in op.attrs:
475 # Force output scale same as the input scale for
476 # resizebilinear 1x1 that is converted to add
477 output_scale = npu_op.ifm2.quantization.scale_f32
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100478 if op.type == Op.Abs:
479 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100480 if op.type == Op.LeakyRelu:
481 output_scale = op.attrs["alpha"]
Patrik Gustavssonb081d672021-08-25 13:49:25 +0200482 if op.type in (Op.RescaleAdd, Op.RescaleMul):
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100483 assert op.rescale is not None, f"{op.type} must have rescale"
484 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100485 if op.type in (Op.Add, Op.Mul, Op.Sub):
486 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
487 output_scale = 1 / 0x3000
488 if output_scale is not None:
489 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
490 return npu_op
491
492
493def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
494 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100495 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100496 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100497 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100498 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100499 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100500
Tim Halld8339a72021-05-27 18:49:40 +0100501 if cmd.in_tensor.purpose == TensorPurpose.Weights:
502 # Get weight range per core
503 sz = 0
504 for core in range(0, arch.ncores):
505 key = WeightKey(core, cmd.box.start_coord[-1])
506 if key in cmd.in_tensor.encoded_ranges:
507 weight_range = cmd.in_tensor.encoded_ranges[key]
508 sz += round_up(weight_range.total_bytes, 16)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100509
Tim Halld8339a72021-05-27 18:49:40 +0100510 if core == 0:
511 weight_range = cmd.in_tensor.encoded_ranges[key]
512 src_addr = cmd.in_tensor.address + weight_range.offset
513
514 if cmd.out_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
515 dest_addr = cmd.out_tensor.address + cmd.in_tensor.max_range_bytes * (
516 (weight_range.index - core) % 2
517 )
518 else:
519 dest_addr = cmd.out_tensor.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100520 else:
Tim Halld8339a72021-05-27 18:49:40 +0100521 start_coord = cmd.box.start_coord
522 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
523 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100524 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
525 src = NpuAddressRange(src_region, int(src_addr), int(sz))
526 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
527 return NpuDmaOperation(src, dest)
528
529
530def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
531 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100532 npu_op: NpuOperation
533 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100534 npu_op = create_dma_op(cmd, arch)
Dwight Lidman9b43f842020-12-08 17:56:44 +0100535 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100536 npu_block_type = cmd.ps.primary_op.type.npu_block_type
537 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
538 npu_op = create_npu_conv2d_op(cmd, arch)
539 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
540 npu_op = create_npu_conv_depthwise_op(cmd, arch)
541 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
542 npu_op = create_npu_pool_op(cmd, arch)
543 elif npu_block_type == NpuBlockType.ElementWise:
544 npu_op = create_npu_elementwise_op(cmd, arch)
545 else:
546 assert 0, f"Unknown command type {npu_block_type}"
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100547 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100548
549
550def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
551 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
552 # Convert high level command stream to list of NpuOperation
553 npu_op_list = []
554 npu_op_to_cmd = dict() # map from npu op to high level command
555 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100556 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100557 print("Warning: Skipping register command stream generation for", cmd.ps)
558 else:
559 npu_op = convert_command_to_npu_op(cmd, arch)
560 npu_op_list.append(npu_op)
561 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100562 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100563 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100564 if len(sg.high_level_command_stream) > 0:
565 stream_id = DebugDatabase.add_stream(sg)
566 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100567
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100568 def add_to_debug_db(npu_op: NpuOperation, offset: int):
569 """Adds info to the debug database"""
570 if not isinstance(npu_op, NpuDmaOperation):
571 cmd = npu_op_to_cmd[npu_op]
572 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100573
Louis Verhaard024c3552021-03-17 14:26:34 +0100574 sg.register_command_stream = generate_command_stream(
575 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
576 )