blob: ad9e2664e1e04c8fb7bfd0ae4b9da3023e7c1044 [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Louis Verhaard024c3552021-03-17 14:26:34 +010020from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010021from typing import List
22from typing import Optional
23
24from .api import NpuActivation
25from .api import NpuActivationOp
26from .api import NpuAddressRange
27from .api import NpuBlockOperation
28from .api import NpuBlockTraversal
29from .api import NpuConv2DOperation
30from .api import NpuConvDepthWiseOperation
31from .api import NpuDataType
32from .api import NpuDmaOperation
33from .api import NpuElementWiseOp
34from .api import NpuElementWiseOperation
35from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010036from .api import NpuLayout
37from .api import NpuOperation
38from .api import NpuPadding
39from .api import NpuPoolingOp
40from .api import NpuPoolingOperation
41from .api import NpuQuantization
42from .api import NpuResamplingMode
43from .api import NpuRoundingMode
44from .api import NpuShape3D
45from .api import NpuTileBox
46from .architecture_features import ArchitectureFeatures
47from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010048from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000049from .errors import UnsupportedFeatureError
Louis Verhaarde8a5a782020-11-02 18:04:27 +010050from .high_level_command_stream import Box
51from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010052from .high_level_command_stream import DMA
53from .high_level_command_stream import NpuStripe
Louis Verhaarde8a5a782020-11-02 18:04:27 +010054from .operation import NpuBlockType
55from .operation import Op
56from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010057from .register_command_stream_generator import generate_command_stream
58from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010059from .register_command_stream_util import to_npu_kernel
60from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000061from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010062from .tensor import MemType
63from .tensor import Tensor
64from .tensor import TensorBlockTraversal
65from .tensor import TensorFormat
66from .tensor import TensorPurpose
67
68
Louis Verhaarde8a5a782020-11-02 18:04:27 +010069class BasePointerIndex(IntEnum):
70 WeightTensor = 0 # base address index for the Weight tensor
71 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
72 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010073
74
75dtype_map = {
76 DataType.uint8: NpuDataType.UINT8,
77 DataType.int8: NpuDataType.INT8,
78 DataType.uint16: NpuDataType.UINT16,
79 DataType.int16: NpuDataType.INT16,
80 DataType.int32: NpuDataType.INT32,
81}
82
83
84block_traversal_map = {
85 TensorBlockTraversal.DepthFirst: NpuBlockTraversal.DEPTH_FIRST,
86 TensorBlockTraversal.PartKernelFirst: NpuBlockTraversal.PART_KERNEL_FIRST,
87}
88
89
90# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
91elementwise_op_map = {
92 Op.Mul: NpuElementWiseOp.MUL,
93 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010094 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010095 Op.Sub: NpuElementWiseOp.SUB,
96 Op.Minimum: NpuElementWiseOp.MIN,
97 Op.Maximum: NpuElementWiseOp.MAX,
98 Op.LeakyRelu: NpuElementWiseOp.LRELU,
99 Op.Abs: NpuElementWiseOp.ABS,
100 Op.CLZ: NpuElementWiseOp.CLZ,
101 Op.SHR: NpuElementWiseOp.SHR,
102 Op.SHL: NpuElementWiseOp.SHL,
103}
104
105
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100106def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
107 if ifm_shape == []:
108 # Scalar needs to be in IFM2
109 return False
110 if ifm2_shape == []:
111 return True
112
113 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
114 if ifm != ifm2 and ifm == 1:
115 # Broadcasted FM needs to be in IFM2
116 return False
117 return True
118
119
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100120def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100121 """Specifies type of rounding to be used"""
122 rounding_mode = NpuRoundingMode.TFL
123 if op.type == Op.ResizeBilinear:
124 rounding_mode = NpuRoundingMode.TRUNCATE
125 elif (
126 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
127 and op.ifm.dtype == DataType.int16
128 ):
129 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100130 elif (
131 not fused_quantize
132 and op.type.is_avgpool_op()
133 and op.memory_function == Op.ConcatSliceWrite
134 and op.kernel.elements_wh() == 1
135 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100136 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100137 if op.rounding_mode is not None:
138 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100139 return rounding_mode
140
141
142def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
143 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
144 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100145 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100146
147 # Check if this is for horizontal ifm streaming
148 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100149 top = cmd.pad_top
150 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100151
152 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
153 # because of activation function needed to be fused.
Henrik G Olsson5fabfca2021-04-15 17:57:26 +0200154 if not primary_op.attrs.get("force_padding"):
155 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > 0:
156 left = 0
157 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < cmd.ps.ifm_shapes[0].width:
158 right = 0
Louis Verhaard69b31762020-11-17 09:45:20 +0100159 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100160
161
Louis Verhaard024c3552021-03-17 14:26:34 +0100162def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000163 base_ptr_idx_map = {
164 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
165 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
166 MemType.Scratch: BasePointerIndex.ScratchTensor,
167 }
168
169 if arch.is_spilling_enabled():
170 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100171 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000172 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
173
Louis Verhaard024c3552021-03-17 14:26:34 +0100174 return base_ptr_idx_map[mem_type].value
175
176
177def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
178 """Returns map region -> max size of the region in bytes"""
179 mem_limits = dict()
180 for mem_type in MemType.all():
181 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
182 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
183 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100184
185
186def get_upscale(op: Operation) -> NpuResamplingMode:
187 upscale = NpuResamplingMode.NONE
188 if op.type == Op.ResizeBilinear:
189 # perform nearest neighbor upscale
190 upscale = NpuResamplingMode.NEAREST
191 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
192 # perform insert zero upscale
193 upscale = NpuResamplingMode.TRANSPOSE
194 return upscale
195
196
197def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
198 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100199 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100200 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100201 block = ofm_box.get_block()
202 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100203
204
205def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
206 """Checks if quantization should use 0 as zero point"""
207 if tens.dtype == DataType.int32 and is_ifm_tensor:
208 return True
209 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
210 return False
211 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
212 forced_ofm_quantization = ps.primary_op.forced_output_quantization
213 use_0 = (
214 (ps.primary_op.activation is None or forced_ofm_quantization is not None)
215 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
216 and not fused_quantize
217 )
218 return use_0
219
220
221def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
222 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100223 op = ps.primary_op
224 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
225 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100226 return None
227 if use_zero_point_0(ps, tens, True):
228 zero_point = 0
229 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100230 zero_point = int(ifm_quant.zero_point)
231 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100232
233
234def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
235 """Gets quantization for OFM"""
236 op = ps.primary_op
237 # Check if operation's output quantization is should be used instead of the output tensor's quantization
238 # (used in LUTs)
239 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
240 if ofm_quant is None:
241 return None
242 if use_zero_point_0(ps, tens, False):
243 zero_point = 0
244 else:
245 zero_point = int(ofm_quant.zero_point)
246 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
247
248
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100249def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100250 """Creates feature map with common fields populated"""
251 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100252 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100253 fm.data_type = dtype_map[tens.dtype]
254 if tens.format == TensorFormat.NHWC:
255 fm.layout = NpuLayout.NHWC
256 elif tens.format == TensorFormat.NHCWB16:
257 fm.layout = NpuLayout.NHCWB16
258 else:
259 assert 0, "Incorrect tensor format"
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100260 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
261 box.start_coord, box.end_coord, op_shape4D
262 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100263 for idx, addr in enumerate(addresses):
264 if addr is None:
265 addresses[idx] = 0
266 fm.tiles = NpuTileBox(
267 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
268 )
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100269 strides = tens.get_strides(shape4D=op_shape4D)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100270 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
271 return fm
272
273
274def create_weights(weight_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures) -> List[NpuAddressRange]:
275 """Returns address ranges for weights"""
276 weights = []
277 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
278 weight_substream_offsets = weight_tensor.compressed_values_substream_offsets[stream_index]
279 substreams = len(weight_substream_offsets) - 1 # Offset list must terminate with full stream length
280
281 # Extract weight substream offsets and calculate their lengths
282 assert len(weight_substream_offsets) > 1 and (weight_substream_offsets[0] == 0)
283 weight_addr = weight_tensor.address_for_coordinate(weight_box.start_coord)
Louis Verhaard024c3552021-03-17 14:26:34 +0100284 region = get_region(weight_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100285 for core in range(substreams):
286 address = weight_addr + weight_substream_offsets[core]
287 length = weight_substream_offsets[core + 1] - weight_substream_offsets[core]
288 addr_range = NpuAddressRange(region, int(address), int(length))
289 weights.append(addr_range)
290 return weights
291
292
293def create_biases(
294 weight_tensor: Tensor, scale_tensor: Tensor, weight_box: Box, arch: ArchitectureFeatures
295) -> List[NpuAddressRange]:
296 """Returns address ranges for biases"""
297 biases = []
298 stream_index = weight_tensor.compressed_stream_index_from_coord(weight_box.start_coord)
299 scale_substream_offsets = scale_tensor.compressed_values_substream_offsets[stream_index]
300 substreams = len(scale_substream_offsets) - 1 # Offset list must terminate with full stream length
301
302 # Extract scale substream offsets and calculate their lengths
303 assert len(scale_substream_offsets) > 1 and (scale_substream_offsets[0] == 0)
304 scale_addr = scale_tensor.address_for_coordinate(weight_box.start_coord[-1:])
305
Louis Verhaard024c3552021-03-17 14:26:34 +0100306 region = get_region(scale_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100307 for core in range(substreams):
308 address = scale_addr + scale_substream_offsets[core]
309 length = scale_substream_offsets[core + 1] - scale_substream_offsets[core]
310 addr_range = NpuAddressRange(region, int(address), int(length))
311 biases.append(addr_range)
312 return biases
313
314
315def create_npu_activation(op: Operation) -> NpuActivation:
316 """Creates fused activation function"""
317 if op.activation is None:
318 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
319 faf = op.activation.op_type
320 act_op = NpuActivationOp.NONE_OR_RELU
321 if faf == Op.Tanh:
322 act_op = NpuActivationOp.TANH
323 elif faf == Op.Sigmoid:
324 act_op = NpuActivationOp.SIGMOID
325 elif faf == Op.LUT:
326 act_op = NpuActivationOp.TABLE_LOOKUP
327 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000328 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100329
330 act = NpuActivation(act_op)
331 act.min = op.activation.min
332 act.max = op.activation.max
333 act.lookup_table_index = op.activation.lut_index
334 return act
335
336
337def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
338 """Sets common fields of the given operation"""
339 ps = cmd.ps
340 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100341
342 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100343 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100344 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100345
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100346 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100347 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100348 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100349
350 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100351 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100352 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100353 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
354
355 if cmd.weight_tensor is not None:
356 npu_op.weights = create_weights(cmd.weight_tensor, cmd.weight_box, arch)
357 if cmd.scale_tensor is not None:
358 npu_op.biases = create_biases(cmd.weight_tensor, cmd.scale_tensor, cmd.weight_box, arch)
359 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100360 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
361 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100362 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
363
364 if not op.type.is_elementwise_op():
365 npu_op.padding = create_padding(cmd, op)
366 npu_op.kernel = to_npu_kernel(op.kernel)
367 npu_op.ifm_upscale = get_upscale(op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100368 return npu_op
369
370
371def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
372 """Converts the command to NpuConv2DOperation"""
373 npu_op = NpuConv2DOperation()
374 set_common_op_fields(npu_op, cmd, arch)
375 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
376 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
377 else:
378 npu_op.block_traversal = block_traversal_map[cmd.weight_tensor.block_traversal]
379 return npu_op
380
381
382def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
383 """Converts the command to NpuConvDepthWiseOperation"""
384 npu_op = NpuConvDepthWiseOperation()
385 set_common_op_fields(npu_op, cmd, arch)
386 return npu_op
387
388
389def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
390 """Converts the command to NpuPoolingOperation"""
391 ps = cmd.ps
392 op = ps.primary_op
393 pool_op = NpuPoolingOp.AVERAGE
394 if op.type.is_maxpool_op():
395 pool_op = NpuPoolingOp.MAX
396 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
397 pool_op = NpuPoolingOp.AVERAGE
398 elif op.type == Op.ReduceSum:
399 pool_op = NpuPoolingOp.REDUCE_SUM
400 else:
401 assert 0, f"Unknown pool type {op.type}"
402 npu_op = NpuPoolingOperation(pool_op)
403 set_common_op_fields(npu_op, cmd, arch)
404 # Pooling specific info
Dwight Lidman4f728c02020-12-17 15:14:45 +0100405 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100406 return npu_op
407
408
409def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
410 """Converts the command to NpuElementWiseOperation"""
411 ps = cmd.ps
412 op = ps.primary_op
413 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
414 elemwise_op = elementwise_op_map[op.type]
415 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100416
Louis Verhaard1e170182020-11-26 11:42:04 +0100417 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100418 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
419 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
420 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100421 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
422 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
423 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100424 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100425 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100426 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100427 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
428 if cmd.ifm2_tensor.shape == []:
429 # scalar
430 assert cmd.ifm2_tensor.quant_values.size == 1
431 npu_op.ifm2_scalar = cmd.ifm2_tensor.values.item(0)
432 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
433 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100434 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100435 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100436 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100437 set_common_op_fields(npu_op, cmd, arch)
438 # Check if output scale needs to be overridden
439 output_scale = None
440 if op.type == Op.Add and "resizebilinear" in op.attrs:
441 # Force output scale same as the input scale for
442 # resizebilinear 1x1 that is converted to add
443 output_scale = npu_op.ifm2.quantization.scale_f32
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100444 if op.type == Op.Abs:
445 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100446 if op.type == Op.LeakyRelu:
447 output_scale = op.attrs["alpha"]
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100448 if op.type == Op.RescaleAdd:
449 assert op.rescale is not None, f"{op.type} must have rescale"
450 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100451 if op.type in (Op.Add, Op.Mul, Op.Sub):
452 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
453 output_scale = 1 / 0x3000
454 if output_scale is not None:
455 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
456 return npu_op
457
458
459def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
460 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100461 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100462 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100463 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100464 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100465 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100466
467 start_coord = cmd.box.start_coord
468 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
469 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
470
471 if cmd.in_tensor.compressed_values is not None:
472 if cmd.out_tensor.purpose == TensorPurpose.FSBias:
473 sz = cmd.in_tensor.storage_size()
474 else:
475 stream_index = cmd.in_tensor.compressed_stream_index_from_coord(start_coord)
476 sz = cmd.in_tensor.size_of_compressed_stream(stream_index)
477 else:
478 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
479 src = NpuAddressRange(src_region, int(src_addr), int(sz))
480 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
481 return NpuDmaOperation(src, dest)
482
483
484def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
485 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100486 npu_op: NpuOperation
487 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100488 npu_op = create_dma_op(cmd, arch)
Dwight Lidman9b43f842020-12-08 17:56:44 +0100489 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100490 npu_block_type = cmd.ps.primary_op.type.npu_block_type
491 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
492 npu_op = create_npu_conv2d_op(cmd, arch)
493 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
494 npu_op = create_npu_conv_depthwise_op(cmd, arch)
495 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
496 npu_op = create_npu_pool_op(cmd, arch)
497 elif npu_block_type == NpuBlockType.ElementWise:
498 npu_op = create_npu_elementwise_op(cmd, arch)
499 else:
500 assert 0, f"Unknown command type {npu_block_type}"
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100501 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100502
503
504def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
505 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
506 # Convert high level command stream to list of NpuOperation
507 npu_op_list = []
508 npu_op_to_cmd = dict() # map from npu op to high level command
509 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100510 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100511 print("Warning: Skipping register command stream generation for", cmd.ps)
512 else:
513 npu_op = convert_command_to_npu_op(cmd, arch)
514 npu_op_list.append(npu_op)
515 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100516 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100517 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100518 if len(sg.high_level_command_stream) > 0:
519 stream_id = DebugDatabase.add_stream(sg)
520 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100521
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100522 def add_to_debug_db(npu_op: NpuOperation, offset: int):
523 """Adds info to the debug database"""
524 if not isinstance(npu_op, NpuDmaOperation):
525 cmd = npu_op_to_cmd[npu_op]
526 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100527
Louis Verhaard024c3552021-03-17 14:26:34 +0100528 sg.register_command_stream = generate_command_stream(
529 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
530 )