blob: f67114ff2c962de843eba8f68c91db9c300fe290 [file] [log] [blame]
erik.andersson@arm.comad45f792021-02-03 10:20:16 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Louis Verhaarde8a5a782020-11-02 18:04:27 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Conversion from high level command to NpuOperation
19from enum import IntEnum
Louis Verhaard024c3552021-03-17 14:26:34 +010020from typing import Dict
Louis Verhaarde8a5a782020-11-02 18:04:27 +010021from typing import List
22from typing import Optional
23
24from .api import NpuActivation
25from .api import NpuActivationOp
26from .api import NpuAddressRange
27from .api import NpuBlockOperation
28from .api import NpuBlockTraversal
29from .api import NpuConv2DOperation
30from .api import NpuConvDepthWiseOperation
31from .api import NpuDataType
32from .api import NpuDmaOperation
33from .api import NpuElementWiseOp
34from .api import NpuElementWiseOperation
35from .api import NpuFeatureMap
Louis Verhaarde8a5a782020-11-02 18:04:27 +010036from .api import NpuLayout
37from .api import NpuOperation
38from .api import NpuPadding
39from .api import NpuPoolingOp
40from .api import NpuPoolingOperation
41from .api import NpuQuantization
42from .api import NpuResamplingMode
43from .api import NpuRoundingMode
44from .api import NpuShape3D
45from .api import NpuTileBox
46from .architecture_features import ArchitectureFeatures
47from .data_type import DataType
Louis Verhaard1e170182020-11-26 11:42:04 +010048from .debug_database import DebugDatabase
Michael McGeagh7a6f8432020-12-02 15:29:22 +000049from .errors import UnsupportedFeatureError
Louis Verhaarde8a5a782020-11-02 18:04:27 +010050from .high_level_command_stream import Box
51from .high_level_command_stream import Command
Louis Verhaarde8a5a782020-11-02 18:04:27 +010052from .high_level_command_stream import DMA
53from .high_level_command_stream import NpuStripe
Fredrik Svedberg838df0a2021-09-17 16:29:22 +020054from .numeric_util import quantise_float32
Tim Halld8339a72021-05-27 18:49:40 +010055from .numeric_util import round_up
Louis Verhaarde8a5a782020-11-02 18:04:27 +010056from .operation import NpuBlockType
57from .operation import Op
58from .operation import Operation
Louis Verhaard1e170182020-11-26 11:42:04 +010059from .register_command_stream_generator import generate_command_stream
60from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
Louis Verhaard1e170182020-11-26 11:42:04 +010061from .register_command_stream_util import to_npu_kernel
62from .register_command_stream_util import UNARY_ELEMWISE_OPS
patrik.gustavssoneeb85152020-12-21 17:10:40 +000063from .shape4d import Shape4D
Louis Verhaarde8a5a782020-11-02 18:04:27 +010064from .tensor import MemType
65from .tensor import Tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010066from .tensor import TensorFormat
67from .tensor import TensorPurpose
Tim Halld8339a72021-05-27 18:49:40 +010068from .tensor import TensorSubPurpose
69from .weight_compressor import WeightKey
Louis Verhaarde8a5a782020-11-02 18:04:27 +010070
71
Louis Verhaarde8a5a782020-11-02 18:04:27 +010072class BasePointerIndex(IntEnum):
73 WeightTensor = 0 # base address index for the Weight tensor
74 ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
75 ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +010076
77
78dtype_map = {
79 DataType.uint8: NpuDataType.UINT8,
80 DataType.int8: NpuDataType.INT8,
81 DataType.uint16: NpuDataType.UINT16,
82 DataType.int16: NpuDataType.INT16,
83 DataType.int32: NpuDataType.INT32,
84}
85
86
Louis Verhaarde8a5a782020-11-02 18:04:27 +010087# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
88elementwise_op_map = {
89 Op.Mul: NpuElementWiseOp.MUL,
Patrik Gustavssonb081d672021-08-25 13:49:25 +020090 Op.RescaleMul: NpuElementWiseOp.MUL,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010091 Op.Add: NpuElementWiseOp.ADD,
Fredrik Svedberge82be7c2021-01-18 15:21:03 +010092 Op.RescaleAdd: NpuElementWiseOp.ADD,
Louis Verhaarde8a5a782020-11-02 18:04:27 +010093 Op.Sub: NpuElementWiseOp.SUB,
94 Op.Minimum: NpuElementWiseOp.MIN,
95 Op.Maximum: NpuElementWiseOp.MAX,
96 Op.LeakyRelu: NpuElementWiseOp.LRELU,
97 Op.Abs: NpuElementWiseOp.ABS,
98 Op.CLZ: NpuElementWiseOp.CLZ,
99 Op.SHR: NpuElementWiseOp.SHR,
100 Op.SHL: NpuElementWiseOp.SHL,
101}
102
103
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100104def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
105 if ifm_shape == []:
106 # Scalar needs to be in IFM2
107 return False
108 if ifm2_shape == []:
109 return True
110
111 for ifm, ifm2 in zip(ifm_shape, ifm2_shape):
112 if ifm != ifm2 and ifm == 1:
113 # Broadcasted FM needs to be in IFM2
114 return False
115 return True
116
117
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100118def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100119 """Specifies type of rounding to be used"""
120 rounding_mode = NpuRoundingMode.TFL
121 if op.type == Op.ResizeBilinear:
Dwight Lidman9d243932021-08-10 12:53:12 +0200122 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100123 elif (
124 op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
125 and op.ifm.dtype == DataType.int16
126 ):
127 rounding_mode = NpuRoundingMode.NATURAL
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100128 elif (
129 not fused_quantize
130 and op.type.is_avgpool_op()
131 and op.memory_function == Op.ConcatSliceWrite
132 and op.kernel.elements_wh() == 1
133 ):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100134 rounding_mode = NpuRoundingMode.NATURAL
Louis Verhaard1a92f782021-02-09 16:08:26 +0100135 if op.rounding_mode is not None:
136 rounding_mode = op.rounding_mode
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100137 return rounding_mode
138
139
140def create_padding(cmd: NpuStripe, primary_op: Operation) -> NpuPadding:
141 if primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
142 return NpuPadding(top=0, left=0, bottom=0, right=0)
Louis Verhaard69b31762020-11-17 09:45:20 +0100143 top, left, bottom, right = primary_op.attrs["explicit_padding"]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100144
145 # Check if this is for horizontal ifm streaming
146 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
Louis Verhaard69b31762020-11-17 09:45:20 +0100147 top = cmd.pad_top
148 bottom = cmd.pad_bottom
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100149
150 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
151 # because of activation function needed to be fused.
Henrik G Olsson5fabfca2021-04-15 17:57:26 +0200152 if not primary_op.attrs.get("force_padding"):
153 if len(cmd.ifm_box.start_coord) >= 2 and cmd.ifm_box.start_coord[-2] > 0:
154 left = 0
155 if len(cmd.ifm_box.end_coord) >= 2 and cmd.ifm_box.end_coord[-2] < cmd.ps.ifm_shapes[0].width:
156 right = 0
Louis Verhaard69b31762020-11-17 09:45:20 +0100157 return NpuPadding(top=top, left=left, bottom=bottom, right=right)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100158
159
Louis Verhaard024c3552021-03-17 14:26:34 +0100160def get_region(mem_type: MemType, arch: ArchitectureFeatures) -> int:
Tim Hall1bd531d2020-11-01 20:59:36 +0000161 base_ptr_idx_map = {
162 MemType.Permanent_NPU: BasePointerIndex.WeightTensor,
163 MemType.Permanent_CPU: BasePointerIndex.WeightTensor,
164 MemType.Scratch: BasePointerIndex.ScratchTensor,
165 }
166
167 if arch.is_spilling_enabled():
168 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchFastTensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100169 else:
Tim Hall1bd531d2020-11-01 20:59:36 +0000170 base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
171
Louis Verhaard024c3552021-03-17 14:26:34 +0100172 return base_ptr_idx_map[mem_type].value
173
174
175def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
176 """Returns map region -> max size of the region in bytes"""
177 mem_limits = dict()
178 for mem_type in MemType.all():
179 mem_limits[get_region(mem_type, arch)] = arch.mem_type_size(mem_type)
180 mem_limits[BASE_PTR_INDEX_MEM2MEM] = arch.shram_size_bytes
181 return mem_limits
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100182
183
184def get_upscale(op: Operation) -> NpuResamplingMode:
185 upscale = NpuResamplingMode.NONE
186 if op.type == Op.ResizeBilinear:
187 # perform nearest neighbor upscale
188 upscale = NpuResamplingMode.NEAREST
189 elif op.type == Op.Conv2DBackpropInputSwitchedBias:
190 # perform insert zero upscale
191 upscale = NpuResamplingMode.TRANSPOSE
192 return upscale
193
194
195def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
196 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
Louis Verhaard69b31762020-11-17 09:45:20 +0100197 block = ifm_box.get_block()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100198 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100199 block = ofm_box.get_block()
200 return block.depth
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100201
202
203def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
204 """Checks if quantization should use 0 as zero point"""
205 if tens.dtype == DataType.int32 and is_ifm_tensor:
206 return True
207 if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
208 return False
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200209 if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
210 return False
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100211 fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
212 forced_ofm_quantization = ps.primary_op.forced_output_quantization
213 use_0 = (
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200214 (
215 ps.primary_op.activation is None
216 or forced_ofm_quantization is not None
217 or (ps.primary_op.type.is_avgpool_op() and ps.primary_op.activation.op_type.is_relu_op())
218 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100219 and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
220 and not fused_quantize
221 )
222 return use_0
223
224
225def get_ifm_or_ifm2_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
226 """Gets quantization for IFM/IFM2"""
Dwight Lidman4f728c02020-12-17 15:14:45 +0100227 op = ps.primary_op
228 ifm_quant = op.forced_input_quantization if op.forced_input_quantization is not None else tens.quantization
229 if ifm_quant is None:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100230 return None
231 if use_zero_point_0(ps, tens, True):
232 zero_point = 0
233 else:
Dwight Lidman4f728c02020-12-17 15:14:45 +0100234 zero_point = int(ifm_quant.zero_point)
235 return NpuQuantization(scale_f32=ifm_quant.scale_f32, zero_point=zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100236
237
238def get_ofm_quantization(ps, tens: Tensor) -> Optional[NpuQuantization]:
239 """Gets quantization for OFM"""
240 op = ps.primary_op
241 # Check if operation's output quantization is should be used instead of the output tensor's quantization
242 # (used in LUTs)
243 ofm_quant = op.forced_output_quantization if op.forced_output_quantization is not None else tens.quantization
244 if ofm_quant is None:
245 return None
246 if use_zero_point_0(ps, tens, False):
247 zero_point = 0
248 else:
249 zero_point = int(ofm_quant.zero_point)
250 return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
251
252
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100253def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_shape4D: Shape4D) -> NpuFeatureMap:
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100254 """Creates feature map with common fields populated"""
255 fm = NpuFeatureMap()
Louis Verhaard024c3552021-03-17 14:26:34 +0100256 fm.region = get_region(tens.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100257 fm.data_type = dtype_map[tens.dtype]
258 if tens.format == TensorFormat.NHWC:
259 fm.layout = NpuLayout.NHWC
260 elif tens.format == TensorFormat.NHCWB16:
261 fm.layout = NpuLayout.NHCWB16
262 else:
263 assert 0, "Incorrect tensor format"
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100264 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
265 box.start_coord, box.end_coord, op_shape4D
266 )
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100267 for idx, addr in enumerate(addresses):
268 if addr is None:
269 addresses[idx] = 0
270 fm.tiles = NpuTileBox(
271 height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
272 )
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100273 strides = tens.get_strides(shape4D=op_shape4D)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100274 fm.strides = NpuShape3D(height=int(strides[2]), width=int(strides[3]), depth=int(strides[1]))
275 return fm
276
277
Tim Halld784af72021-06-08 21:25:57 +0100278def create_weights(
279 weight_tensor: Tensor, weight_box: Box, scale_tensor: Tensor, arch: ArchitectureFeatures
280) -> List[NpuAddressRange]:
Tim Halld8339a72021-05-27 18:49:40 +0100281 """Returns address ranges for weights and scales"""
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100282 weights = []
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100283 biases = []
Tim Halld784af72021-06-08 21:25:57 +0100284 shared_region = get_region(weight_tensor.mem_type, arch)
285 scale_region = scale_tensor and get_region(scale_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100286
Tim Halld8339a72021-05-27 18:49:40 +0100287 w_tensor_src = weight_tensor
288 if weight_tensor.src_tensor:
289 w_tensor_src = weight_tensor.src_tensor
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100290
Tim Halld8339a72021-05-27 18:49:40 +0100291 core_offset = 0
292 for core in range(0, arch.ncores):
293 # Get weight range per core
294 key = WeightKey(core, weight_box.start_coord[-1])
295 if key in w_tensor_src.encoded_ranges:
296 weight_range = w_tensor_src.encoded_ranges[key]
297 if weight_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
298 assert weight_tensor != w_tensor_src
299 # Double buffered inside weight_tensor
300 address = weight_tensor.address + w_tensor_src.max_range_bytes * ((weight_range.index - core) % 2)
301 address += core_offset
302 core_offset += round_up(weight_range.total_bytes, 16)
303 else:
304 if weight_tensor == w_tensor_src:
305 # Straight from source tensor
306 address = weight_tensor.address + weight_range.offset
307 else:
308 # Single buffered inside weight tensor
309 address = weight_tensor.address + core_offset
310 core_offset += round_up(weight_range.total_bytes, 16)
311
312 # Location of weights in tensor
313 addr_range = NpuAddressRange(
Tim Halld784af72021-06-08 21:25:57 +0100314 shared_region, int(address + weight_range.weight_offset), round_up(int(weight_range.weight_bytes), 16)
Tim Halld8339a72021-05-27 18:49:40 +0100315 )
316 weights.append(addr_range)
Tim Halld784af72021-06-08 21:25:57 +0100317
318 # Location of standalone scales or combined weights tensor scales
319 if scale_tensor:
320 assert scale_tensor.src_tensor is None # Must be standalone
321 scale_range = scale_tensor.encoded_ranges[key]
322 address = scale_tensor.address + scale_range.offset
323 addr_range = NpuAddressRange(scale_region, int(address), round_up(int(scale_range.scale_bytes), 16))
324 else:
325 addr_range = NpuAddressRange(shared_region, int(address), round_up(int(weight_range.scale_bytes), 16))
326
Tim Halld8339a72021-05-27 18:49:40 +0100327 biases.append(addr_range)
328
329 return weights, biases
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100330
331
332def create_npu_activation(op: Operation) -> NpuActivation:
333 """Creates fused activation function"""
334 if op.activation is None:
335 return NpuActivation(NpuActivationOp.NONE_OR_RELU)
336 faf = op.activation.op_type
337 act_op = NpuActivationOp.NONE_OR_RELU
338 if faf == Op.Tanh:
339 act_op = NpuActivationOp.TANH
340 elif faf == Op.Sigmoid:
341 act_op = NpuActivationOp.SIGMOID
342 elif faf == Op.LUT:
343 act_op = NpuActivationOp.TABLE_LOOKUP
344 elif not faf.is_relu_op():
Michael McGeagh7a6f8432020-12-02 15:29:22 +0000345 raise UnsupportedFeatureError(f"Unsupported fused_activation_function: {faf.name}")
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100346
347 act = NpuActivation(act_op)
348 act.min = op.activation.min
349 act.max = op.activation.max
Fredrik Svedberg838df0a2021-09-17 16:29:22 +0200350 if act_op is NpuActivationOp.NONE_OR_RELU and op.type.is_avgpool_op():
351 quant = op.ofm.quantization
352 if quant and quant.zero_point: # Zero point is not 0
353 scale_f32 = 1 if quant.scale_f32 is None else quant.scale_f32
354 zero_point = quant.zero_point
355 if act.min is not None:
356 act.min = scale_f32 * quantise_float32(act.min, scale_f32, zero_point)
357 if act.max is not None:
358 act.max = scale_f32 * quantise_float32(act.max, scale_f32, zero_point)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100359 act.lookup_table_index = op.activation.lut_index
360 return act
361
362
363def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: ArchitectureFeatures):
364 """Sets common fields of the given operation"""
365 ps = cmd.ps
366 op = ps.primary_op
Louis Verhaard69b31762020-11-17 09:45:20 +0100367
368 ifm_height = cmd.ifm_box.get_block().height
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100369 ifm_width = cmd.ps.ifm_shapes[0].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100370 ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100371
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100372 npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100373 npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100374 npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
Louis Verhaard69b31762020-11-17 09:45:20 +0100375
376 out_block = cmd.ofm_box.get_block()
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100377 npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0])
Louis Verhaard69b31762020-11-17 09:45:20 +0100378 npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100379 npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
380
381 if cmd.weight_tensor is not None:
Tim Halld784af72021-06-08 21:25:57 +0100382 npu_op.weights, npu_op.biases = create_weights(cmd.weight_tensor, cmd.weight_box, cmd.scale_tensor, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100383 npu_op.activation = create_npu_activation(op)
Patrik Gustavssonb0ca2742020-11-18 07:59:09 +0100384 npu_op.fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
385 npu_op.rounding_mode = get_rounding_mode(op, npu_op.fused_quantize)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100386 npu_op.block_config = NpuShape3D(height=ps.block_config[0], width=ps.block_config[1], depth=ps.block_config[3])
387
388 if not op.type.is_elementwise_op():
389 npu_op.padding = create_padding(cmd, op)
390 npu_op.kernel = to_npu_kernel(op.kernel)
391 npu_op.ifm_upscale = get_upscale(op)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100392 return npu_op
393
394
395def create_npu_conv2d_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConv2DOperation:
396 """Converts the command to NpuConv2DOperation"""
397 npu_op = NpuConv2DOperation()
398 set_common_op_fields(npu_op, cmd, arch)
399 if cmd.ps.primary_op.type.npu_block_type == NpuBlockType.VectorProduct:
400 npu_op.block_traversal = NpuBlockTraversal.DEPTH_FIRST
401 else:
Tim Halld8339a72021-05-27 18:49:40 +0100402 if cmd.weight_tensor.src_tensor:
403 npu_op.block_traversal = cmd.weight_tensor.src_tensor.hw_traversal
404 else:
405 npu_op.block_traversal = cmd.weight_tensor.hw_traversal
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100406 return npu_op
407
408
409def create_npu_conv_depthwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuConvDepthWiseOperation:
410 """Converts the command to NpuConvDepthWiseOperation"""
411 npu_op = NpuConvDepthWiseOperation()
412 set_common_op_fields(npu_op, cmd, arch)
413 return npu_op
414
415
416def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPoolingOperation:
417 """Converts the command to NpuPoolingOperation"""
418 ps = cmd.ps
419 op = ps.primary_op
420 pool_op = NpuPoolingOp.AVERAGE
421 if op.type.is_maxpool_op():
422 pool_op = NpuPoolingOp.MAX
423 elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
424 pool_op = NpuPoolingOp.AVERAGE
425 elif op.type == Op.ReduceSum:
426 pool_op = NpuPoolingOp.REDUCE_SUM
427 else:
428 assert 0, f"Unknown pool type {op.type}"
429 npu_op = NpuPoolingOperation(pool_op)
430 set_common_op_fields(npu_op, cmd, arch)
431 # Pooling specific info
Dwight Lidman4f728c02020-12-17 15:14:45 +0100432 npu_op.rescale = op.rescale
Patrik Gustavssonc74682c2021-08-17 14:26:38 +0200433 if op.explicit_scaling:
434 # Note: reuse of rescale for explicit scaling to not expose this in the external API
435 assert npu_op.rescale is None
436 npu_op.rescale = op.explicit_scaling
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100437 return npu_op
438
439
440def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuElementWiseOperation:
441 """Converts the command to NpuElementWiseOperation"""
442 ps = cmd.ps
443 op = ps.primary_op
444 assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
445 elemwise_op = elementwise_op_map[op.type]
446 npu_op = NpuElementWiseOperation(elemwise_op)
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100447
Louis Verhaard1e170182020-11-26 11:42:04 +0100448 if elemwise_op not in UNARY_ELEMWISE_OPS:
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100449 ifm_shape = [] if cmd.ifm_tensor.shape == [] else ps.ifm_shapes[0].as_list()
450 ifm2_shape = [] if cmd.ifm2_tensor.shape == [] else ps.ifm_shapes[1].as_list()
451 if not ifm_ifm2_correct_order(ifm_shape, ifm2_shape):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100452 # The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
453 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
454 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100455 ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100456 npu_op.reversed_operands = True
Patrik Gustavsson2349d422020-12-01 16:02:29 +0100457 npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100458 npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
459 if cmd.ifm2_tensor.shape == []:
460 # scalar
James Peet7519d502021-07-19 16:47:58 +0100461 npu_op.ifm2_scalar = cmd.ifm2_tensor.get_scalar()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100462 npu_op.ifm2.shape = NpuShape3D(height=0, width=0, depth=0)
463 else:
Louis Verhaard69b31762020-11-17 09:45:20 +0100464 ifm2_blk = cmd.ifm2_box.get_block()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100465 ifm2_width = ps.ifm_shapes[1].width
Louis Verhaard69b31762020-11-17 09:45:20 +0100466 npu_op.ifm2.shape = NpuShape3D(height=ifm2_blk.height, width=ifm2_width, depth=ifm2_blk.depth)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100467 set_common_op_fields(npu_op, cmd, arch)
468 # Check if output scale needs to be overridden
469 output_scale = None
470 if op.type == Op.Add and "resizebilinear" in op.attrs:
471 # Force output scale same as the input scale for
472 # resizebilinear 1x1 that is converted to add
473 output_scale = npu_op.ifm2.quantization.scale_f32
Fredrik Svedbergf2afd7f2021-02-01 21:42:12 +0100474 if op.type == Op.Abs:
475 output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100476 if op.type == Op.LeakyRelu:
477 output_scale = op.attrs["alpha"]
Patrik Gustavssonb081d672021-08-25 13:49:25 +0200478 if op.type in (Op.RescaleAdd, Op.RescaleMul):
Fredrik Svedberge82be7c2021-01-18 15:21:03 +0100479 assert op.rescale is not None, f"{op.type} must have rescale"
480 npu_op.rescale = op.rescale
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100481 if op.type in (Op.Add, Op.Mul, Op.Sub):
482 if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
483 output_scale = 1 / 0x3000
484 if output_scale is not None:
485 npu_op.ofm.quantization = NpuQuantization(scale_f32=output_scale, zero_point=npu_op.ofm.quantization.zero_point)
486 return npu_op
487
488
489def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
490 """Converts the command to NpuDmaOperation"""
Louis Verhaard024c3552021-03-17 14:26:34 +0100491 src_region = get_region(cmd.in_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100492 if cmd.out_tensor.purpose == TensorPurpose.LUT:
Louis Verhaard1e170182020-11-26 11:42:04 +0100493 dest_region = BASE_PTR_INDEX_MEM2MEM
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100494 else:
Louis Verhaard024c3552021-03-17 14:26:34 +0100495 dest_region = get_region(cmd.out_tensor.mem_type, arch)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100496
Tim Halld8339a72021-05-27 18:49:40 +0100497 if cmd.in_tensor.purpose == TensorPurpose.Weights:
498 # Get weight range per core
499 sz = 0
500 for core in range(0, arch.ncores):
501 key = WeightKey(core, cmd.box.start_coord[-1])
502 if key in cmd.in_tensor.encoded_ranges:
503 weight_range = cmd.in_tensor.encoded_ranges[key]
504 sz += round_up(weight_range.total_bytes, 16)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100505
Tim Halld8339a72021-05-27 18:49:40 +0100506 if core == 0:
507 weight_range = cmd.in_tensor.encoded_ranges[key]
508 src_addr = cmd.in_tensor.address + weight_range.offset
509
510 if cmd.out_tensor.sub_purpose == TensorSubPurpose.DoubleBuffer:
511 dest_addr = cmd.out_tensor.address + cmd.in_tensor.max_range_bytes * (
512 (weight_range.index - core) % 2
513 )
514 else:
515 dest_addr = cmd.out_tensor.address
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100516 else:
Tim Halld8339a72021-05-27 18:49:40 +0100517 start_coord = cmd.box.start_coord
518 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
519 dest_addr = cmd.out_tensor.address_for_coordinate(start_coord)
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100520 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
521 src = NpuAddressRange(src_region, int(src_addr), int(sz))
522 dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
523 return NpuDmaOperation(src, dest)
524
525
526def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
527 """Converts the high level command to NpuOperation"""
Dwight Lidman9b43f842020-12-08 17:56:44 +0100528 npu_op: NpuOperation
529 if isinstance(cmd, DMA):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100530 npu_op = create_dma_op(cmd, arch)
Dwight Lidman9b43f842020-12-08 17:56:44 +0100531 elif isinstance(cmd, NpuStripe):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100532 npu_block_type = cmd.ps.primary_op.type.npu_block_type
533 if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
534 npu_op = create_npu_conv2d_op(cmd, arch)
535 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
536 npu_op = create_npu_conv_depthwise_op(cmd, arch)
537 elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
538 npu_op = create_npu_pool_op(cmd, arch)
539 elif npu_block_type == NpuBlockType.ElementWise:
540 npu_op = create_npu_elementwise_op(cmd, arch)
541 else:
542 assert 0, f"Unknown command type {npu_block_type}"
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100543 return npu_op
Louis Verhaard1e170182020-11-26 11:42:04 +0100544
545
546def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
547 """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
548 # Convert high level command stream to list of NpuOperation
549 npu_op_list = []
550 npu_op_to_cmd = dict() # map from npu op to high level command
551 for cmd in sg.high_level_command_stream:
Dwight Lidman9b43f842020-12-08 17:56:44 +0100552 if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
Louis Verhaard1e170182020-11-26 11:42:04 +0100553 print("Warning: Skipping register command stream generation for", cmd.ps)
554 else:
555 npu_op = convert_command_to_npu_op(cmd, arch)
556 npu_op_list.append(npu_op)
557 npu_op_to_cmd[npu_op] = cmd
Louis Verhaard024c3552021-03-17 14:26:34 +0100558 mem_limits = get_mem_limits_for_regions(arch)
Louis Verhaard1e170182020-11-26 11:42:04 +0100559 # Generate register commands
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100560 if len(sg.high_level_command_stream) > 0:
561 stream_id = DebugDatabase.add_stream(sg)
562 sg.generated_stream_id = stream_id
Louis Verhaard1e170182020-11-26 11:42:04 +0100563
erik.andersson@arm.comad45f792021-02-03 10:20:16 +0100564 def add_to_debug_db(npu_op: NpuOperation, offset: int):
565 """Adds info to the debug database"""
566 if not isinstance(npu_op, NpuDmaOperation):
567 cmd = npu_op_to_cmd[npu_op]
568 DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
Louis Verhaard1e170182020-11-26 11:42:04 +0100569
Louis Verhaard024c3552021-03-17 14:26:34 +0100570 sg.register_command_stream = generate_command_stream(
571 npu_op_list, arch, verbose, mem_limits, add_to_debug_db, npu_op_to_cmd
572 )