blob: b5e7b4b9e171aee62dc0075497e9ac333f5ce70d [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
20from typing import List
21from typing import Optional
22
23from .api import NpuActivation
24from .api import NpuActivationOp
25from .api import NpuAddressRange
26from .api import NpuBlockOperation
27from .api import NpuBlockTraversal
28from .api import NpuConv2DOperation
29from .api import NpuConvDepthWiseOperation
30from .api import NpuDataType
31from .api import NpuDmaOperation
32from .api import NpuElementWiseOp
33from .api import NpuElementWiseOperation
34from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010035from .api import NpuLayout
36from .api import NpuOperation
37from .api import NpuPadding
38from .api import NpuPoolingOp
39from .api import NpuPoolingOperation
40from .api import NpuQuantization
41from .api import NpuResamplingMode
42from .api import NpuRoundingMode
43from .api import NpuShape3D
44from .api import NpuTileBox
45from .architecture_features import ArchitectureFeatures
46from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010047from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000048from .errors import UnsupportedFeatureError
Louis Verhaarde8a5a782020-11-02 18:04:27 +010049from .high_level_command_stream import Box
50from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010051from .high_level_command_stream import DMA
52from .high_level_command_stream import NpuStripe
Louis Verhaarde8a5a782020-11-02 18:04:27 +010053from .operation import NpuBlockType
54from .operation import Op
55from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010056from .register_command_stream_generator import generate_command_stream
57from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010058from .register_command_stream_util import to_npu_kernel
59from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000060from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010061from .tensor import MemType
62from .tensor import Tensor
63from .tensor import TensorBlockTraversal
64from .tensor import TensorFormat
65from .tensor import TensorPurpose
66
67
Louis Verhaarde8a5a782020-11-02 18:04:27 +010068class BasePointerIndex(IntEnum):
69 WeightTensor = 0 # base address index for the Weight tensor
70 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
71 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010072
73
74dtype_map = {
75 DataType.uint8: NpuDataType.UINT8,
76 DataType.int8: NpuDataType.INT8,
77 DataType.uint16: NpuDataType.UINT16,
78 DataType.int16: NpuDataType.INT16,
79 DataType.int32: NpuDataType.INT32,
80}
81
82
83block_traversal_map = {
84 TensorBlockTraversal.DepthFirst: NpuBlockTraversal.DEPTH_FIRST,
85 TensorBlockTraversal.PartKernelFirst: NpuBlockTraversal.PART_KERNEL_FIRST,
86}
87
88
89# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
90elementwise_op_map = {
91 Op.Mul: NpuElementWiseOp.MUL,
92 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010093 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010094 Op.Sub: NpuElementWiseOp.SUB,
95 Op.Minimum: NpuElementWiseOp.MIN,
96 Op.Maximum: NpuElementWiseOp.MAX,
97 Op.LeakyRelu: NpuElementWiseOp.LRELU,
98 Op.Abs: NpuElementWiseOp.ABS,
99 Op.CLZ: NpuElementWiseOp.CLZ,
100 Op.SHR: NpuElementWiseOp.SHR,
101 Op.SHL: NpuElementWiseOp.SHL,
102}
103
104
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100105def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
106 if ifm_shape == []:
107 # Scalar needs to be in IFM2
108 return False
109 if ifm2_shape == []:
110 return True
111
112 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
113 if ifm != ifm2 and ifm == 1:
114 # Broadcasted FM needs to be in IFM2
115 return False
116 return True
117
118
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100119def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100120 """Specifies type of rounding to be used"""
121 rounding_mode = NpuRoundingMode.TFL
122 if op.type == Op.ResizeBilinear:
123 rounding_mode = NpuRoundingMode.TRUNCATE
124 elif (
125 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
126 and op.ifm.dtype == DataType.int16
127 ):
128 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100129 elif (
130 not fused_quantize
131 and op.type.is_avgpool_op()
132 and op.memory_function == Op.ConcatSliceWrite
133 and op.kernel.elements_wh() == 1
134 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100135 rounding_mode = NpuRoundingMode.NATURAL
136 rounding_mode = op.attrs.get("rounding_mode", rounding_mode)
137 return rounding_mode
138
139
140def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
141 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
142 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100143 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100144
145 # Check if this is for horizontal ifm streaming
146 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100147 top = cmd.pad_top
148 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100149
150 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
151 # because of activation function needed to be fused.
Andreas Nevalainen083f1032020-11-18 10:45:50 +0100152 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > 0:
Louis Verhaard69b31762020-11-17 09:45:20 +0100153 left = 0
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100154 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < cmd.ps.ifm_shapes[0].width:
Louis Verhaard69b31762020-11-17 09:45:20 +0100155 right = 0
156 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100157
158
159def get_region(tens: Tensor, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000160 base_ptr_idx_map = {
161 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
162 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
163 MemType.Scratch: BasePointerIndex.ScratchTensor,
164 }
165
166 if arch.is_spilling_enabled():
167 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100168 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000169 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
170
Dwight Lidman9b43f842020-12-08 17:56:44 +0100171 return base_ptr_idx_map[tens.mem_type].value
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100172
173
174def get_upscale(op: Operation) -> NpuResamplingMode:
175 upscale = NpuResamplingMode.NONE
176 if op.type == Op.ResizeBilinear:
177 # perform nearest neighbor upscale
178 upscale = NpuResamplingMode.NEAREST
179 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
180 # perform insert zero upscale
181 upscale = NpuResamplingMode.TRANSPOSE
182 return upscale
183
184
185def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
186 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100187 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100188 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100189 block = ofm_box.get_block()
190 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100191
192
193def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
194 """Checks if quantization should use 0 as zero point"""
195 if tens.dtype == DataType.int32 and is_ifm_tensor:
196 return True
197 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
198 return False
199 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
200 forced_ofm_quantization = ps.primary_op.forced_output_quantization
201 use_0 = (
202 (ps.primary_op.activation is None or forced_ofm_quantization is not None)
203 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
204 and not fused_quantize
205 )
206 return use_0
207
208
209def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
210 """Gets quantization for IFM/IFM2"""
211 if tens.quantization is None:
212 return None
213 if use_zero_point_0(ps, tens, True):
214 zero_point = 0
215 else:
216 zero_point = int(tens.quantization.zero_point)
217 return NpuQuantization(scale_f32=tens.quantization.scale_f32, zero_point=zero_point)
218
219
220def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
221 """Gets quantization for OFM"""
222 op = ps.primary_op
223 # Check if operation's output quantization is should be used instead of the output tensor's quantization
224 # (used in LUTs)
225 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
226 if ofm_quant is None:
227 return None
228 if use_zero_point_0(ps, tens, False):
229 zero_point = 0
230 else:
231 zero_point = int(ofm_quant.zero_point)
232 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
233
234
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100235def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100236 """Creates feature map with common fields populated"""
237 fm = NpuFeatureMap()
238 fm.region = get_region(tens, arch)
239 fm.data_type = dtype_map[tens.dtype]
240 if tens.format == TensorFormat.NHWC:
241 fm.layout = NpuLayout.NHWC
242 elif tens.format == TensorFormat.NHCWB16:
243 fm.layout = NpuLayout.NHCWB16
244 else:
245 assert 0, "Incorrect tensor format"
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100246 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
247 box.start_coord, box.end_coord, op_shape4D
248 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100249 for idx, addr in enumerate(addresses):
250 if addr is None:
251 addresses[idx] = 0
252 fm.tiles = NpuTileBox(
253 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
254 )
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100255 strides = tens.get_strides(shape4D=op_shape4D)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100256 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
257 return fm
258
259
260def create_weights(weight_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures) -> List[NpuAddressRange]:
261 """Returns address ranges for weights"""
262 weights = []
263 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
264 weight_substream_offsets = weight_tensor.compressed_values_substream_offsets[stream_index]
265 substreams = len(weight_substream_offsets) - 1 # Offset list must terminate with full stream length
266
267 # Extract weight substream offsets and calculate their lengths
268 assert len(weight_substream_offsets) > 1 and (weight_substream_offsets[0] == 0)
269 weight_addr = weight_tensor.address_for_coordinate(weight_box.start_coord)
270 region = get_region(weight_tensor, arch)
271 for core in range(substreams):
272 address = weight_addr + weight_substream_offsets[core]
273 length = weight_substream_offsets[core + 1] - weight_substream_offsets[core]
274 addr_range = NpuAddressRange(region, int(address), int(length))
275 weights.append(addr_range)
276 return weights
277
278
279def create_biases(
280 weight_tensor: Tensor, scale_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures
281) -> List[NpuAddressRange]:
282 """Returns address ranges for biases"""
283 biases = []
284 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
285 scale_substream_offsets = scale_tensor.compressed_values_substream_offsets[stream_index]
286 substreams = len(scale_substream_offsets) - 1 # Offset list must terminate with full stream length
287
288 # Extract scale substream offsets and calculate their lengths
289 assert len(scale_substream_offsets) > 1 and (scale_substream_offsets[0] == 0)
290 scale_addr = scale_tensor.address_for_coordinate(weight_box.start_coord[-1:])
291
292 region = get_region(scale_tensor, arch)
293 for core in range(substreams):
294 address = scale_addr + scale_substream_offsets[core]
295 length = scale_substream_offsets[core + 1] - scale_substream_offsets[core]
296 addr_range = NpuAddressRange(region, int(address), int(length))
297 biases.append(addr_range)
298 return biases
299
300
301def create_npu_activation(op: Operation) -> NpuActivation:
302 """Creates fused activation function"""
303 if op.activation is None:
304 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
305 faf = op.activation.op_type
306 act_op = NpuActivationOp.NONE_OR_RELU
307 if faf == Op.Tanh:
308 act_op = NpuActivationOp.TANH
309 elif faf == Op.Sigmoid:
310 act_op = NpuActivationOp.SIGMOID
311 elif faf == Op.LUT:
312 act_op = NpuActivationOp.TABLE_LOOKUP
313 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000314 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100315
316 act = NpuActivation(act_op)
317 act.min = op.activation.min
318 act.max = op.activation.max
319 act.lookup_table_index = op.activation.lut_index
320 return act
321
322
323def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
324 """Sets common fields of the given operation"""
325 ps = cmd.ps
326 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100327
328 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100329 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100330 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100331
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100332 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100333 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100334 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100335
336 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100337 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100338 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100339 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
340
341 if cmd.weight_tensor is not None:
342 npu_op.weights = create_weights(cmd.weight_tensor, cmd.weight_box, arch)
343 if cmd.scale_tensor is not None:
344 npu_op.biases = create_biases(cmd.weight_tensor, cmd.scale_tensor, cmd.weight_box, arch)
345 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100346 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
347 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100348 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
349
350 if not op.type.is_elementwise_op():
351 npu_op.padding = create_padding(cmd, op)
352 npu_op.kernel = to_npu_kernel(op.kernel)
353 npu_op.ifm_upscale = get_upscale(op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100354 return npu_op
355
356
357def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
358 """Converts the command to NpuConv2DOperation"""
359 npu_op = NpuConv2DOperation()
360 set_common_op_fields(npu_op, cmd, arch)
361 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
362 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
363 else:
364 npu_op.block_traversal = block_traversal_map[cmd.weight_tensor.block_traversal]
365 return npu_op
366
367
368def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
369 """Converts the command to NpuConvDepthWiseOperation"""
370 npu_op = NpuConvDepthWiseOperation()
371 set_common_op_fields(npu_op, cmd, arch)
372 return npu_op
373
374
375def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
376 """Converts the command to NpuPoolingOperation"""
377 ps = cmd.ps
378 op = ps.primary_op
379 pool_op = NpuPoolingOp.AVERAGE
380 if op.type.is_maxpool_op():
381 pool_op = NpuPoolingOp.MAX
382 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
383 pool_op = NpuPoolingOp.AVERAGE
384 elif op.type == Op.ReduceSum:
385 pool_op = NpuPoolingOp.REDUCE_SUM
386 else:
387 assert 0, f"Unknown pool type {op.type}"
388 npu_op = NpuPoolingOperation(pool_op)
389 set_common_op_fields(npu_op, cmd, arch)
390 # Pooling specific info
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100391 if op.type == Op.ResizeBilinear:
392 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100393 return npu_op
394
395
396def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
397 """Converts the command to NpuElementWiseOperation"""
398 ps = cmd.ps
399 op = ps.primary_op
400 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
401 elemwise_op = elementwise_op_map[op.type]
402 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100403
Louis Verhaard1e170182020-11-26 11:42:04 +0100404 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100405 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
406 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
407 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100408 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
409 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
410 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100411 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100412 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100413 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100414 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
415 if cmd.ifm2_tensor.shape == []:
416 # scalar
417 assert cmd.ifm2_tensor.quant_values.size == 1
418 npu_op.ifm2_scalar = cmd.ifm2_tensor.values.item(0)
419 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
420 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100421 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100422 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100423 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100424 set_common_op_fields(npu_op, cmd, arch)
425 # Check if output scale needs to be overridden
426 output_scale = None
427 if op.type == Op.Add and "resizebilinear" in op.attrs:
428 # Force output scale same as the input scale for
429 # resizebilinear 1x1 that is converted to add
430 output_scale = npu_op.ifm2.quantization.scale_f32
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100431 if op.type == Op.Abs:
432 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100433 if op.type == Op.LeakyRelu:
434 output_scale = op.attrs["alpha"]
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100435 if op.type == Op.RescaleAdd:
436 assert op.rescale is not None, f"{op.type} must have rescale"
437 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100438 if op.type in (Op.Add, Op.Mul, Op.Sub):
439 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
440 output_scale = 1 / 0x3000
441 if output_scale is not None:
442 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
443 return npu_op
444
445
446def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
447 """Converts the command to NpuDmaOperation"""
448 src_region = get_region(cmd.in_tensor, arch)
449 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100450 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100451 else:
452 dest_region = get_region(cmd.out_tensor, arch)
453
454 start_coord = cmd.box.start_coord
455 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
456 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
457
458 if cmd.in_tensor.compressed_values is not None:
459 if cmd.out_tensor.purpose == TensorPurpose.FSBias:
460 sz = cmd.in_tensor.storage_size()
461 else:
462 stream_index = cmd.in_tensor.compressed_stream_index_from_coord(start_coord)
463 sz = cmd.in_tensor.size_of_compressed_stream(stream_index)
464 else:
465 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
466 src = NpuAddressRange(src_region, int(src_addr), int(sz))
467 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
468 return NpuDmaOperation(src, dest)
469
470
471def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
472 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100473 npu_op: NpuOperation
474 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100475 npu_op = create_dma_op(cmd, arch)
Dwight Lidman9b43f842020-12-08 17:56:44 +0100476 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100477 npu_block_type = cmd.ps.primary_op.type.npu_block_type
478 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
479 npu_op = create_npu_conv2d_op(cmd, arch)
480 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
481 npu_op = create_npu_conv_depthwise_op(cmd, arch)
482 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
483 npu_op = create_npu_pool_op(cmd, arch)
484 elif npu_block_type == NpuBlockType.ElementWise:
485 npu_op = create_npu_elementwise_op(cmd, arch)
486 else:
487 assert 0, f"Unknown command type {npu_block_type}"
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100488 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100489
490
491def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
492 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
493 # Convert high level command stream to list of NpuOperation
494 npu_op_list = []
495 npu_op_to_cmd = dict() # map from npu op to high level command
496 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100497 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100498 print("Warning: Skipping register command stream generation for", cmd.ps)
499 else:
500 npu_op = convert_command_to_npu_op(cmd, arch)
501 npu_op_list.append(npu_op)
502 npu_op_to_cmd[npu_op] = cmd
503 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100504 if len(sg.high_level_command_stream) > 0:
505 stream_id = DebugDatabase.add_stream(sg)
506 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100507
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100508 def add_to_debug_db(npu_op: NpuOperation, offset: int):
509 """Adds info to the debug database"""
510 if not isinstance(npu_op, NpuDmaOperation):
511 cmd = npu_op_to_cmd[npu_op]
512 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100513
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100514 sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db, npu_op_to_cmd)