blob: 1059e6e7123fa1b40cc0133a3b6110f1d14241da [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
20from typing import List
21from typing import Optional
22
23from .api import NpuActivation
24from .api import NpuActivationOp
25from .api import NpuAddressRange
26from .api import NpuBlockOperation
27from .api import NpuBlockTraversal
28from .api import NpuConv2DOperation
29from .api import NpuConvDepthWiseOperation
30from .api import NpuDataType
31from .api import NpuDmaOperation
32from .api import NpuElementWiseOp
33from .api import NpuElementWiseOperation
34from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010035from .api import NpuLayout
36from .api import NpuOperation
37from .api import NpuPadding
38from .api import NpuPoolingOp
39from .api import NpuPoolingOperation
40from .api import NpuQuantization
41from .api import NpuResamplingMode
42from .api import NpuRoundingMode
43from .api import NpuShape3D
44from .api import NpuTileBox
45from .architecture_features import ArchitectureFeatures
46from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010047from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000048from .errors import UnsupportedFeatureError
Louis Verhaarde8a5a782020-11-02 18:04:27 +010049from .high_level_command_stream import Box
50from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010051from .high_level_command_stream import DMA
52from .high_level_command_stream import NpuStripe
Louis Verhaarde8a5a782020-11-02 18:04:27 +010053from .operation import NpuBlockType
54from .operation import Op
55from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010056from .register_command_stream_generator import generate_command_stream
57from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010058from .register_command_stream_util import to_npu_kernel
59from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000060from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010061from .tensor import MemType
62from .tensor import Tensor
63from .tensor import TensorBlockTraversal
64from .tensor import TensorFormat
65from .tensor import TensorPurpose
66
67
Louis Verhaarde8a5a782020-11-02 18:04:27 +010068class BasePointerIndex(IntEnum):
69 WeightTensor = 0 # base address index for the Weight tensor
70 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
71 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010072
73
74dtype_map = {
75 DataType.uint8: NpuDataType.UINT8,
76 DataType.int8: NpuDataType.INT8,
77 DataType.uint16: NpuDataType.UINT16,
78 DataType.int16: NpuDataType.INT16,
79 DataType.int32: NpuDataType.INT32,
80}
81
82
83block_traversal_map = {
84 TensorBlockTraversal.DepthFirst: NpuBlockTraversal.DEPTH_FIRST,
85 TensorBlockTraversal.PartKernelFirst: NpuBlockTraversal.PART_KERNEL_FIRST,
86}
87
88
89# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
90elementwise_op_map = {
91 Op.Mul: NpuElementWiseOp.MUL,
92 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010093 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010094 Op.Sub: NpuElementWiseOp.SUB,
95 Op.Minimum: NpuElementWiseOp.MIN,
96 Op.Maximum: NpuElementWiseOp.MAX,
97 Op.LeakyRelu: NpuElementWiseOp.LRELU,
98 Op.Abs: NpuElementWiseOp.ABS,
99 Op.CLZ: NpuElementWiseOp.CLZ,
100 Op.SHR: NpuElementWiseOp.SHR,
101 Op.SHL: NpuElementWiseOp.SHL,
102}
103
104
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100105def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
106 if ifm_shape == []:
107 # Scalar needs to be in IFM2
108 return False
109 if ifm2_shape == []:
110 return True
111
112 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
113 if ifm != ifm2 and ifm == 1:
114 # Broadcasted FM needs to be in IFM2
115 return False
116 return True
117
118
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100119def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100120 """Specifies type of rounding to be used"""
121 rounding_mode = NpuRoundingMode.TFL
122 if op.type == Op.ResizeBilinear:
123 rounding_mode = NpuRoundingMode.TRUNCATE
124 elif (
125 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
126 and op.ifm.dtype == DataType.int16
127 ):
128 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100129 elif (
130 not fused_quantize
131 and op.type.is_avgpool_op()
132 and op.memory_function == Op.ConcatSliceWrite
133 and op.kernel.elements_wh() == 1
134 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100135 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100136 if op.rounding_mode is not None:
137 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100138 return rounding_mode
139
140
141def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
142 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
143 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100144 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100145
146 # Check if this is for horizontal ifm streaming
147 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100148 top = cmd.pad_top
149 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100150
151 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
152 # because of activation function needed to be fused.
Andreas Nevalainen083f1032020-11-18 10:45:50 +0100153 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > 0:
Louis Verhaard69b31762020-11-17 09:45:20 +0100154 left = 0
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100155 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < cmd.ps.ifm_shapes[0].width:
Louis Verhaard69b31762020-11-17 09:45:20 +0100156 right = 0
157 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100158
159
160def get_region(tens: Tensor, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000161 base_ptr_idx_map = {
162 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
163 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
164 MemType.Scratch: BasePointerIndex.ScratchTensor,
165 }
166
167 if arch.is_spilling_enabled():
168 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100169 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000170 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
171
Dwight Lidman9b43f842020-12-08 17:56:44 +0100172 return base_ptr_idx_map[tens.mem_type].value
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100173
174
175def get_upscale(op: Operation) -> NpuResamplingMode:
176 upscale = NpuResamplingMode.NONE
177 if op.type == Op.ResizeBilinear:
178 # perform nearest neighbor upscale
179 upscale = NpuResamplingMode.NEAREST
180 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
181 # perform insert zero upscale
182 upscale = NpuResamplingMode.TRANSPOSE
183 return upscale
184
185
186def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
187 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100188 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100189 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100190 block = ofm_box.get_block()
191 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100192
193
194def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
195 """Checks if quantization should use 0 as zero point"""
196 if tens.dtype == DataType.int32 and is_ifm_tensor:
197 return True
198 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
199 return False
200 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
201 forced_ofm_quantization = ps.primary_op.forced_output_quantization
202 use_0 = (
203 (ps.primary_op.activation is None or forced_ofm_quantization is not None)
204 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
205 and not fused_quantize
206 )
207 return use_0
208
209
210def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
211 """Gets quantization for IFM/IFM2"""
212 if tens.quantization is None:
213 return None
214 if use_zero_point_0(ps, tens, True):
215 zero_point = 0
216 else:
217 zero_point = int(tens.quantization.zero_point)
218 return NpuQuantization(scale_f32=tens.quantization.scale_f32, zero_point=zero_point)
219
220
221def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
222 """Gets quantization for OFM"""
223 op = ps.primary_op
224 # Check if operation's output quantization is should be used instead of the output tensor's quantization
225 # (used in LUTs)
226 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
227 if ofm_quant is None:
228 return None
229 if use_zero_point_0(ps, tens, False):
230 zero_point = 0
231 else:
232 zero_point = int(ofm_quant.zero_point)
233 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
234
235
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100236def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100237 """Creates feature map with common fields populated"""
238 fm = NpuFeatureMap()
239 fm.region = get_region(tens, arch)
240 fm.data_type = dtype_map[tens.dtype]
241 if tens.format == TensorFormat.NHWC:
242 fm.layout = NpuLayout.NHWC
243 elif tens.format == TensorFormat.NHCWB16:
244 fm.layout = NpuLayout.NHCWB16
245 else:
246 assert 0, "Incorrect tensor format"
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100247 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
248 box.start_coord, box.end_coord, op_shape4D
249 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100250 for idx, addr in enumerate(addresses):
251 if addr is None:
252 addresses[idx] = 0
253 fm.tiles = NpuTileBox(
254 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
255 )
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100256 strides = tens.get_strides(shape4D=op_shape4D)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100257 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
258 return fm
259
260
261def create_weights(weight_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures) -> List[NpuAddressRange]:
262 """Returns address ranges for weights"""
263 weights = []
264 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
265 weight_substream_offsets = weight_tensor.compressed_values_substream_offsets[stream_index]
266 substreams = len(weight_substream_offsets) - 1 # Offset list must terminate with full stream length
267
268 # Extract weight substream offsets and calculate their lengths
269 assert len(weight_substream_offsets) > 1 and (weight_substream_offsets[0] == 0)
270 weight_addr = weight_tensor.address_for_coordinate(weight_box.start_coord)
271 region = get_region(weight_tensor, arch)
272 for core in range(substreams):
273 address = weight_addr + weight_substream_offsets[core]
274 length = weight_substream_offsets[core + 1] - weight_substream_offsets[core]
275 addr_range = NpuAddressRange(region, int(address), int(length))
276 weights.append(addr_range)
277 return weights
278
279
280def create_biases(
281 weight_tensor: Tensor, scale_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures
282) -> List[NpuAddressRange]:
283 """Returns address ranges for biases"""
284 biases = []
285 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
286 scale_substream_offsets = scale_tensor.compressed_values_substream_offsets[stream_index]
287 substreams = len(scale_substream_offsets) - 1 # Offset list must terminate with full stream length
288
289 # Extract scale substream offsets and calculate their lengths
290 assert len(scale_substream_offsets) > 1 and (scale_substream_offsets[0] == 0)
291 scale_addr = scale_tensor.address_for_coordinate(weight_box.start_coord[-1:])
292
293 region = get_region(scale_tensor, arch)
294 for core in range(substreams):
295 address = scale_addr + scale_substream_offsets[core]
296 length = scale_substream_offsets[core + 1] - scale_substream_offsets[core]
297 addr_range = NpuAddressRange(region, int(address), int(length))
298 biases.append(addr_range)
299 return biases
300
301
302def create_npu_activation(op: Operation) -> NpuActivation:
303 """Creates fused activation function"""
304 if op.activation is None:
305 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
306 faf = op.activation.op_type
307 act_op = NpuActivationOp.NONE_OR_RELU
308 if faf == Op.Tanh:
309 act_op = NpuActivationOp.TANH
310 elif faf == Op.Sigmoid:
311 act_op = NpuActivationOp.SIGMOID
312 elif faf == Op.LUT:
313 act_op = NpuActivationOp.TABLE_LOOKUP
314 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000315 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100316
317 act = NpuActivation(act_op)
318 act.min = op.activation.min
319 act.max = op.activation.max
320 act.lookup_table_index = op.activation.lut_index
321 return act
322
323
324def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
325 """Sets common fields of the given operation"""
326 ps = cmd.ps
327 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100328
329 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100330 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100331 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100332
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100333 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100334 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100335 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100336
337 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100338 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100339 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100340 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
341
342 if cmd.weight_tensor is not None:
343 npu_op.weights = create_weights(cmd.weight_tensor, cmd.weight_box, arch)
344 if cmd.scale_tensor is not None:
345 npu_op.biases = create_biases(cmd.weight_tensor, cmd.scale_tensor, cmd.weight_box, arch)
346 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100347 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
348 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100349 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
350
351 if not op.type.is_elementwise_op():
352 npu_op.padding = create_padding(cmd, op)
353 npu_op.kernel = to_npu_kernel(op.kernel)
354 npu_op.ifm_upscale = get_upscale(op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100355 return npu_op
356
357
358def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
359 """Converts the command to NpuConv2DOperation"""
360 npu_op = NpuConv2DOperation()
361 set_common_op_fields(npu_op, cmd, arch)
362 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
363 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
364 else:
365 npu_op.block_traversal = block_traversal_map[cmd.weight_tensor.block_traversal]
366 return npu_op
367
368
369def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
370 """Converts the command to NpuConvDepthWiseOperation"""
371 npu_op = NpuConvDepthWiseOperation()
372 set_common_op_fields(npu_op, cmd, arch)
373 return npu_op
374
375
376def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
377 """Converts the command to NpuPoolingOperation"""
378 ps = cmd.ps
379 op = ps.primary_op
380 pool_op = NpuPoolingOp.AVERAGE
381 if op.type.is_maxpool_op():
382 pool_op = NpuPoolingOp.MAX
383 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
384 pool_op = NpuPoolingOp.AVERAGE
385 elif op.type == Op.ReduceSum:
386 pool_op = NpuPoolingOp.REDUCE_SUM
387 else:
388 assert 0, f"Unknown pool type {op.type}"
389 npu_op = NpuPoolingOperation(pool_op)
390 set_common_op_fields(npu_op, cmd, arch)
391 # Pooling specific info
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100392 if op.type == Op.ResizeBilinear:
393 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100394 return npu_op
395
396
397def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
398 """Converts the command to NpuElementWiseOperation"""
399 ps = cmd.ps
400 op = ps.primary_op
401 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
402 elemwise_op = elementwise_op_map[op.type]
403 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100404
Louis Verhaard1e170182020-11-26 11:42:04 +0100405 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100406 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
407 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
408 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100409 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
410 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
411 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100412 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100413 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100414 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100415 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
416 if cmd.ifm2_tensor.shape == []:
417 # scalar
418 assert cmd.ifm2_tensor.quant_values.size == 1
419 npu_op.ifm2_scalar = cmd.ifm2_tensor.values.item(0)
420 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
421 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100422 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100423 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100424 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100425 set_common_op_fields(npu_op, cmd, arch)
426 # Check if output scale needs to be overridden
427 output_scale = None
428 if op.type == Op.Add and "resizebilinear" in op.attrs:
429 # Force output scale same as the input scale for
430 # resizebilinear 1x1 that is converted to add
431 output_scale = npu_op.ifm2.quantization.scale_f32
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100432 if op.type == Op.Abs:
433 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100434 if op.type == Op.LeakyRelu:
435 output_scale = op.attrs["alpha"]
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100436 if op.type == Op.RescaleAdd:
437 assert op.rescale is not None, f"{op.type} must have rescale"
438 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100439 if op.type in (Op.Add, Op.Mul, Op.Sub):
440 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
441 output_scale = 1 / 0x3000
442 if output_scale is not None:
443 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
444 return npu_op
445
446
447def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
448 """Converts the command to NpuDmaOperation"""
449 src_region = get_region(cmd.in_tensor, arch)
450 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100451 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100452 else:
453 dest_region = get_region(cmd.out_tensor, arch)
454
455 start_coord = cmd.box.start_coord
456 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
457 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
458
459 if cmd.in_tensor.compressed_values is not None:
460 if cmd.out_tensor.purpose == TensorPurpose.FSBias:
461 sz = cmd.in_tensor.storage_size()
462 else:
463 stream_index = cmd.in_tensor.compressed_stream_index_from_coord(start_coord)
464 sz = cmd.in_tensor.size_of_compressed_stream(stream_index)
465 else:
466 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
467 src = NpuAddressRange(src_region, int(src_addr), int(sz))
468 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
469 return NpuDmaOperation(src, dest)
470
471
472def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
473 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100474 npu_op: NpuOperation
475 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100476 npu_op = create_dma_op(cmd, arch)
Dwight Lidman9b43f842020-12-08 17:56:44 +0100477 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100478 npu_block_type = cmd.ps.primary_op.type.npu_block_type
479 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
480 npu_op = create_npu_conv2d_op(cmd, arch)
481 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
482 npu_op = create_npu_conv_depthwise_op(cmd, arch)
483 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
484 npu_op = create_npu_pool_op(cmd, arch)
485 elif npu_block_type == NpuBlockType.ElementWise:
486 npu_op = create_npu_elementwise_op(cmd, arch)
487 else:
488 assert 0, f"Unknown command type {npu_block_type}"
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100489 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100490
491
492def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
493 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
494 # Convert high level command stream to list of NpuOperation
495 npu_op_list = []
496 npu_op_to_cmd = dict() # map from npu op to high level command
497 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100498 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100499 print("Warning: Skipping register command stream generation for", cmd.ps)
500 else:
501 npu_op = convert_command_to_npu_op(cmd, arch)
502 npu_op_list.append(npu_op)
503 npu_op_to_cmd[npu_op] = cmd
504 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100505 if len(sg.high_level_command_stream) > 0:
506 stream_id = DebugDatabase.add_stream(sg)
507 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100508
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100509 def add_to_debug_db(npu_op: NpuOperation, offset: int):
510 """Adds info to the debug database"""
511 if not isinstance(npu_op, NpuDmaOperation):
512 cmd = npu_op_to_cmd[npu_op]
513 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100514
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100515 sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db, npu_op_to_cmd)