blob: f49df259f6da300022e0f35d85d1fe1d313f996b [file] [log] [blame]
Louis Verhaarde8a5a782020-11-02 18:04:27 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
Louis Verhaardaeae5672020-11-02 18:04:27 +010018# Contains external APIs
Louis Verhaarde8a5a782020-11-02 18:04:27 +010019from enum import auto
20from enum import Enum
21from typing import List
22from typing import NamedTuple
23from typing import Optional
24from typing import Tuple
25
Louis Verhaardaeae5672020-11-02 18:04:27 +010026import numpy
27
Patrik Gustavssonc74682c2021-08-17 14:26:38 +020028
Louis Verhaard11831ce2020-11-18 18:53:24 +010029API_VERSION_MAJOR = 1
Dwight Lidmaneed30ea2021-11-25 17:34:42 +010030API_VERSION_MINOR = 2
Louis Verhaard11831ce2020-11-18 18:53:24 +010031API_VERSION = f"{API_VERSION_MAJOR}.{API_VERSION_MINOR}"
Patrik Gustavssonc8a22f12020-11-18 17:05:50 +010032
Louis Verhaarde8a5a782020-11-02 18:04:27 +010033
Louis Verhaardaeae5672020-11-02 18:04:27 +010034class NpuAccelerator(Enum):
35 """
36 Supported accelerators
37 """
38
39 Ethos_U55_32 = auto()
40 Ethos_U55_64 = auto()
41 Ethos_U55_128 = auto()
42 Ethos_U55_256 = auto()
43 Ethos_U65_256 = auto()
44 Ethos_U65_512 = auto()
45
46
Louis Verhaarde8a5a782020-11-02 18:04:27 +010047class NpuElementWiseOp(Enum):
48 """
49 Elementwise operation
50 """
51
52 ADD = auto()
53 SUB = auto()
54 MUL = auto()
55 ABS = auto()
56 MIN = auto()
57 MAX = auto()
58 LRELU = auto() # Leaky relu
59 CLZ = auto() # Number leading zeros
60 SHR = auto() # Rounded right-shift
61 SHL = auto() # Bitwise shift-left
62
63
64class NpuPoolingOp(Enum):
65 """
66 Pooling operation
67 """
68
69 MAX = auto()
70 AVERAGE = auto()
71 REDUCE_SUM = auto()
72
73
74class NpuActivationOp(Enum):
75 """
76 Activation function
77 """
78
79 NONE_OR_RELU = auto() # Clamps output using min/max
80 TANH = auto()
81 SIGMOID = auto()
82 TABLE_LOOKUP = auto() # Performs table look-up, using the provided table lookup index
83
84
85class NpuRoundingMode(Enum):
86 """
87 Available rounding modes
88 """
89
90 TFL = auto() # TensorFlow Lite rounding
91 TRUNCATE = auto() # Truncate towards zero
92 NATURAL = auto() # Round to nearest with x.5 rounded up, towards +infinity
93
94
95class NpuLayout(Enum):
96 """
97 Tensor layout of feature maps
98 """
99
100 NHWC = auto()
101 NHCWB16 = auto()
102
103 def __str__(self):
104 return self.name
105
106
107class NpuResamplingMode(Enum):
108 """
109 Resampling mode
110 """
111
112 NONE = auto() # No resampling is performed
113 NEAREST = auto() # 2x2 insert nearest
114 TRANSPOSE = auto() # 2x2 transpose
115
116
117class NpuBlockTraversal(Enum):
118 """
119 Block-traversal of weights
120 """
121
122 DEPTH_FIRST = auto()
123 PART_KERNEL_FIRST = auto()
124
125
126class NpuDataType(Enum):
127 """
128 Supported data types in feature maps
129 """
130
131 UINT8 = 8, False, auto()
132 INT8 = 8, True, auto()
133 UINT16 = 16, False, auto()
134 INT16 = 16, True, auto()
135 INT32 = 32, True, auto()
136
137 def is_signed(self) -> bool:
138 """Checks if this data type is signed or unsigned"""
139 return self.value[1]
140
141 def size_in_bits(self) -> int:
142 """ Size of the data type in bits"""
143 return self.value[0]
144
145 def size_in_bytes(self) -> int:
146 """ Size of the data type in bytes"""
147 return self.value[0] // 8
148
149 def min_value(self) -> int:
150 """Minimum value of this type"""
151 if self.is_signed():
152 return -(1 << (self.size_in_bits() - 1))
153 else:
154 return 0
155
156 def max_value(self) -> int:
157 """Maximum value of this type"""
158 if self.is_signed():
159 return (1 << (self.size_in_bits() - 1)) - 1
160 else:
161 return (1 << self.size_in_bits()) - 1
162
163 def __str__(self):
164 return self.name
165
166 __repr__ = __str__
167
168
169class NpuAddressRange(NamedTuple):
170 """
171 Address range
172 """
173
174 region: int # Memory region, a value between 0 and 7
175 address: int # Address, offset from the region's base address
176 length: int # The length of the range, in bytes
177
178 def __str__(self):
179 return f"(region={self.region}, address={hex(self.address)}, length={self.length})"
180
181
182class NpuTileBox(NamedTuple):
183 """
184 Specifies the addresses and dimensions of the tiles of a feature map.
185 A feature map can use 1 to 4 tiles
186 """
187
188 height_0: int # The height of tile 0
189 height_1: int # The height of tile 1, 0 if unused
190 width_0: int # the width of tile 0, and tile 2 (if used)
191 addresses: List[int] # A list of 4 addresses, set unused addresses to 0
192
193
194class NpuShape3D(NamedTuple):
195 """
196 Shape of (part of) a feature map
197 """
198
199 height: int
200 width: int
201 depth: int
202
203
204class NpuQuantization(NamedTuple):
205 """
206 Quantization parameters
207 """
208
209 scale_f32: Optional[float]
210 zero_point: int
211
212
213class NpuPadding(NamedTuple):
214 """
215 Padding to be applied to a convolution operation
216 """
217
218 top: int
219 left: int
220 bottom: int
221 right: int
222
223
224class NpuActivation:
225 """
226 Activation function, fused with NPU operations
227 """
228
229 def __init__(self, op_type: NpuActivationOp):
230 self.op_type = op_type # The activation operation to be performed
231 # min/max are optional
232 self.min: Optional[float] = None # E.g. set to 0.0 for RELU
233 self.max: Optional[float] = None # E.g. set to 6.0 for RELU6
234 # Table lookup index, only applicable for TABLE_LOOKUP activation, 0-7
235 self.lookup_table_index: int = 0
236
237
238class NpuFeatureMap:
239 """
240 Basic information about IFM, IFM2, OFM
241 """
242
243 def __init__(self):
244 self.data_type: NpuDataType = NpuDataType.UINT8
245 # The memory region, a value 0-7
246 self.region: int = 0
247 # Shape of the feature map
248 self.shape: NpuShape3D = NpuShape3D(height=0, width=0, depth=0)
249 # The tiles that comprise the feature map. In the normal case when only 1 tile is used,
250 # height_0 == self.shape.height, height_1 is 0, width_0 == self.shape.width, addresses[1:] are set to 0
251 self.tiles: NpuTileBox = NpuTileBox(height_0=0, height_1=0, width_0=0, addresses=[0, 0, 0, 0])
252 self.quantization: Optional[NpuQuantization]
253 self.layout: NpuLayout = NpuLayout.NHWC
254 # x/y/c strides used by the NPU when traversing the feature map, if None, vela will use default strides
255 self.strides: Optional[NpuShape3D] = None
256
257
258class NpuKernel:
259 """
260 Kernel information for NPU operations
261 """
262
263 def __init__(self, w: int, h: int, stride_x: int = 1, stride_y: int = 1, dilation_x: int = 1, dilation_y: int = 1):
264 assert stride_x > 0 and stride_y > 0
265 assert dilation_x > 0 and dilation_y > 0
266 self.width = w
267 self.height = h
268 self.stride_x = stride_x
269 self.stride_y = stride_y
270 self.dilation_x = dilation_x
271 self.dilation_y = dilation_y
272
273
274class NpuOperationType(Enum):
275 """
276 Type of NPU operation
277 """
278
279 Dma = auto()
280 Conv2D = auto()
281 ConvDepthWise = auto()
282 Pooling = auto()
283 ElementWise = auto()
284
285
286class NpuOperation:
287 """
288 Base class for all NPU operations
289 """
290
291 def __init__(self, op_type: NpuOperationType):
292 self.op_type = op_type
293
294
295class NpuDmaOperation(NpuOperation):
296 """
297 DMA operation
298 """
299
300 def __init__(self, src: NpuAddressRange, dest: NpuAddressRange):
301 super().__init__(NpuOperationType.Dma)
302 self.src = src
303 self.dest = dest
304 # DMA channel, usually 0 (user channel)
305 self.channel: int = 0
306 # Channel mode, 0 = external, 1 = internal (should usually be 0)
307 self.mode: int = 0
308
309
310class NpuBlockOperation(NpuOperation):
311 """
312 Base class for operations which produce an OFM
313 """
314
315 def __init__(self, op_type: NpuOperationType):
316 super().__init__(op_type)
317 self.ifm: Optional[NpuFeatureMap] = None
318 self.ifm2: Optional[NpuFeatureMap] = None
319 # The non-quantized scalar value in a binary elementwise operation. Only set if IFM2 is scalar
320 self.ifm2_scalar: Optional[float] = None
321 self.ofm: Optional[NpuFeatureMap] = None
322 self.kernel: Optional[NpuKernel] = None
323 # Weights, one element for each NPU core, empty if no weights are used.
Louis Verhaard933f55e2020-11-25 14:10:30 +0100324 # Must have been compressed using npu_encode_weights()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100325 self.weights: List[NpuAddressRange] = []
326 # Biases, one element for each NPU core, empty if no bias is used.
Louis Verhaard933f55e2020-11-25 14:10:30 +0100327 # Must have been encoded using npu_encode_bias()
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100328 self.biases: List[NpuAddressRange] = []
329 self.padding: Optional[NpuPadding] = None
330 # Optional activation function to be applied
331 self.activation: Optional[NpuActivation] = None
Louis Verhaard933f55e2020-11-25 14:10:30 +0100332 # The block config to be used, which must be valid for the given operation.
333 # See also npu_find_block_configs.
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100334 # If the operation has weights, the depth of the block config must be the same as
Louis Verhaard933f55e2020-11-25 14:10:30 +0100335 # the ofm depth used in the call to npu_encode_weights()
336 self.block_config: NpuShape3D
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100337 self.rounding_mode: NpuRoundingMode = NpuRoundingMode.TFL
338 # Set to True if the operations is fused with a Quantize operation (affects scaling)
339 self.fused_quantize: bool = False
340 # IFM upscaling to be applied
341 self.ifm_upscale: NpuResamplingMode = NpuResamplingMode.NONE
342
343
344class NpuConv2DOperation(NpuBlockOperation):
345 """
346 NPU_OP_CONV operation
347 """
348
349 def __init__(self):
350 super().__init__(NpuOperationType.Conv2D)
351 # Block traversal must be consistent with the block_traversal parameter specified in
352 # weight_compressor.encode_weights()
353 self.block_traversal: NpuBlockTraversal = NpuBlockTraversal.PART_KERNEL_FIRST
354
355
356class NpuConvDepthWiseOperation(NpuBlockOperation):
357 """
358 NPU_OP_DEPTHWISE operation
359 """
360
361 def __init__(self):
362 super().__init__(NpuOperationType.ConvDepthWise)
363
364
365class NpuPoolingOperation(NpuBlockOperation):
366 """
367 NPU_OP_POOL operation
368 """
369
370 def __init__(self, pooling_op_type: NpuPoolingOp):
371 super().__init__(NpuOperationType.Pooling)
372 self.sub_op_type: NpuPoolingOp = pooling_op_type
373 # Set to a float value for ResizeBilinear operations (affects scaling), else to None
374 self.rescale: Optional[float] = None
375
376
377class NpuElementWiseOperation(NpuBlockOperation):
378 """
379 NPU_OP_ELEMENTWISE operation
380 """
381
382 def __init__(self, elementwise_op_type: NpuElementWiseOp):
383 super().__init__(NpuOperationType.ElementWise)
384 self.sub_op_type: NpuElementWiseOp = elementwise_op_type
385 # Set to True for binary operators where IFM2 should be used as first operand
386 self.reversed_operands: bool = False
387 # Set to a tuple (scale, shift) for explicit rescale, else to None
388 self.rescale: Optional[Tuple] = None
Patrik Gustavssonc8a22f12020-11-18 17:05:50 +0100389
390
Louis Verhaard11831ce2020-11-18 18:53:24 +0100391def npu_get_api_version():
Patrik Gustavssonc8a22f12020-11-18 17:05:50 +0100392 """
393 Public facing API to get the API version
394 :return: int, the 16 most significant bits, corresponding to major version
395 the 16 least significant bits, corresponding to minor version
396 """
Louis Verhaard11831ce2020-11-18 18:53:24 +0100397 version = (API_VERSION_MAJOR << 16) | (API_VERSION_MINOR & 0xFFFF)
Patrik Gustavssonc8a22f12020-11-18 17:05:50 +0100398 return version
Louis Verhaardaeae5672020-11-02 18:04:27 +0100399
400
401def npu_encode_weights(
402 accelerator: NpuAccelerator,
403 weights_volume: numpy.ndarray,
404 dilation_xy: Tuple[int, int],
405 ifm_bitdepth: int,
406 ofm_block_depth: int,
407 is_depthwise: bool,
408 block_traversal: NpuBlockTraversal,
409):
410 """
411 Public facing API to use the Ethos-U weight encoding.
412
413 :param accelerator: NpuAccelerator enum to pick the correct accelerator
414 :param weights_volume: numpy.ndarray in OHWI layout with a shape of four
415 :param dilation_xy: a two element tuple of dilation attributes in x,y dimension
416 :param ifm_bitdepth: the bitdepth of input feature map
417 :param ofm_block_depth: the depth of blocks for processing
418 :param is_depthwise: a boolean indicating these weights are used for a depthwise traversal
419 :param block_traversal: indicates how these weights are traversed on sub-kernel basis
Fredrik Svedbergf5c07c42021-04-23 14:36:42 +0200420 :return: a bytearray of encoded weights
Louis Verhaardaeae5672020-11-02 18:04:27 +0100421 """
422 from .architecture_features import Accelerator
423 from . import weight_compressor
424
425 acc = Accelerator.from_npu_accelerator(accelerator)
Fredrik Svedbergf5c07c42021-04-23 14:36:42 +0200426 encoded_weights, _ = weight_compressor.encode_weights(
Louis Verhaardaeae5672020-11-02 18:04:27 +0100427 acc, weights_volume, dilation_xy, ifm_bitdepth, ofm_block_depth, is_depthwise, block_traversal
428 )
Fredrik Svedbergf5c07c42021-04-23 14:36:42 +0200429 return encoded_weights
Louis Verhaardaeae5672020-11-02 18:04:27 +0100430
431
432def npu_encode_bias(bias: numpy.int64, scale: int, shift: int):
433 """
434 Public facing API to pack bias and scale values as required by the hardware
435 :param bias: 64-bit signed number that includes 40-bit signed bias
436 :param scale: 32-bit scale value
437 :param shift: 6-bit shift value
438 :return: packed 80-bit [0(2-bits),shift(6-bits),scale(32-bits),bias(40-bits)]
439 """
440 from . import weight_compressor
441
442 return weight_compressor.encode_bias(bias, scale, shift)
443
444
Louis Verhaard933f55e2020-11-25 14:10:30 +0100445def npu_find_block_configs(npu_op: NpuOperation, accelerator: NpuAccelerator) -> List[NpuShape3D]:
446 """
447 Public facing API that returns a list of block configs that are valid for the given operation.
448 This function can be used to find a valid value for npu_op.block_config.
449 The block config is the unit of work in which the NPU generates the OFM.
450 """
Jacob Bohlinb8060f52021-08-09 12:22:51 +0100451 from .architecture_features import Accelerator
452 from .architecture_features import ArchitectureFeatures
453 from .architecture_features import Block
454 from .architecture_features import create_default_arch
455 from .architecture_allocator import try_block_config
456 from .register_command_stream_generator import resampling_mode_map
457 from .register_command_stream_util import to_kernel
458 from .operation import NpuBlockType
Louis Verhaard933f55e2020-11-25 14:10:30 +0100459
Jacob Bohlinb8060f52021-08-09 12:22:51 +0100460 is_partkernel = False
461 if isinstance(npu_op, NpuConv2DOperation):
462 block_type = NpuBlockType.ConvolutionMxN
463 is_partkernel = npu_op.block_traversal == NpuBlockTraversal.PART_KERNEL_FIRST
464 elif isinstance(npu_op, NpuConvDepthWiseOperation):
465 block_type = NpuBlockType.ConvolutionDepthWise
466 elif isinstance(npu_op, NpuPoolingOperation):
467 block_type = NpuBlockType.ReduceSum if npu_op.sub_op_type == NpuPoolingOp.REDUCE_SUM else NpuBlockType.Pooling
468 elif isinstance(npu_op, NpuElementWiseOperation):
469 block_type = NpuBlockType.ElementWise
470 else:
471 assert 0, "Unsupported operation"
472
473 ifm_shape = Block(npu_op.ifm.shape.width, npu_op.ifm.shape.height, npu_op.ifm.shape.depth)
474 ifm2_shape = None
475 if npu_op.ifm2:
476 ifm2_shape = Block(npu_op.ifm2.shape.width, npu_op.ifm2.shape.height, npu_op.ifm2.shape.depth)
477 ofm_shape = Block(npu_op.ofm.shape.width, npu_op.ofm.shape.height, npu_op.ofm.shape.depth)
478
479 ifm_resampling_mode = resampling_mode_map[npu_op.ifm_upscale]
480 ifm_bits = npu_op.ifm.data_type.size_in_bits()
481 kernel = to_kernel(npu_op.kernel)
482 lut_banks = 0
483 if npu_op.activation:
484 lut_banks = 2 if npu_op.activation.op_type == NpuActivationOp.TABLE_LOOKUP else 0
485
486 has_scaling = True
487 for tensor in [npu_op.ifm, npu_op.ifm2, npu_op.ofm]:
488 if tensor and tensor.quantization is None:
489 has_scaling = False
490 break
491
492 arch = create_default_arch(Accelerator.from_npu_accelerator(accelerator))
493
494 max_block_width = min(arch.ofm_block_max.width, ofm_shape.width)
495 max_block_height = min(arch.ofm_block_max.height, ofm_shape.height)
496 max_block_depth = min(arch.ofm_block_max.depth, ofm_shape.depth)
497
498 min_block_height = max(arch.ofm_ublock.height, 2 if ifm_resampling_mode != NpuResamplingMode.NONE else 1)
499 min_block_width = max(arch.ofm_ublock.width, 2 if ifm_resampling_mode != NpuResamplingMode.NONE else 1)
500
501 valid_block_configs = []
502 for w in range(min_block_width, max_block_width + min_block_width, min_block_width):
503 for h in range(min_block_height, max_block_height + min_block_height, min_block_height):
504 # Try valid OFM block depths
505 for c in range(arch.ofm_ublock.depth, max_block_depth + arch.ofm_ublock.depth, arch.ofm_ublock.depth):
506 # OFM block depth has the constraint that if it causes the OFM to be
507 # split, it must be a multiple of the OFM split size
508 if (c >= max_block_depth) or (c < max_block_depth and (c % ArchitectureFeatures.OFMSplitDepth) == 0):
509 block = Block(w, h, c)
510 config = try_block_config(
511 block,
512 arch,
513 block_type,
514 ofm_shape,
515 ifm_shape,
516 ifm2_shape,
517 npu_op.ifm2_scalar is not None,
518 ifm_bits,
519 is_partkernel,
520 kernel,
521 lut_banks,
522 has_scaling,
523 ifm_resampling_mode,
524 )
525
526 if config:
527 ofm_block = config.ofm_block
528 valid_block_configs.append(NpuShape3D(ofm_block.height, ofm_block.width, ofm_block.depth))
529
530 assert len(valid_block_configs) > 0
531 return valid_block_configs
Louis Verhaard933f55e2020-11-25 14:10:30 +0100532
533
Louis Verhaardaeae5672020-11-02 18:04:27 +0100534def npu_generate_register_command_stream(npu_op_list: List[NpuOperation], accelerator: NpuAccelerator) -> List[int]:
535 """
536 Public facing API for generating an Ethos-U register command stream.
537 Calculates dependencies between commands and inserts wait operations if needed.
538
539 :param npu_op_list: List[NpuOperation] list of high level NPU operations
540 :param accelerator: NpuAccelerator enum to pick the correct accelerator
541 :return register commands, as a list of 32-bit integers
542 """
543 from . import register_command_stream_generator
544
545 return register_command_stream_generator.generate_register_command_stream(npu_op_list, accelerator)
Louis Verhaard52078302020-11-18 13:35:06 +0100546
547
548def npu_create_driver_payload(register_command_stream: List[int], accelerator: NpuAccelerator) -> bytes:
549 """
550 Public facing API for generating driver payload, containing a driver header
551 and the given Ethos-U register command stream.
552 Returns the payload, in little endian format, which must be placed in memory on a 16-byte aligned
553 address.
554
555 :param register_command_stream: List[int] register commands, as a list of 32-bit integers
556 :param accelerator: NpuAccelerator enum to pick the correct accelerator
557 :return driver payload, as a byte array
558 """
559 from . import driver_actions
560
561 return driver_actions.npu_create_driver_payload(register_command_stream, accelerator)