blob: 5563b96952a3b7294a95147f748b521983318fdd [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18# Description:
19# Register level (low-level) command stream generation for Ethos-U55. Takes a high-level command stream and generates
20# all the register settings. Calculates dependencies between commands and inserts wait operations. And generates a bit
21# stream suitable for interpretation by the Ethos-U55 processor.
22
23from collections import defaultdict
24from enum import Enum, IntEnum
25from .high_level_command_stream import CommandType
26from .ethos_u55_regs.ethos_u55_regs import *
27from .tensor import MemArea, TensorBlockTraversal
28from .operation import NpuBlockType
29from .numeric_util import quantise_float32, round_up, round_away_zero, round_up_to_int, clamp_sigmoid, clamp_tanh
30from .data_type import BaseType
31import numpy as np
32from .shared_buffer_allocation import SharedBufferAllocation
33from .architecture_features import SharedBufferArea, SHRAMElements, ArchitectureFeatures
34from .nn_graph import TensorFormat, SchedulingStrategy
35from .range_set import (
36 MemoryAccessSet,
37 AccessDirection,
38)
39from .mark_tensors import (
40 reshape_operations,
41)
42from .architecture_features import Block, Kernel, Rect
43from . import scaling
44
45
46class RegisterMachine:
47 def __init__(self):
48 self.n_banks = 1
49 self.registers = [defaultdict(lambda: None) for _ in range(self.n_banks)]
50 self.bank_idx = 0
51
52 def set_register(self, reg, value):
53 is_changed = self.registers[self.bank_idx][reg] != value
54 self.registers[self.bank_idx][reg] = value
55 # is_changed = True # force command
56 return is_changed
57
58 def switch_bank(self):
59 self.bank_idx = (self.bank_idx + 1) % self.n_banks
60
61
62class CmdMode(IntEnum):
63 NoPayload = 0x0000
64 Payload32 = 0x4000
65 Mask = 0xC000
66 CmdOpMask = 0x03FF
67
68
69class BasePointerIndex(IntEnum):
70 ReadOnly = 0 # base address slot index for weights and scaling
71 Scratch = 1 # base address slot index for scratch memory area
72
73
74# TODO: Replace with definitions from ethos_u55_regs
75class IFM2Broadcast(IntEnum):
76 BroadcastHdim = 1 << 0
77 BroadcastWdim = 1 << 1
78 BroadcastCdim = 1 << 2
79 ReverseOperandOrder = 1 << 6
80 UseIFM2Scalar = 1 << 7
81
82
83class CommandStreamEmitter:
84 def __init__(self):
85 self.cmd_stream = []
86 self.reg_machine = [RegisterMachine(), RegisterMachine()]
87 self.last_absolute_wait = defaultdict(int)
88
89 def get_reg_machine(self, cmd):
90 if "DMA" in cmd.name:
91 return self.reg_machine[1]
92 else:
93 return self.reg_machine[0]
94
95 def size_in_bytes(self):
96 sz = 0
97 for cmd in self.cmd_stream:
98 sz += len(cmd) * 4
99 return sz
100
101 def to_list(self):
102 return [elem for cmd in self.cmd_stream for elem in cmd]
103
104 def print_cmds(self):
105 print("Code: Command: Param: Payload:")
106 for words_for_one_command in self.cmd_stream:
107 code = words_for_one_command[0] & 0x0000FFFF # lower 16 bits
108 param = words_for_one_command[0] >> 16 # higher 16 bits
109
110 payload_mode = CmdMode(code & CmdMode.Mask)
111
112 # code and command
113 s = " 0x%04x " % code
114 if payload_mode == CmdMode.NoPayload:
115 s += str(cmd0(code & CmdMode.CmdOpMask))
116 else:
117 s += str(cmd1(code & CmdMode.CmdOpMask))
118
119 s = s.ljust(40)
120 s += "%5d" % param
121
122 # payload
123 if payload_mode == CmdMode.Payload32:
124 s += " 0x%08x (%d)" % (words_for_one_command[1], words_for_one_command[1])
125 else:
126 s += " -"
127
128 print(s)
129
130 def cmd0_with_param(self, cmd, param):
131 if isinstance(param, Enum):
132 param = int(param.value)
133 else:
134 param = int(param)
135 param = param & 0xFFFF
136 command = cmd.value | (param << 16)
137 if not self.get_reg_machine(cmd).set_register(cmd, (command, param)):
138 return
139
140 # This is not a redundant command, actually write it
141 self.cmd_stream.append((command,))
142
143 def cmd1_with_offset(self, cmd, offset, param=0x0):
144 offset = int(offset) & 0xFFFFFFFFF
145 command = cmd.value | CmdMode.Payload32.value | (param << 16)
146
147 if not self.get_reg_machine(cmd).set_register(cmd, (command, offset)):
148 return
149
150 # This is not a redundant command, actually write it
151 self.cmd_stream.append((command, offset))
152
153 def cmd_wait(self, cmd, param, absolute_wait_time):
154 if absolute_wait_time <= self.last_absolute_wait[cmd]:
155 return
156
157 self.last_absolute_wait[cmd] = absolute_wait_time
158 param = int(param)
159 command = ((param & 0xFFFF) << 16) | cmd.value
160 self.cmd_stream.append((command,))
161
162 def cmd_do_operation(self, cmd, param=0):
163 param = int(param)
164 command = ((param & 0xFFFF) << 16) | cmd.value
165
166 self.cmd_stream.append((command,))
167 self.get_reg_machine(cmd).switch_bank()
168
169
170def calc_command_dependencies(cmd_stream, arch):
171 cmd_starts = {}
172 cmd_ends = {}
173 memory_accesses = {}
174
175 # Keep track of accumulated number of commands in command stream.
176 # First element kernel ops: (# of blocks, # of commands)
177 # Second element DMA ops: (# of commands)
178 pos = np.array((np.array((0, 0)), np.array([0])))
179
180 dependencies = {}
181
182 for cmd in cmd_stream:
183 cmd_starts[cmd] = pos
184 op_count = cmd.get_operation_count()
185 # Keep track of both num blocks and commands
186 cmd_add = 0 if (op_count[0] == 0) else 1
187 pos = np.array((pos[0] + np.array((op_count[0], cmd_add)), pos[1] + np.array([op_count[1]])))
188 cmd_ends[cmd] = np.array((pos[0], pos[1]))
189 memory_accesses[cmd] = cmd.get_memory_accesses()
190
191 for idx, cmd in enumerate(cmd_stream):
192 curr_accesses = memory_accesses[cmd]
193 # Keep track of command dependency.
194 # First element kernel ops: (# of blocks, # of commands)
195 # Second element DMA ops: (# of commands)
196 dep_offsets = np.array((np.array((-1, -1)), np.array([-1])))
197 dep_cmds = [None] * CommandType.Size.value
198 if idx > 0:
199 # Look at the previous commands in backwards order
200 for prev_cmd in cmd_stream[idx - 1 :: -1]:
201 assert prev_cmd is not cmd
202 if dep_cmds[prev_cmd.cmdtype] is None:
203 is_dependency = False
204 if cmd.cmdtype == CommandType.NpuStripe and prev_cmd.cmdtype == CommandType.NpuStripe:
205 # Special handling here, as dpu -> dpu operations require additional care
206 if not SharedBufferAllocation.is_compatible(prev_cmd.ps.shared_buffer, cmd.ps.shared_buffer):
207 is_dependency = True
208 elif memory_accesses[prev_cmd].conflicts(curr_accesses):
209 is_dependency = True
210 else:
211 if memory_accesses[prev_cmd].conflicts(curr_accesses):
212 is_dependency = True
213
214 if is_dependency:
215 new_offset = cmd_ends[prev_cmd][prev_cmd.cmdtype]
216 if new_offset[0] > dep_offsets[prev_cmd.cmdtype][0]:
217 dep_cmds[prev_cmd.cmdtype] = prev_cmd
218 dep_offsets[prev_cmd.cmdtype] = new_offset
219
220 # Check if we've got dependencies for all commands, in which case we can early out
221 for dep in dep_cmds:
222 if dep is None:
223 break
224 else:
225 break # all handled
226
227 # Convert absolute to relative dependencies, using None to signal the special case of no
228 # dependency of this kind
229 res = [None] * CommandType.Size.value
230 for i in range(CommandType.Size.value):
231 if dep_cmds[i] is not None:
232 res[i] = cmd_starts[cmd][i] - dep_offsets[i]
233
234 dependencies[cmd] = cmd_starts[cmd], res
235
236 return dependencies
237
238
239def get_op_kernel(ps):
240 if ps.primary_op is None:
241 return None
242
243 strides = ps.primary_op.attrs.get("strides", (1, 1, 1, 1))
244 dilation = ps.primary_op.attrs.get("dilation", (1, 1, 1, 1))
245 if ps.weight_tensor:
246 if ps.npu_block_type in set((NpuBlockType.VectorProduct, NpuBlockType.ElementWise)):
247 k_h = 1
248 k_w = 1
249 else:
250 k_h = ps.weight_tensor.shape[0]
251 k_w = ps.weight_tensor.shape[1]
252 else:
253 k_h = ps.primary_op.attrs.get("filter_height", 1)
254 k_w = ps.primary_op.attrs.get("filter_width", 1)
255
256 return Kernel(k_w, k_h, strides[2], strides[1], dilation[2], dilation[1])
257
258
259def full_shape(shape, fill):
260 return ([fill] * (4 - len(shape))) + shape
261
262
263def has_prev_op_dependency(prev_cmd, cmd):
264 if prev_cmd is None:
265 return False
266 if (prev_cmd.cmdtype == cmd.cmdtype == CommandType.NpuStripe) and (prev_cmd.ps != cmd.ps):
267 if prev_cmd.ofm_tensor == cmd.ifm_tensor:
268 return True
269 else:
270 return prev_cmd.ofm_tensor.equivalence_id == cmd.ifm_tensor.equivalence_id
271 return False
272
273
274def get_op_ofm_rect(cmd):
275 start = full_shape(cmd.ofm_box.start_coord, 0)
276 end = full_shape(cmd.ofm_box.end_coord, 1)
277 return Rect(start[-2], start[-3], start[-1], end[-2] - 1, end[-3] - 1, end[-1] - 1)
278
279
280def get_op_ifm_rect(cmd):
281 start = full_shape(cmd.ifm_box.start_coord, 0)
282 end = full_shape(cmd.ifm_box.end_coord, 1)
283 return Rect(start[-2], start[-3], start[-1], end[-2] - 1, end[-3] - 1, end[-1] - 1)
284
285
286def get_op_ifmofm_block_depth(arch, cmd):
287 # Note: NOT equivalent to the normal ifm block depth calculation since
288 # it takes into account 'depthless' block operations by returning full
289 # depth
290 if cmd.ps.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling, NpuBlockType.ElementWise):
291 return cmd.ofm_box.get_size_shape()[-1]
292
293 return arch.calc_ifm_block_depth(cmd.ifm_box.get_size_shape()[-1], cmd.ifm_tensor.dtype.bits)
294
295
296def get_op_padding_lt(cmd):
297 if cmd.ps.npu_block_type not in (
298 NpuBlockType.ConvolutionDepthWise,
299 NpuBlockType.Pooling,
300 NpuBlockType.ConvolutionMxN,
301 ):
302 return (0, 0)
303
304 explicit_padding = list(cmd.ps.primary_op.attrs["explicit_padding"]) # (top, left, bottom, right)
305
306 # Check if this is for horizontal ifm streaming
307 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
308 explicit_padding[0] = cmd.pad_top
309 explicit_padding[2] = cmd.pad_bottom
310
311 return (explicit_padding[1], explicit_padding[0])
312
313
314def generate_register_command_stream(nng, sg, arch, verbose=False):
315 emit = CommandStreamEmitter()
316
317 base_ptr_idx_map = {
318 MemArea.Sram: BasePointerIndex.Scratch,
319 MemArea.OnChipFlash: BasePointerIndex.ReadOnly,
320 MemArea.OffChipFlash: BasePointerIndex.ReadOnly,
321 MemArea.Dram: BasePointerIndex.ReadOnly,
322 }
323
324 # Maps an AccumulatorType enum to the corresponding acc_format value
325 acc_format_map = {
326 SHRAMElements.Acc16: acc_format.FP_S5_10.value,
327 SHRAMElements.Acc32: acc_format.INT_32BIT.value,
328 SHRAMElements.Acc40: acc_format.INT_40BIT.value,
329 }
330
331 # Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
332 elementwise_mode_map = {
333 "MulAct": elementwise_mode.MUL.value,
334 "AddAct": elementwise_mode.ADD.value,
335 "SubAct": elementwise_mode.SUB.value,
336 "Minimum": elementwise_mode.MIN.value,
337 "Maximum": elementwise_mode.MAX.value,
338 "LeakyRelu": elementwise_mode.LRELU.value,
339 "Abs": elementwise_mode.ABS.value,
340 }
341
342 cmd_stream = []
343 for cmd in sg.high_level_command_stream:
344 if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.npu_block_type == NpuBlockType.Default:
345 print("Warning: Skipping register command stream generation for", cmd.ps)
346 else:
347 cmd_stream.append(cmd)
348
349 dependencies = calc_command_dependencies(cmd_stream, arch)
350
351 # Initialise operator dependency state
352 prev_ifm_rect = cur_ifm_rect = None
353 prev_ifm_block_depth = cur_ifm_block_depth = None
354 prev_ofm_rect = cur_ofm_rect = None
355 prev_ofm_block = cur_ofm_block = None
356 prev_kernel = cur_kernel = None
357 prev_cmd = None
358
359 def emit_wait_commands(cmd):
360 # The command is fully set up, emit whatever wait commands we need
361 absolute_dep, relative_dep = dependencies[cmd]
362 if relative_dep[CommandType.NpuStripe] is not None:
363 if cmd.cmdtype == CommandType.DMA:
364 param = relative_dep[CommandType.NpuStripe][1]
365 if param <= 3:
366 emit.cmd_wait(cmd0.NPU_OP_KERNEL_WAIT, param, absolute_dep[CommandType.NpuStripe][1])
367 else:
368 param = relative_dep[CommandType.NpuStripe][0]
369 param = min(param, 0xFFFF) # Clamp to allowable wait amount
370
371 if relative_dep[CommandType.DMA] is not None:
372 param = relative_dep[CommandType.DMA][0]
373 param = min(param, 0xF) # Clamp to allowable wait amount
374 emit.cmd_wait(cmd0.NPU_OP_DMA_WAIT, param, absolute_dep[CommandType.DMA][0])
375 prev_cmd = None # Clear any dependency
376
377 # Start by issuing REGION commands since they remain the same
378 emit.cmd0_with_param(cmd0.NPU_SET_IFM_REGION, BasePointerIndex.Scratch)
379 emit.cmd0_with_param(cmd0.NPU_SET_IFM2_REGION, BasePointerIndex.Scratch)
380 emit.cmd0_with_param(cmd0.NPU_SET_OFM_REGION, BasePointerIndex.Scratch)
381 for cmd in cmd_stream:
382 if cmd.cmdtype == CommandType.DMA:
383 start_coord = cmd.box.start_coord
384
385 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
386 dst_addr = cmd.out_tensor.address_for_coordinate(start_coord)
387
388 if cmd.in_tensor.compressed_values is not None:
389 stream_index = cmd.in_tensor.compressed_stream_index_from_coord(start_coord)
390 sz = cmd.in_tensor.size_of_compressed_stream(stream_index)
391 else:
392 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
393
394 # TODO: Yoda support needs to use feature_maps_not_in_fast_storage and force_outputs_to_fast_storage
395 emit.cmd0_with_param(cmd0.NPU_SET_DMA0_SRC_REGION, base_ptr_idx_map[cmd.in_tensor.mem_area])
396 emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_SRC, src_addr)
397 emit.cmd0_with_param(cmd0.NPU_SET_DMA0_DST_REGION, base_ptr_idx_map[cmd.out_tensor.mem_area])
398 emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_DST, dst_addr)
399 emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_LEN, sz)
400 dma_channel = 0
401 mode = 0 # From external to external
402
403 emit_wait_commands(cmd)
404 emit.cmd_do_operation(cmd0.NPU_OP_DMA_START, dma_channel * 16 + mode)
405
406 elif cmd.cmdtype == CommandType.NpuStripe:
407
408 ps = cmd.ps
409 primary_op = ps.primary_op
410 npu_block_type = ps.npu_block_type
411 # Specifies if global scale from the NPU_SET_OFM_SCALE register should be used instead of per-channel scale
412 use_global_scale = False
413 # Specifies type of rounding to be used.
414 rounding_mode = rounding.TFL
415 fmf = primary_op.attrs.get("fused_memory_function", None)
416 faf = primary_op.attrs.get("fused_activation_function", None)
417
418 # Specifies which operand to apply scaling to in bitexact elementwise ADD/SUB
419 op_to_scale = 0
420
421 # Update state history
422 prev_ifm_rect = cur_ifm_rect
423 prev_ifm_block_depth = cur_ifm_block_depth
424 prev_ofm_rect = cur_ofm_rect
425 prev_ofm_block = cur_ofm_block
426 prev_kernel = cur_kernel
427
428 block_config = ps.block_config
429 emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_HEIGHT_M1, block_config[0] - 1)
430 emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_WIDTH_M1, block_config[1] - 1)
431 emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_DEPTH_M1, block_config[3] - 1)
432
433 shared_buffer = ps.shared_buffer
434
435 if npu_block_type == NpuBlockType.ElementWise:
436 ifm2_broadcast = 0
437
438 if cmd.ifm_tensor.shape == []:
439 # The scalar has to be the ifm2 tensor so switch the ifms
440 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
441 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
442
443 # Set ReverseOperandOrder bit to IFM2_BROADCAST
444 ifm2_broadcast |= IFM2Broadcast.ReverseOperandOrder
445
446 # Calculate scales needed for arithmetic elementwise operators
447 if primary_op.type in set(("AddAct", "MulAct", "SubAct",)):
448 input_scale = cmd.ifm_tensor.quantization.scale_f32
449 input2_scale = cmd.ifm2_tensor.quantization.scale_f32
450 output_scale = cmd.ofm_tensor.quantization.scale_f32
451 use_global_scale = True
452
453 if primary_op.type == "MulAct":
454 if (faf == "Sigmoid") or (faf == "Tanh"):
455 output_scale = 1 / 0x3000
456
457 ofm_scale, shift = scaling.elementwise_mul_scale(input_scale, input2_scale, output_scale)
458 emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
459 else: # AddAct/SubAct
460 if (faf == "Sigmoid") or (faf == "Tanh"):
461 output_scale = 1 / 0x3000
462
463 if input_scale == input2_scale:
464 opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale(
465 input_scale, input2_scale, output_scale
466 )
467 opa_shift = 0 # Unused for this case
468 else:
469 # Use advanced implementation only when input scales differ
470 bitdepth = cmd.ifm_tensor.dtype.bits
471 (
472 opa_scale,
473 opa_shift,
474 ofm_scale,
475 shift,
476 op_to_scale,
477 ) = scaling.advanced_elementwise_add_sub_scale(
478 input_scale, input2_scale, output_scale, bitdepth
479 )
480 opb_scale = 0 # Unused for this case
481 if ifm2_broadcast & IFM2Broadcast.ReverseOperandOrder:
482 # If the operand order is reversed we also have to swap which operand is scaled
483 if op_to_scale == scaling.OperandToScale.OPa:
484 op_to_scale = scaling.OperandToScale.OPb
485 else:
486 op_to_scale = scaling.OperandToScale.OPa
487
488 emit.cmd1_with_offset(cmd1.NPU_SET_OPA_SCALE, opa_scale, opa_shift)
489 emit.cmd1_with_offset(cmd1.NPU_SET_OPB_SCALE, opb_scale)
490 emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
491
492 if primary_op.type in set(("LeakyRelu", "Abs",)):
493 output_scale = cmd.ofm_tensor.quantization.scale_f32
494 use_global_scale = True
495
496 if primary_op.type == "LeakyRelu":
497 output_scale *= primary_op.attrs["alpha"]
498
499 ofm_scale, shift = scaling.quantise_scale(output_scale)
500 emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
501
502 # For elementwise set the required SHRAM to be equal to the total size of SHRAM
503 shram_required = arch.shram_total_banks
504 emit.cmd0_with_param(cmd0.NPU_SET_IFM_IB_END, shram_required)
505
506 # Acc buffers not needed so set AB_START to size of SHRAM
507 emit.cmd0_with_param(cmd0.NPU_SET_AB_START, arch.shram_total_banks)
508
509 # Is not a unary operator
510 if cmd.ifm2_tensor is not None:
511 if cmd.ifm2_tensor.shape == []:
512 # IFM2 is a constant, set UseIFM2Scalar bit to IFM2_BROADCAST
513 ifm2_broadcast |= IFM2Broadcast.UseIFM2Scalar
514 else:
515 ifm_box_shape = cmd.ifm_box.get_size_shape()
516 ifm2_box_shape = cmd.ifm2_box.get_size_shape()
517
518 if len(cmd.ifm_tensor.shape) > 1 and ifm_box_shape[1] != ifm2_box_shape[1]:
519 # Broadcast in 'H' dimension
520 assert cmd.ifm2_tensor.shape[1] == 1
521 ifm2_broadcast |= IFM2Broadcast.BroadcastHdim
522
523 if len(cmd.ifm_tensor.shape) > 2 and ifm_box_shape[2] != ifm2_box_shape[2]:
524 # Broadcast in 'W' dimension
525 assert cmd.ifm2_tensor.shape[2] == 1
526 ifm2_broadcast |= IFM2Broadcast.BroadcastWdim
527
528 if len(cmd.ifm_tensor.shape) > 3 and ifm_box_shape[3] != ifm2_box_shape[3]:
529 # Broadcast in 'C' dimension
530 assert cmd.ifm2_tensor.shape[3] == 1
531 ifm2_broadcast |= IFM2Broadcast.BroadcastCdim
532
533 # Set IFM2_IB_START to the latter half of the IB space
534 ifm_ib_start = shared_buffer.bank_locations[SharedBufferArea.IFM]
535 emit.cmd0_with_param(
536 cmd0.NPU_SET_IFM2_IB_START, (shram_required - ifm_ib_start) / 2 + ifm_ib_start
537 )
538
539 emit.cmd0_with_param(cmd0.NPU_SET_IFM2_BROADCAST, ifm2_broadcast)
540
541 else:
542 emit.cmd0_with_param(
543 cmd0.NPU_SET_IFM_IB_END,
544 shared_buffer.bank_locations[SharedBufferArea.IFM]
545 + shared_buffer.banks_required[SharedBufferArea.IFM],
546 )
547 emit.cmd0_with_param(cmd0.NPU_SET_AB_START, shared_buffer.bank_locations[SharedBufferArea.Accumulators])
548
549 emit.cmd0_with_param(cmd0.NPU_SET_ACC_FORMAT, acc_format_map[shared_buffer.use_accumulator_element])
550
551 emit.cmd0_with_param(cmd0.NPU_SET_IFM_UPSCALE, 0)
552
553 if npu_block_type in set(
554 (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling)
555 ):
556 # Set up padding
557 explicit_padding = list(primary_op.attrs["explicit_padding"]) # (top, left, bottom, right)
558
559 # Check if this is for horizontal ifm streaming
560 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
561 explicit_padding[0] = cmd.pad_top
562 explicit_padding[2] = cmd.pad_bottom
563
564 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
565 # because of activation function needed to be fused.
566 if cmd.ifm_box.start_coord[-2] > 0:
567 explicit_padding[1] = 0
568 if cmd.ifm_box.end_coord[-2] < cmd.ifm_tensor.shape[-2]:
569 explicit_padding[3] = 0
570
571 emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_TOP, explicit_padding[0])
572 emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_LEFT, explicit_padding[1])
573 emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_BOTTOM, explicit_padding[2])
574 emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_RIGHT, explicit_padding[3])
575
576 stride = primary_op.attrs["strides"][2] - 1
577 stride |= (primary_op.attrs["strides"][1] - 1) << 1
578
579 if npu_block_type == NpuBlockType.Pooling:
580 k_height, k_width = primary_op.attrs["ksize"][1:3]
581 emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_HEIGHT_M1, k_height - 1)
582 emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_WIDTH_M1, k_width - 1)
583
584 valid_padding = sum(explicit_padding) == 0
585
586 if primary_op.type in set(("AvgPool", "AvgPoolAct")) and valid_padding:
587 # For valid padding vela has to output scaling values
588 if faf == "Sigmoid" or faf == "Tanh":
589 rescale = 0x3000 * cmd.ifm_tensor.quantization.scale_f32
590 rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
591
592 scale, shift = scaling.quantise_pooling_scale(k_height * k_width, rescale_bits)
593 scale = int(round_away_zero(scale * rescale))
594 else:
595 # In case avg pool fused with concat or other memory operation, rescaling might be needed.
596 # k_height == k_width == 1 is allways true in this case
597 # Normally the scale is maximised, to get maximum precision, which means that
598 # if rescale != 1, scale need to consider the number of bits needed for rescaling
599 rescale = cmd.ifm_tensor.quantization.scale_f32 / cmd.ofm_tensor.quantization.scale_f32
600 rescale_bits = 0
601 if k_height == k_width == 1:
602 if fmf == "ConcatSliceWrite":
603 rounding_mode = rounding.NATURAL
604 if rescale > 1:
605 rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
606 elif rescale < 1:
607 rescale_bits = -(len(bin(round_up_to_int(1 / rescale))) - 2 - 1)
608 scale, shift = scaling.quantise_pooling_scale(k_height * k_width, rescale_bits)
609 scale = int(round_away_zero(scale * rescale))
610
611 emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, scale, shift)
612 # Valid-padded average pool should use the global scale from
613 # NPU_SET_OFM_SCALE register, which is set above.
614 use_global_scale = True
615
616 else: # Convolution
617 assert cmd.weight_tensor.block_traversal != TensorBlockTraversal.Default
618 emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_HEIGHT_M1, cmd.weight_tensor.shape[0] - 1)
619 emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_WIDTH_M1, cmd.weight_tensor.shape[1] - 1)
620 if cmd.weight_tensor.block_traversal == TensorBlockTraversal.PartKernelFirst:
621 # Part-kernel-first weight ordering
622 assert npu_block_type == NpuBlockType.ConvolutionMxN
623 stride |= 1 << 2
624
625 emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_STRIDE, stride)
626
627 elif npu_block_type in set((NpuBlockType.VectorProduct,)):
628 # Vector product is implemented using a 1x1 convolution so need
629 # to setup the appropriate padding and kernel info
630 emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_TOP, 0)
631 emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_LEFT, 0)
632 emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_BOTTOM, 0)
633 emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_RIGHT, 0)
634
635 # kernel stride reg = 0 means stride(1,1) + depth first weight
636 # order + dilation(0,0) + kernel_split_size=8
637 emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_STRIDE, 0)
638
639 emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_HEIGHT_M1, 0)
640 emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_WIDTH_M1, 0)
641
642 if npu_block_type in set(
643 (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.VectorProduct)
644 ):
645 # Emit Weight base address commands, only maps the area required for
646 # this command's weights from the larger tensor.
647 stream_index = cmd.weight_tensor.compressed_stream_index_from_coord(cmd.weight_box.start_coord)
648 weight_addr = cmd.weight_tensor.address_for_coordinate(cmd.weight_box.start_coord)
649 weight_len = cmd.weight_tensor.size_of_compressed_stream(stream_index)
650 # Select weight/scale region depending on where permanent storage was defined
651 weight_region = base_ptr_idx_map[cmd.weight_tensor.mem_area]
652 if arch.permanent_storage_mem_area == MemArea.Sram:
653 weight_region = BasePointerIndex.ReadOnly
654 emit.cmd0_with_param(cmd0.NPU_SET_WEIGHT_REGION, weight_region)
655 emit.cmd1_with_offset(cmd1.NPU_SET_WEIGHT_BASE, weight_addr)
656 emit.cmd1_with_offset(cmd1.NPU_SET_WEIGHT_LENGTH, weight_len)
657
658 # Emit Scale & Bias base address commands, with length matching the amount required by
659 # the weight tensors.
660 if cmd.scale_tensor is not None:
661 # Get address and size of the scale/bias data area
662 scale_addr = cmd.scale_tensor.address_for_coordinate(cmd.weight_box.start_coord[-1:])
663 scale_len = (
664 cmd.scale_tensor.address_for_coordinate(cmd.weight_box.end_coord[-1:], True) - scale_addr
665 )
666 # Emit base address for NPU to access scale & bias data
667 scale_region = base_ptr_idx_map[cmd.scale_tensor.mem_area]
668 if arch.permanent_storage_mem_area == MemArea.Sram:
669 scale_region = BasePointerIndex.ReadOnly
670 emit.cmd0_with_param(cmd0.NPU_SET_SCALE_REGION, scale_region)
671 emit.cmd1_with_offset(cmd1.NPU_SET_SCALE_BASE, scale_addr)
672 emit.cmd1_with_offset(cmd1.NPU_SET_SCALE_LENGTH, round_up(scale_len, 16))
673
674 ofm_quant = cmd.ofm_tensor.quantization
675 ofm_quant_qmin = cmd.ofm_tensor.quantization.quant_min
676 ofm_quant_qmax = cmd.ofm_tensor.quantization.quant_max
677 ifm_min = cmd.ifm_tensor.quantization.min
678 ifm_max = cmd.ifm_tensor.quantization.max
679
680 # Emit commands for any fused activation function
681 if faf == None:
682 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.NONE)
683 # Even if no activation function, values need to be set to override previous values
684 faf_min = ofm_quant_qmin
685 faf_max = ofm_quant_qmax
686 elif faf == "Relu":
687 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.NONE)
688 faf_min = quantise_float32(0.0, ofm_quant.scale_f32, ofm_quant.zero_point)
689 faf_max = ofm_quant_qmax
690 elif faf == "Relu6":
691 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.NONE)
692 faf_min = quantise_float32(0.0, ofm_quant.scale_f32, ofm_quant.zero_point)
693 faf_max = quantise_float32(6.0, ofm_quant.scale_f32, ofm_quant.zero_point)
694 elif faf == "ReluN1To1":
695 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.NONE)
696 faf_min = quantise_float32(-1.0, ofm_quant.scale_f32, ofm_quant.zero_point)
697 faf_max = quantise_float32(1.0, ofm_quant.scale_f32, ofm_quant.zero_point)
698 elif faf == "Tanh":
699 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.TANH)
700 faf_min = quantise_float32(clamp_tanh(ifm_min), ofm_quant.scale_f32, ofm_quant.zero_point)
701 faf_max = quantise_float32(clamp_tanh(ifm_max), ofm_quant.scale_f32, ofm_quant.zero_point)
702 elif faf == "Sigmoid":
703 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.SIGMOID)
704 faf_min = quantise_float32(clamp_sigmoid(ifm_min), ofm_quant.scale_f32, ofm_quant.zero_point)
705 faf_max = quantise_float32(clamp_sigmoid(ifm_max), ofm_quant.scale_f32, ofm_quant.zero_point)
706 else:
707 raise Exception("Unsupported fused_activation_function = " + faf)
708
709 # Activation range needs to be set based upon the quantisation range and the fused activation range
710 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION_MIN, max(ofm_quant_qmin, faf_min))
711 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION_MAX, min(ofm_quant_qmax, faf_max))
712
713 out_shape = cmd.ofm_box.get_size_shape()
714 if len(out_shape) >= 4:
715 emit.cmd0_with_param(cmd0.NPU_SET_OFM_HEIGHT_M1, out_shape[-3] - 1)
716 else:
717 emit.cmd0_with_param(cmd0.NPU_SET_OFM_HEIGHT_M1, 0)
718 if len(out_shape) >= 2:
719 emit.cmd0_with_param(cmd0.NPU_SET_OFM_WIDTH_M1, out_shape[-2] - 1)
720 else:
721 emit.cmd0_with_param(cmd0.NPU_SET_OFM_WIDTH_M1, 0)
722 emit.cmd0_with_param(cmd0.NPU_SET_OFM_DEPTH_M1, out_shape[-1] - 1)
723
724 if npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct)):
725 in_shape = cmd.ifm_box.get_size_shape()
726 emit.cmd0_with_param(cmd0.NPU_SET_IFM_DEPTH_M1, in_shape[-1] - 1)
727 else:
728 emit.cmd0_with_param(cmd0.NPU_SET_IFM_DEPTH_M1, out_shape[-1] - 1)
729
730 for tens, box, ptr_ops, stride_ops, zero_point_op in (
731 (
732 cmd.ifm_tensor,
733 cmd.ifm_box,
734 (cmd1.NPU_SET_IFM_BASE0, cmd1.NPU_SET_IFM_BASE1, cmd1.NPU_SET_IFM_BASE2, cmd1.NPU_SET_IFM_BASE3),
735 (cmd1.NPU_SET_IFM_STRIDE_C, cmd1.NPU_SET_IFM_STRIDE_Y, cmd1.NPU_SET_IFM_STRIDE_X),
736 cmd0.NPU_SET_IFM_ZERO_POINT,
737 ),
738 (
739 cmd.ifm2_tensor,
740 cmd.ifm2_box,
741 (
742 cmd1.NPU_SET_IFM2_BASE0,
743 cmd1.NPU_SET_IFM2_BASE1,
744 cmd1.NPU_SET_IFM2_BASE2,
745 cmd1.NPU_SET_IFM2_BASE3,
746 ),
747 (cmd1.NPU_SET_IFM2_STRIDE_C, cmd1.NPU_SET_IFM2_STRIDE_Y, cmd1.NPU_SET_IFM2_STRIDE_X),
748 cmd0.NPU_SET_IFM2_ZERO_POINT,
749 ),
750 (
751 cmd.ofm_tensor,
752 cmd.ofm_box,
753 (cmd1.NPU_SET_OFM_BASE0, cmd1.NPU_SET_OFM_BASE1, cmd1.NPU_SET_OFM_BASE2, cmd1.NPU_SET_OFM_BASE3),
754 (cmd1.NPU_SET_OFM_STRIDE_C, cmd1.NPU_SET_OFM_STRIDE_Y, cmd1.NPU_SET_OFM_STRIDE_X),
755 cmd0.NPU_SET_OFM_ZERO_POINT,
756 ),
757 ):
758
759 if tens == None:
760 continue
761
762 need_zero_point = (faf != None) or (fmf == "ConcatSliceWrite")
763 if (
764 primary_op.type in set(("AvgPool", "AvgPoolAct")) and not need_zero_point
765 ) or tens.quantization == None:
766 # Actual integer operation, just set scale to 1 and zero point to 0
767 emit.cmd0_with_param(zero_point_op, 0)
768 else:
769 assert tens.quantization.zero_point is not None, "need an actual zero point set"
770 emit.cmd0_with_param(zero_point_op, int(tens.quantization.zero_point))
771
772 if tens.shape == []:
773 # Empty shape, elementwise constant
774 ifm2_scalar = tens.quant_values.astype(np.uint8)
775 assert ifm2_scalar.size == 1
776 emit.cmd0_with_param(cmd0.NPU_SET_IFM2_SCALAR, ifm2_scalar.item(0))
777 continue
778
779 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
780 box.start_coord, box.end_coord
781 )
782 if npu_block_type != NpuBlockType.VectorProduct:
783 if tens == cmd.ifm_tensor:
784 emit.cmd0_with_param(cmd0.NPU_SET_IFM_HEIGHT0_M1, height_0 - 1)
785 emit.cmd0_with_param(cmd0.NPU_SET_IFM_HEIGHT1_M1, height_1 - 1)
786 emit.cmd0_with_param(cmd0.NPU_SET_IFM_WIDTH0_M1, width_0 - 1)
787 elif tens == cmd.ofm_tensor:
788 emit.cmd0_with_param(cmd0.NPU_SET_OFM_HEIGHT0_M1, height_0 - 1)
789 emit.cmd0_with_param(cmd0.NPU_SET_OFM_HEIGHT1_M1, height_1 - 1)
790 emit.cmd0_with_param(cmd0.NPU_SET_OFM_WIDTH0_M1, width_0 - 1)
791 elif tens == cmd.ifm2_tensor:
792 emit.cmd0_with_param(cmd0.NPU_SET_IFM2_HEIGHT0_M1, height_0 - 1)
793 emit.cmd0_with_param(cmd0.NPU_SET_IFM2_HEIGHT1_M1, height_1 - 1)
794 emit.cmd0_with_param(cmd0.NPU_SET_IFM2_WIDTH0_M1, width_0 - 1)
795 else:
796 if len(out_shape) == 2:
797 # TODO: N is put in W-dimension for now
798 # Should be spread over H and W, but then block size selectetion,
799 # and stride calculation should be changed
800 if tens == cmd.ifm_tensor:
801 emit.cmd0_with_param(cmd0.NPU_SET_IFM_WIDTH0_M1, out_shape[-2] - 1)
802 elif tens == cmd.ofm_tensor:
803 emit.cmd0_with_param(cmd0.NPU_SET_OFM_WIDTH0_M1, out_shape[-2] - 1)
804 else:
805 assert False
806
807 for idx, addr in enumerate(addresses):
808 if addr is None:
809 addresses[idx] = 0
810
811 emit.cmd1_with_offset(ptr_ops[0], addresses[0])
812 emit.cmd1_with_offset(ptr_ops[1], addresses[1])
813 emit.cmd1_with_offset(ptr_ops[2], addresses[2])
814 emit.cmd1_with_offset(ptr_ops[3], addresses[3])
815
816 strides = tens.get_strides()
817 emit.cmd1_with_offset(stride_ops[0], strides[1]) # stride between 16-byte channel blocks (C)
818 emit.cmd1_with_offset(stride_ops[2], strides[3]) # stride between horisontal values (W)
819 emit.cmd1_with_offset(stride_ops[1], strides[2]) # stride between vertical values (H)
820
821 if tens.format == TensorFormat.NHCWB16:
822 # Check that all BasePointer addresses are aligned to 16 bytes
823 assert (int(addresses[0]) % 16) == 0
824 assert (int(addresses[1]) % 16) == 0
825 assert (int(addresses[2]) % 16) == 0
826 assert (int(addresses[3]) % 16) == 0
827
828 ofm_dtype = cmd.ofm_tensor.dtype
829 assert ofm_dtype.type & BaseType.Int
830 prec = 0
831 if ofm_dtype.size_in_bits() == 8:
832 prec = 0
833 elif ofm_dtype.size_in_bits() == 16:
834 prec = 2
835 else:
836 assert 0
837
838 if ofm_dtype.type & BaseType.Signed:
839 prec += 1
840
841 if use_global_scale:
842 # Set global scale bit, as opposed to using per channel scale
843 prec |= 1 << 8
844
845 if cmd.ofm_tensor.format == TensorFormat.NHCWB16:
846 prec |= 1 << 6
847
848 prec |= rounding_mode.value << 14
849
850 emit.cmd0_with_param(cmd0.NPU_SET_OFM_PRECISION, prec)
851
852 prec = None
853 weight_bits = 8
854 if cmd.weight_tensor is not None:
855 weight_bits = cmd.weight_tensor.dtype.size_in_bits()
856
857 ifm_dtype = cmd.ifm_tensor.dtype
858
859 assert weight_bits == 8, "Unsupported weight bit depth"
860 assert ifm_dtype.size_in_bits() in {8, 16}
861
862 if ifm_dtype.size_in_bits() == 8:
863 if ifm_dtype.type & BaseType.Signed:
864 prec = ifm_precision.W8_S8
865 else:
866 prec = ifm_precision.W8_U8
867 elif ifm_dtype.size_in_bits() == 16:
868 if ifm_dtype.type & BaseType.Signed:
869 prec = ifm_precision.W8_S16
870 else:
871 prec = ifm_precision.W8_U16
872
873 ifm_prec = prec.value
874 ifm2_prec = ifm_prec
875
876 if cmd.ifm_tensor.format == TensorFormat.NHCWB16:
877 ifm_prec |= 1 << 6
878
879 ifm_prec |= op_to_scale << 8
880
881 emit.cmd0_with_param(cmd0.NPU_SET_IFM_PRECISION, ifm_prec)
882
883 if cmd.ifm2_tensor is not None:
884 if cmd.ifm2_tensor.format == TensorFormat.NHCWB16:
885 ifm2_prec |= 1 << 6
886 emit.cmd0_with_param(cmd0.NPU_SET_IFM2_PRECISION, ifm2_prec)
887
888 emit_wait_commands(cmd)
889
890 # Get op parameters
891 cur_ifm_block_depth = get_op_ifmofm_block_depth(arch, cmd)
892 cur_ofm_block = Block(ps.block_config[1], ps.block_config[0], ps.block_config[3])
893 cur_ofm_rect = get_op_ofm_rect(cmd)
894 cur_ifm_rect = get_op_ifm_rect(cmd)
895 cur_kernel = get_op_kernel(cmd.ps)
896 cur_padLT = get_op_padding_lt(cmd)
897 if (prev_kernel is not None) and (cur_kernel is not None) and has_prev_op_dependency(prev_cmd, cmd):
898 if cmd.ifm_tensor.shape == prev_cmd.ofm_tensor.shape:
899 blockdep = arch.calc_block_dep(
900 prev_ifm_rect,
901 prev_ofm_rect,
902 prev_ifm_block_depth,
903 prev_ofm_block,
904 prev_kernel,
905 cur_ifm_rect,
906 cur_ofm_rect,
907 cur_ifm_block_depth,
908 cur_ofm_block,
909 cur_kernel,
910 cur_padLT,
911 )
912 else:
913 blockdep = 0
914 else:
915 blockdep = ArchitectureFeatures.MAX_BLOCKDEP
916
917 # Set between every op (dependent or not)
918 blockdep = min(blockdep, arch.max_blockdep)
919 emit.cmd0_with_param(cmd0.NPU_SET_BLOCKDEP, blockdep)
920 prev_cmd = cmd
921
922 if npu_block_type == NpuBlockType.ConvolutionMxN:
923 emit.cmd_do_operation(cmd0.NPU_OP_CONV)
924 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
925 emit.cmd_do_operation(cmd0.NPU_OP_DEPTHWISE)
926 elif npu_block_type == NpuBlockType.VectorProduct:
927 # Vector product is implemented using a 1x1 convolution
928 emit.cmd_do_operation(cmd0.NPU_OP_CONV)
929 elif npu_block_type == NpuBlockType.Pooling:
930 param = "Max" not in primary_op.type
931 emit.cmd_do_operation(cmd0.NPU_OP_POOL, param=param)
932 elif npu_block_type == NpuBlockType.ElementWise:
933 param = elementwise_mode_map[primary_op.type]
934 emit.cmd_do_operation(cmd0.NPU_OP_ELEMENTWISE, param)
935 else:
936 print("Warning: Skipping register command stream generation for", ps)
937
938 # Fill in final part of command stream:
939 emit.cmd_do_operation(cmd0.NPU_OP_STOP, param=0xFFFF)
940
941 sg.register_command_stream = emit.to_list()
942 if verbose:
943 emit.print_cmds()
944 print("number of commands", len(emit.cmd_stream))
945 print("command stream length in words", len(sg.register_command_stream))