blob: 56c5e74741bcd810cf14e224e1446e81b87f6f95 [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
20from typing import List
21from typing import Optional
22
23from .api import NpuActivation
24from .api import NpuActivationOp
25from .api import NpuAddressRange
26from .api import NpuBlockOperation
27from .api import NpuBlockTraversal
28from .api import NpuConv2DOperation
29from .api import NpuConvDepthWiseOperation
30from .api import NpuDataType
31from .api import NpuDmaOperation
32from .api import NpuElementWiseOp
33from .api import NpuElementWiseOperation
34from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010035from .api import NpuLayout
36from .api import NpuOperation
37from .api import NpuPadding
38from .api import NpuPoolingOp
39from .api import NpuPoolingOperation
40from .api import NpuQuantization
41from .api import NpuResamplingMode
42from .api import NpuRoundingMode
43from .api import NpuShape3D
44from .api import NpuTileBox
45from .architecture_features import ArchitectureFeatures
46from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010047from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000048from .errors import UnsupportedFeatureError
Louis Verhaarde8a5a782020-11-02 18:04:27 +010049from .high_level_command_stream import Box
50from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010051from .high_level_command_stream import DMA
52from .high_level_command_stream import NpuStripe
Louis Verhaarde8a5a782020-11-02 18:04:27 +010053from .operation import NpuBlockType
54from .operation import Op
55from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010056from .register_command_stream_generator import generate_command_stream
57from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010058from .register_command_stream_util import to_npu_kernel
59from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000060from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010061from .tensor import MemType
62from .tensor import Tensor
63from .tensor import TensorBlockTraversal
64from .tensor import TensorFormat
65from .tensor import TensorPurpose
66
67
Louis Verhaarde8a5a782020-11-02 18:04:27 +010068class BasePointerIndex(IntEnum):
69 WeightTensor = 0 # base address index for the Weight tensor
70 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
71 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010072
73
74dtype_map = {
75 DataType.uint8: NpuDataType.UINT8,
76 DataType.int8: NpuDataType.INT8,
77 DataType.uint16: NpuDataType.UINT16,
78 DataType.int16: NpuDataType.INT16,
79 DataType.int32: NpuDataType.INT32,
80}
81
82
83block_traversal_map = {
84 TensorBlockTraversal.DepthFirst: NpuBlockTraversal.DEPTH_FIRST,
85 TensorBlockTraversal.PartKernelFirst: NpuBlockTraversal.PART_KERNEL_FIRST,
86}
87
88
89# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
90elementwise_op_map = {
91 Op.Mul: NpuElementWiseOp.MUL,
92 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010093 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010094 Op.Sub: NpuElementWiseOp.SUB,
95 Op.Minimum: NpuElementWiseOp.MIN,
96 Op.Maximum: NpuElementWiseOp.MAX,
97 Op.LeakyRelu: NpuElementWiseOp.LRELU,
98 Op.Abs: NpuElementWiseOp.ABS,
99 Op.CLZ: NpuElementWiseOp.CLZ,
100 Op.SHR: NpuElementWiseOp.SHR,
101 Op.SHL: NpuElementWiseOp.SHL,
102}
103
104
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100105def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
106 if ifm_shape == []:
107 # Scalar needs to be in IFM2
108 return False
109 if ifm2_shape == []:
110 return True
111
112 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
113 if ifm != ifm2 and ifm == 1:
114 # Broadcasted FM needs to be in IFM2
115 return False
116 return True
117
118
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100119def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100120 """Specifies type of rounding to be used"""
121 rounding_mode = NpuRoundingMode.TFL
122 if op.type == Op.ResizeBilinear:
123 rounding_mode = NpuRoundingMode.TRUNCATE
124 elif (
125 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
126 and op.ifm.dtype == DataType.int16
127 ):
128 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100129 elif (
130 not fused_quantize
131 and op.type.is_avgpool_op()
132 and op.memory_function == Op.ConcatSliceWrite
133 and op.kernel.elements_wh() == 1
134 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100135 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100136 if op.rounding_mode is not None:
137 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100138 return rounding_mode
139
140
141def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
142 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
143 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100144 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100145
146 # Check if this is for horizontal ifm streaming
147 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100148 top = cmd.pad_top
149 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100150
151 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
152 # because of activation function needed to be fused.
Andreas Nevalainen083f1032020-11-18 10:45:50 +0100153 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > 0:
Louis Verhaard69b31762020-11-17 09:45:20 +0100154 left = 0
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100155 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < cmd.ps.ifm_shapes[0].width:
Louis Verhaard69b31762020-11-17 09:45:20 +0100156 right = 0
157 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100158
159
160def get_region(tens: Tensor, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000161 base_ptr_idx_map = {
162 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
163 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
164 MemType.Scratch: BasePointerIndex.ScratchTensor,
165 }
166
167 if arch.is_spilling_enabled():
168 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100169 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000170 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
171
Dwight Lidman9b43f842020-12-08 17:56:44 +0100172 return base_ptr_idx_map[tens.mem_type].value
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100173
174
175def get_upscale(op: Operation) -> NpuResamplingMode:
176 upscale = NpuResamplingMode.NONE
177 if op.type == Op.ResizeBilinear:
178 # perform nearest neighbor upscale
179 upscale = NpuResamplingMode.NEAREST
180 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
181 # perform insert zero upscale
182 upscale = NpuResamplingMode.TRANSPOSE
183 return upscale
184
185
186def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
187 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100188 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100189 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100190 block = ofm_box.get_block()
191 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100192
193
194def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
195 """Checks if quantization should use 0 as zero point"""
196 if tens.dtype == DataType.int32 and is_ifm_tensor:
197 return True
198 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
199 return False
200 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
201 forced_ofm_quantization = ps.primary_op.forced_output_quantization
202 use_0 = (
203 (ps.primary_op.activation is None or forced_ofm_quantization is not None)
204 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
205 and not fused_quantize
206 )
207 return use_0
208
209
210def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
211 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100212 op = ps.primary_op
213 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
214 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100215 return None
216 if use_zero_point_0(ps, tens, True):
217 zero_point = 0
218 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100219 zero_point = int(ifm_quant.zero_point)
220 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100221
222
223def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
224 """Gets quantization for OFM"""
225 op = ps.primary_op
226 # Check if operation's output quantization is should be used instead of the output tensor's quantization
227 # (used in LUTs)
228 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
229 if ofm_quant is None:
230 return None
231 if use_zero_point_0(ps, tens, False):
232 zero_point = 0
233 else:
234 zero_point = int(ofm_quant.zero_point)
235 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
236
237
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100238def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100239 """Creates feature map with common fields populated"""
240 fm = NpuFeatureMap()
241 fm.region = get_region(tens, arch)
242 fm.data_type = dtype_map[tens.dtype]
243 if tens.format == TensorFormat.NHWC:
244 fm.layout = NpuLayout.NHWC
245 elif tens.format == TensorFormat.NHCWB16:
246 fm.layout = NpuLayout.NHCWB16
247 else:
248 assert 0, "Incorrect tensor format"
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100249 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
250 box.start_coord, box.end_coord, op_shape4D
251 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100252 for idx, addr in enumerate(addresses):
253 if addr is None:
254 addresses[idx] = 0
255 fm.tiles = NpuTileBox(
256 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
257 )
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100258 strides = tens.get_strides(shape4D=op_shape4D)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100259 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
260 return fm
261
262
263def create_weights(weight_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures) -> List[NpuAddressRange]:
264 """Returns address ranges for weights"""
265 weights = []
266 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
267 weight_substream_offsets = weight_tensor.compressed_values_substream_offsets[stream_index]
268 substreams = len(weight_substream_offsets) - 1 # Offset list must terminate with full stream length
269
270 # Extract weight substream offsets and calculate their lengths
271 assert len(weight_substream_offsets) > 1 and (weight_substream_offsets[0] == 0)
272 weight_addr = weight_tensor.address_for_coordinate(weight_box.start_coord)
273 region = get_region(weight_tensor, arch)
274 for core in range(substreams):
275 address = weight_addr + weight_substream_offsets[core]
276 length = weight_substream_offsets[core + 1] - weight_substream_offsets[core]
277 addr_range = NpuAddressRange(region, int(address), int(length))
278 weights.append(addr_range)
279 return weights
280
281
282def create_biases(
283 weight_tensor: Tensor, scale_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures
284) -> List[NpuAddressRange]:
285 """Returns address ranges for biases"""
286 biases = []
287 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
288 scale_substream_offsets = scale_tensor.compressed_values_substream_offsets[stream_index]
289 substreams = len(scale_substream_offsets) - 1 # Offset list must terminate with full stream length
290
291 # Extract scale substream offsets and calculate their lengths
292 assert len(scale_substream_offsets) > 1 and (scale_substream_offsets[0] == 0)
293 scale_addr = scale_tensor.address_for_coordinate(weight_box.start_coord[-1:])
294
295 region = get_region(scale_tensor, arch)
296 for core in range(substreams):
297 address = scale_addr + scale_substream_offsets[core]
298 length = scale_substream_offsets[core + 1] - scale_substream_offsets[core]
299 addr_range = NpuAddressRange(region, int(address), int(length))
300 biases.append(addr_range)
301 return biases
302
303
304def create_npu_activation(op: Operation) -> NpuActivation:
305 """Creates fused activation function"""
306 if op.activation is None:
307 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
308 faf = op.activation.op_type
309 act_op = NpuActivationOp.NONE_OR_RELU
310 if faf == Op.Tanh:
311 act_op = NpuActivationOp.TANH
312 elif faf == Op.Sigmoid:
313 act_op = NpuActivationOp.SIGMOID
314 elif faf == Op.LUT:
315 act_op = NpuActivationOp.TABLE_LOOKUP
316 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000317 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100318
319 act = NpuActivation(act_op)
320 act.min = op.activation.min
321 act.max = op.activation.max
322 act.lookup_table_index = op.activation.lut_index
323 return act
324
325
326def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
327 """Sets common fields of the given operation"""
328 ps = cmd.ps
329 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100330
331 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100332 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100333 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100334
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100335 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100336 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100337 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100338
339 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100340 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100341 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100342 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
343
344 if cmd.weight_tensor is not None:
345 npu_op.weights = create_weights(cmd.weight_tensor, cmd.weight_box, arch)
346 if cmd.scale_tensor is not None:
347 npu_op.biases = create_biases(cmd.weight_tensor, cmd.scale_tensor, cmd.weight_box, arch)
348 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100349 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
350 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100351 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
352
353 if not op.type.is_elementwise_op():
354 npu_op.padding = create_padding(cmd, op)
355 npu_op.kernel = to_npu_kernel(op.kernel)
356 npu_op.ifm_upscale = get_upscale(op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100357 return npu_op
358
359
360def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
361 """Converts the command to NpuConv2DOperation"""
362 npu_op = NpuConv2DOperation()
363 set_common_op_fields(npu_op, cmd, arch)
364 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
365 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
366 else:
367 npu_op.block_traversal = block_traversal_map[cmd.weight_tensor.block_traversal]
368 return npu_op
369
370
371def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
372 """Converts the command to NpuConvDepthWiseOperation"""
373 npu_op = NpuConvDepthWiseOperation()
374 set_common_op_fields(npu_op, cmd, arch)
375 return npu_op
376
377
378def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
379 """Converts the command to NpuPoolingOperation"""
380 ps = cmd.ps
381 op = ps.primary_op
382 pool_op = NpuPoolingOp.AVERAGE
383 if op.type.is_maxpool_op():
384 pool_op = NpuPoolingOp.MAX
385 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
386 pool_op = NpuPoolingOp.AVERAGE
387 elif op.type == Op.ReduceSum:
388 pool_op = NpuPoolingOp.REDUCE_SUM
389 else:
390 assert 0, f"Unknown pool type {op.type}"
391 npu_op = NpuPoolingOperation(pool_op)
392 set_common_op_fields(npu_op, cmd, arch)
393 # Pooling specific info
Dwight Lidman4f728c02020-12-17 15:14:45 +0100394 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100395 return npu_op
396
397
398def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
399 """Converts the command to NpuElementWiseOperation"""
400 ps = cmd.ps
401 op = ps.primary_op
402 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
403 elemwise_op = elementwise_op_map[op.type]
404 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100405
Louis Verhaard1e170182020-11-26 11:42:04 +0100406 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100407 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
408 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
409 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100410 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
411 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
412 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100413 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100414 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100415 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100416 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
417 if cmd.ifm2_tensor.shape == []:
418 # scalar
419 assert cmd.ifm2_tensor.quant_values.size == 1
420 npu_op.ifm2_scalar = cmd.ifm2_tensor.values.item(0)
421 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
422 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100423 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100424 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100425 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100426 set_common_op_fields(npu_op, cmd, arch)
427 # Check if output scale needs to be overridden
428 output_scale = None
429 if op.type == Op.Add and "resizebilinear" in op.attrs:
430 # Force output scale same as the input scale for
431 # resizebilinear 1x1 that is converted to add
432 output_scale = npu_op.ifm2.quantization.scale_f32
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100433 if op.type == Op.Abs:
434 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100435 if op.type == Op.LeakyRelu:
436 output_scale = op.attrs["alpha"]
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100437 if op.type == Op.RescaleAdd:
438 assert op.rescale is not None, f"{op.type} must have rescale"
439 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100440 if op.type in (Op.Add, Op.Mul, Op.Sub):
441 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
442 output_scale = 1 / 0x3000
443 if output_scale is not None:
444 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
445 return npu_op
446
447
448def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
449 """Converts the command to NpuDmaOperation"""
450 src_region = get_region(cmd.in_tensor, arch)
451 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100452 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100453 else:
454 dest_region = get_region(cmd.out_tensor, arch)
455
456 start_coord = cmd.box.start_coord
457 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
458 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
459
460 if cmd.in_tensor.compressed_values is not None:
461 if cmd.out_tensor.purpose == TensorPurpose.FSBias:
462 sz = cmd.in_tensor.storage_size()
463 else:
464 stream_index = cmd.in_tensor.compressed_stream_index_from_coord(start_coord)
465 sz = cmd.in_tensor.size_of_compressed_stream(stream_index)
466 else:
467 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
468 src = NpuAddressRange(src_region, int(src_addr), int(sz))
469 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
470 return NpuDmaOperation(src, dest)
471
472
473def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
474 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100475 npu_op: NpuOperation
476 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100477 npu_op = create_dma_op(cmd, arch)
Dwight Lidman9b43f842020-12-08 17:56:44 +0100478 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100479 npu_block_type = cmd.ps.primary_op.type.npu_block_type
480 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
481 npu_op = create_npu_conv2d_op(cmd, arch)
482 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
483 npu_op = create_npu_conv_depthwise_op(cmd, arch)
484 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
485 npu_op = create_npu_pool_op(cmd, arch)
486 elif npu_block_type == NpuBlockType.ElementWise:
487 npu_op = create_npu_elementwise_op(cmd, arch)
488 else:
489 assert 0, f"Unknown command type {npu_block_type}"
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100490 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100491
492
493def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
494 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
495 # Convert high level command stream to list of NpuOperation
496 npu_op_list = []
497 npu_op_to_cmd = dict() # map from npu op to high level command
498 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100499 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100500 print("Warning: Skipping register command stream generation for", cmd.ps)
501 else:
502 npu_op = convert_command_to_npu_op(cmd, arch)
503 npu_op_list.append(npu_op)
504 npu_op_to_cmd[npu_op] = cmd
505 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100506 if len(sg.high_level_command_stream) > 0:
507 stream_id = DebugDatabase.add_stream(sg)
508 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100509
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100510 def add_to_debug_db(npu_op: NpuOperation, offset: int):
511 """Adds info to the debug database"""
512 if not isinstance(npu_op, NpuDmaOperation):
513 cmd = npu_op_to_cmd[npu_op]
514 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100515
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100516 sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db, npu_op_to_cmd)