blob: c56eb04d1bbd6c1e42ca1aae4df6af901ec5256a [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Louis Verhaard024c3552021-03-17 14:26:34 +010020from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010021from typing import List
22from typing import Optional
23
24from .api import NpuActivation
25from .api import NpuActivationOp
26from .api import NpuAddressRange
27from .api import NpuBlockOperation
28from .api import NpuBlockTraversal
29from .api import NpuConv2DOperation
30from .api import NpuConvDepthWiseOperation
31from .api import NpuDataType
32from .api import NpuDmaOperation
33from .api import NpuElementWiseOp
34from .api import NpuElementWiseOperation
35from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010036from .api import NpuLayout
37from .api import NpuOperation
38from .api import NpuPadding
39from .api import NpuPoolingOp
40from .api import NpuPoolingOperation
41from .api import NpuQuantization
42from .api import NpuResamplingMode
43from .api import NpuRoundingMode
44from .api import NpuShape3D
45from .api import NpuTileBox
46from .architecture_features import ArchitectureFeatures
47from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010048from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000049from .errors import UnsupportedFeatureError
Louis Verhaarde8a5a782020-11-02 18:04:27 +010050from .high_level_command_stream import Box
51from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010052from .high_level_command_stream import DMA
53from .high_level_command_stream import NpuStripe
Louis Verhaarde8a5a782020-11-02 18:04:27 +010054from .operation import NpuBlockType
55from .operation import Op
56from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010057from .register_command_stream_generator import generate_command_stream
58from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010059from .register_command_stream_util import to_npu_kernel
60from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000061from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010062from .tensor import MemType
63from .tensor import Tensor
64from .tensor import TensorBlockTraversal
65from .tensor import TensorFormat
66from .tensor import TensorPurpose
67
68
Louis Verhaarde8a5a782020-11-02 18:04:27 +010069class BasePointerIndex(IntEnum):
70 WeightTensor = 0 # base address index for the Weight tensor
71 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
72 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010073
74
75dtype_map = {
76 DataType.uint8: NpuDataType.UINT8,
77 DataType.int8: NpuDataType.INT8,
78 DataType.uint16: NpuDataType.UINT16,
79 DataType.int16: NpuDataType.INT16,
80 DataType.int32: NpuDataType.INT32,
81}
82
83
84block_traversal_map = {
85 TensorBlockTraversal.DepthFirst: NpuBlockTraversal.DEPTH_FIRST,
86 TensorBlockTraversal.PartKernelFirst: NpuBlockTraversal.PART_KERNEL_FIRST,
87}
88
89
90# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
91elementwise_op_map = {
92 Op.Mul: NpuElementWiseOp.MUL,
93 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010094 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010095 Op.Sub: NpuElementWiseOp.SUB,
96 Op.Minimum: NpuElementWiseOp.MIN,
97 Op.Maximum: NpuElementWiseOp.MAX,
98 Op.LeakyRelu: NpuElementWiseOp.LRELU,
99 Op.Abs: NpuElementWiseOp.ABS,
100 Op.CLZ: NpuElementWiseOp.CLZ,
101 Op.SHR: NpuElementWiseOp.SHR,
102 Op.SHL: NpuElementWiseOp.SHL,
103}
104
105
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100106def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
107 if ifm_shape == []:
108 # Scalar needs to be in IFM2
109 return False
110 if ifm2_shape == []:
111 return True
112
113 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
114 if ifm != ifm2 and ifm == 1:
115 # Broadcasted FM needs to be in IFM2
116 return False
117 return True
118
119
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100120def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100121 """Specifies type of rounding to be used"""
122 rounding_mode = NpuRoundingMode.TFL
123 if op.type == Op.ResizeBilinear:
124 rounding_mode = NpuRoundingMode.TRUNCATE
125 elif (
126 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
127 and op.ifm.dtype == DataType.int16
128 ):
129 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100130 elif (
131 not fused_quantize
132 and op.type.is_avgpool_op()
133 and op.memory_function == Op.ConcatSliceWrite
134 and op.kernel.elements_wh() == 1
135 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100136 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100137 if op.rounding_mode is not None:
138 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100139 return rounding_mode
140
141
142def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
143 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
144 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100145 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100146
147 # Check if this is for horizontal ifm streaming
148 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100149 top = cmd.pad_top
150 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100151
152 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
153 # because of activation function needed to be fused.
Andreas Nevalainen083f1032020-11-18 10:45:50 +0100154 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > 0:
Louis Verhaard69b31762020-11-17 09:45:20 +0100155 left = 0
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100156 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < cmd.ps.ifm_shapes[0].width:
Louis Verhaard69b31762020-11-17 09:45:20 +0100157 right = 0
158 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100159
160
Louis Verhaard024c3552021-03-17 14:26:34 +0100161def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000162 base_ptr_idx_map = {
163 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
164 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
165 MemType.Scratch: BasePointerIndex.ScratchTensor,
166 }
167
168 if arch.is_spilling_enabled():
169 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100170 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000171 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
172
Louis Verhaard024c3552021-03-17 14:26:34 +0100173 return base_ptr_idx_map[mem_type].value
174
175
176def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
177 """Returns map region -> max size of the region in bytes"""
178 mem_limits = dict()
179 for mem_type in MemType.all():
180 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
181 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
182 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100183
184
185def get_upscale(op: Operation) -> NpuResamplingMode:
186 upscale = NpuResamplingMode.NONE
187 if op.type == Op.ResizeBilinear:
188 # perform nearest neighbor upscale
189 upscale = NpuResamplingMode.NEAREST
190 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
191 # perform insert zero upscale
192 upscale = NpuResamplingMode.TRANSPOSE
193 return upscale
194
195
196def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
197 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100198 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100199 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100200 block = ofm_box.get_block()
201 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100202
203
204def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
205 """Checks if quantization should use 0 as zero point"""
206 if tens.dtype == DataType.int32 and is_ifm_tensor:
207 return True
208 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
209 return False
210 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
211 forced_ofm_quantization = ps.primary_op.forced_output_quantization
212 use_0 = (
213 (ps.primary_op.activation is None or forced_ofm_quantization is not None)
214 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
215 and not fused_quantize
216 )
217 return use_0
218
219
220def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
221 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100222 op = ps.primary_op
223 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
224 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100225 return None
226 if use_zero_point_0(ps, tens, True):
227 zero_point = 0
228 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100229 zero_point = int(ifm_quant.zero_point)
230 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100231
232
233def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
234 """Gets quantization for OFM"""
235 op = ps.primary_op
236 # Check if operation's output quantization is should be used instead of the output tensor's quantization
237 # (used in LUTs)
238 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
239 if ofm_quant is None:
240 return None
241 if use_zero_point_0(ps, tens, False):
242 zero_point = 0
243 else:
244 zero_point = int(ofm_quant.zero_point)
245 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
246
247
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100248def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100249 """Creates feature map with common fields populated"""
250 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100251 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100252 fm.data_type = dtype_map[tens.dtype]
253 if tens.format == TensorFormat.NHWC:
254 fm.layout = NpuLayout.NHWC
255 elif tens.format == TensorFormat.NHCWB16:
256 fm.layout = NpuLayout.NHCWB16
257 else:
258 assert 0, "Incorrect tensor format"
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100259 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
260 box.start_coord, box.end_coord, op_shape4D
261 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100262 for idx, addr in enumerate(addresses):
263 if addr is None:
264 addresses[idx] = 0
265 fm.tiles = NpuTileBox(
266 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
267 )
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100268 strides = tens.get_strides(shape4D=op_shape4D)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100269 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
270 return fm
271
272
273def create_weights(weight_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures) -> List[NpuAddressRange]:
274 """Returns address ranges for weights"""
275 weights = []
276 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
277 weight_substream_offsets = weight_tensor.compressed_values_substream_offsets[stream_index]
278 substreams = len(weight_substream_offsets) - 1 # Offset list must terminate with full stream length
279
280 # Extract weight substream offsets and calculate their lengths
281 assert len(weight_substream_offsets) > 1 and (weight_substream_offsets[0] == 0)
282 weight_addr = weight_tensor.address_for_coordinate(weight_box.start_coord)
Louis Verhaard024c3552021-03-17 14:26:34 +0100283 region = get_region(weight_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100284 for core in range(substreams):
285 address = weight_addr + weight_substream_offsets[core]
286 length = weight_substream_offsets[core + 1] - weight_substream_offsets[core]
287 addr_range = NpuAddressRange(region, int(address), int(length))
288 weights.append(addr_range)
289 return weights
290
291
292def create_biases(
293 weight_tensor: Tensor, scale_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures
294) -> List[NpuAddressRange]:
295 """Returns address ranges for biases"""
296 biases = []
297 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
298 scale_substream_offsets = scale_tensor.compressed_values_substream_offsets[stream_index]
299 substreams = len(scale_substream_offsets) - 1 # Offset list must terminate with full stream length
300
301 # Extract scale substream offsets and calculate their lengths
302 assert len(scale_substream_offsets) > 1 and (scale_substream_offsets[0] == 0)
303 scale_addr = scale_tensor.address_for_coordinate(weight_box.start_coord[-1:])
304
Louis Verhaard024c3552021-03-17 14:26:34 +0100305 region = get_region(scale_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100306 for core in range(substreams):
307 address = scale_addr + scale_substream_offsets[core]
308 length = scale_substream_offsets[core + 1] - scale_substream_offsets[core]
309 addr_range = NpuAddressRange(region, int(address), int(length))
310 biases.append(addr_range)
311 return biases
312
313
314def create_npu_activation(op: Operation) -> NpuActivation:
315 """Creates fused activation function"""
316 if op.activation is None:
317 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
318 faf = op.activation.op_type
319 act_op = NpuActivationOp.NONE_OR_RELU
320 if faf == Op.Tanh:
321 act_op = NpuActivationOp.TANH
322 elif faf == Op.Sigmoid:
323 act_op = NpuActivationOp.SIGMOID
324 elif faf == Op.LUT:
325 act_op = NpuActivationOp.TABLE_LOOKUP
326 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000327 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100328
329 act = NpuActivation(act_op)
330 act.min = op.activation.min
331 act.max = op.activation.max
332 act.lookup_table_index = op.activation.lut_index
333 return act
334
335
336def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
337 """Sets common fields of the given operation"""
338 ps = cmd.ps
339 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100340
341 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100342 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100343 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100344
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100345 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100346 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100347 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100348
349 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100350 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100351 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100352 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
353
354 if cmd.weight_tensor is not None:
355 npu_op.weights = create_weights(cmd.weight_tensor, cmd.weight_box, arch)
356 if cmd.scale_tensor is not None:
357 npu_op.biases = create_biases(cmd.weight_tensor, cmd.scale_tensor, cmd.weight_box, arch)
358 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100359 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
360 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100361 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
362
363 if not op.type.is_elementwise_op():
364 npu_op.padding = create_padding(cmd, op)
365 npu_op.kernel = to_npu_kernel(op.kernel)
366 npu_op.ifm_upscale = get_upscale(op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100367 return npu_op
368
369
370def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
371 """Converts the command to NpuConv2DOperation"""
372 npu_op = NpuConv2DOperation()
373 set_common_op_fields(npu_op, cmd, arch)
374 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
375 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
376 else:
377 npu_op.block_traversal = block_traversal_map[cmd.weight_tensor.block_traversal]
378 return npu_op
379
380
381def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
382 """Converts the command to NpuConvDepthWiseOperation"""
383 npu_op = NpuConvDepthWiseOperation()
384 set_common_op_fields(npu_op, cmd, arch)
385 return npu_op
386
387
388def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
389 """Converts the command to NpuPoolingOperation"""
390 ps = cmd.ps
391 op = ps.primary_op
392 pool_op = NpuPoolingOp.AVERAGE
393 if op.type.is_maxpool_op():
394 pool_op = NpuPoolingOp.MAX
395 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
396 pool_op = NpuPoolingOp.AVERAGE
397 elif op.type == Op.ReduceSum:
398 pool_op = NpuPoolingOp.REDUCE_SUM
399 else:
400 assert 0, f"Unknown pool type {op.type}"
401 npu_op = NpuPoolingOperation(pool_op)
402 set_common_op_fields(npu_op, cmd, arch)
403 # Pooling specific info
Dwight Lidman4f728c02020-12-17 15:14:45 +0100404 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100405 return npu_op
406
407
408def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
409 """Converts the command to NpuElementWiseOperation"""
410 ps = cmd.ps
411 op = ps.primary_op
412 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
413 elemwise_op = elementwise_op_map[op.type]
414 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100415
Louis Verhaard1e170182020-11-26 11:42:04 +0100416 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100417 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
418 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
419 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100420 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
421 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
422 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100423 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100424 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100425 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100426 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
427 if cmd.ifm2_tensor.shape == []:
428 # scalar
429 assert cmd.ifm2_tensor.quant_values.size == 1
430 npu_op.ifm2_scalar = cmd.ifm2_tensor.values.item(0)
431 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
432 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100433 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100434 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100435 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100436 set_common_op_fields(npu_op, cmd, arch)
437 # Check if output scale needs to be overridden
438 output_scale = None
439 if op.type == Op.Add and "resizebilinear" in op.attrs:
440 # Force output scale same as the input scale for
441 # resizebilinear 1x1 that is converted to add
442 output_scale = npu_op.ifm2.quantization.scale_f32
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100443 if op.type == Op.Abs:
444 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100445 if op.type == Op.LeakyRelu:
446 output_scale = op.attrs["alpha"]
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100447 if op.type == Op.RescaleAdd:
448 assert op.rescale is not None, f"{op.type} must have rescale"
449 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100450 if op.type in (Op.Add, Op.Mul, Op.Sub):
451 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
452 output_scale = 1 / 0x3000
453 if output_scale is not None:
454 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
455 return npu_op
456
457
458def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
459 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100460 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100461 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100462 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100463 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100464 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100465
466 start_coord = cmd.box.start_coord
467 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
468 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
469
470 if cmd.in_tensor.compressed_values is not None:
471 if cmd.out_tensor.purpose == TensorPurpose.FSBias:
472 sz = cmd.in_tensor.storage_size()
473 else:
474 stream_index = cmd.in_tensor.compressed_stream_index_from_coord(start_coord)
475 sz = cmd.in_tensor.size_of_compressed_stream(stream_index)
476 else:
477 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
478 src = NpuAddressRange(src_region, int(src_addr), int(sz))
479 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
480 return NpuDmaOperation(src, dest)
481
482
483def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
484 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100485 npu_op: NpuOperation
486 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100487 npu_op = create_dma_op(cmd, arch)
Dwight Lidman9b43f842020-12-08 17:56:44 +0100488 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100489 npu_block_type = cmd.ps.primary_op.type.npu_block_type
490 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
491 npu_op = create_npu_conv2d_op(cmd, arch)
492 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
493 npu_op = create_npu_conv_depthwise_op(cmd, arch)
494 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
495 npu_op = create_npu_pool_op(cmd, arch)
496 elif npu_block_type == NpuBlockType.ElementWise:
497 npu_op = create_npu_elementwise_op(cmd, arch)
498 else:
499 assert 0, f"Unknown command type {npu_block_type}"
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100500 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100501
502
503def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
504 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
505 # Convert high level command stream to list of NpuOperation
506 npu_op_list = []
507 npu_op_to_cmd = dict() # map from npu op to high level command
508 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100509 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100510 print("Warning: Skipping register command stream generation for", cmd.ps)
511 else:
512 npu_op = convert_command_to_npu_op(cmd, arch)
513 npu_op_list.append(npu_op)
514 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100515 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100516 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100517 if len(sg.high_level_command_stream) > 0:
518 stream_id = DebugDatabase.add_stream(sg)
519 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100520
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100521 def add_to_debug_db(npu_op: NpuOperation, offset: int):
522 """Adds info to the debug database"""
523 if not isinstance(npu_op, NpuDmaOperation):
524 cmd = npu_op_to_cmd[npu_op]
525 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100526
Louis Verhaard024c3552021-03-17 14:26:34 +0100527 sg.register_command_stream = generate_command_stream(
528 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
529 )