blob: 460cf016ce0da4db20f83109b410e60eb4c101b1 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18# Description:
19# Register level (low-level) command stream generation for Ethos-U55. Takes a high-level command stream and generates
20# all the register settings. Calculates dependencies between commands and inserts wait operations. And generates a bit
21# stream suitable for interpretation by the Ethos-U55 processor.
22
23from collections import defaultdict
24from enum import Enum, IntEnum
Diego Russoea6111a2020-04-14 18:41:58 +010025
26import numpy as np
27
28from . import scaling
Tim Hall79d07d22020-04-27 18:20:16 +010029from .high_level_command_stream import CommandType
Diego Russoea6111a2020-04-14 18:41:58 +010030from .ethos_u55_regs.ethos_u55_regs import cmd0, cmd1, acc_format, elementwise_mode, rounding, activation, ifm_precision
31from .tensor import MemArea, TensorBlockTraversal, TensorFormat
Tim Hall79d07d22020-04-27 18:20:16 +010032from .operation import NpuBlockType
33from .numeric_util import quantise_float32, round_up, round_away_zero, round_up_to_int, clamp_sigmoid, clamp_tanh
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +020034from .data_type import BaseType, DataType
Tim Hall79d07d22020-04-27 18:20:16 +010035from .shared_buffer_allocation import SharedBufferAllocation
36from .architecture_features import SharedBufferArea, SHRAMElements, ArchitectureFeatures
Tim Hall79d07d22020-04-27 18:20:16 +010037from .architecture_features import Block, Kernel, Rect
Tim Hall79d07d22020-04-27 18:20:16 +010038
39
40class RegisterMachine:
41 def __init__(self):
42 self.n_banks = 1
43 self.registers = [defaultdict(lambda: None) for _ in range(self.n_banks)]
44 self.bank_idx = 0
45
46 def set_register(self, reg, value):
47 is_changed = self.registers[self.bank_idx][reg] != value
48 self.registers[self.bank_idx][reg] = value
49 # is_changed = True # force command
50 return is_changed
51
52 def switch_bank(self):
53 self.bank_idx = (self.bank_idx + 1) % self.n_banks
54
55
56class CmdMode(IntEnum):
57 NoPayload = 0x0000
58 Payload32 = 0x4000
59 Mask = 0xC000
60 CmdOpMask = 0x03FF
61
62
63class BasePointerIndex(IntEnum):
64 ReadOnly = 0 # base address slot index for weights and scaling
65 Scratch = 1 # base address slot index for scratch memory area
66
67
68# TODO: Replace with definitions from ethos_u55_regs
69class IFM2Broadcast(IntEnum):
70 BroadcastHdim = 1 << 0
71 BroadcastWdim = 1 << 1
72 BroadcastCdim = 1 << 2
73 ReverseOperandOrder = 1 << 6
74 UseIFM2Scalar = 1 << 7
75
76
77class CommandStreamEmitter:
78 def __init__(self):
79 self.cmd_stream = []
80 self.reg_machine = [RegisterMachine(), RegisterMachine()]
81 self.last_absolute_wait = defaultdict(int)
82
83 def get_reg_machine(self, cmd):
84 if "DMA" in cmd.name:
85 return self.reg_machine[1]
86 else:
87 return self.reg_machine[0]
88
89 def size_in_bytes(self):
90 sz = 0
91 for cmd in self.cmd_stream:
92 sz += len(cmd) * 4
93 return sz
94
95 def to_list(self):
96 return [elem for cmd in self.cmd_stream for elem in cmd]
97
98 def print_cmds(self):
99 print("Code: Command: Param: Payload:")
100 for words_for_one_command in self.cmd_stream:
101 code = words_for_one_command[0] & 0x0000FFFF # lower 16 bits
102 param = words_for_one_command[0] >> 16 # higher 16 bits
103
104 payload_mode = CmdMode(code & CmdMode.Mask)
105
106 # code and command
107 s = " 0x%04x " % code
108 if payload_mode == CmdMode.NoPayload:
109 s += str(cmd0(code & CmdMode.CmdOpMask))
110 else:
111 s += str(cmd1(code & CmdMode.CmdOpMask))
112
113 s = s.ljust(40)
114 s += "%5d" % param
115
116 # payload
117 if payload_mode == CmdMode.Payload32:
118 s += " 0x%08x (%d)" % (words_for_one_command[1], words_for_one_command[1])
119 else:
120 s += " -"
121
122 print(s)
123
124 def cmd0_with_param(self, cmd, param):
125 if isinstance(param, Enum):
126 param = int(param.value)
127 else:
128 param = int(param)
129 param = param & 0xFFFF
130 command = cmd.value | (param << 16)
131 if not self.get_reg_machine(cmd).set_register(cmd, (command, param)):
132 return
133
134 # This is not a redundant command, actually write it
135 self.cmd_stream.append((command,))
136
137 def cmd1_with_offset(self, cmd, offset, param=0x0):
138 offset = int(offset) & 0xFFFFFFFFF
139 command = cmd.value | CmdMode.Payload32.value | (param << 16)
140
141 if not self.get_reg_machine(cmd).set_register(cmd, (command, offset)):
142 return
143
144 # This is not a redundant command, actually write it
145 self.cmd_stream.append((command, offset))
146
147 def cmd_wait(self, cmd, param, absolute_wait_time):
148 if absolute_wait_time <= self.last_absolute_wait[cmd]:
149 return
150
151 self.last_absolute_wait[cmd] = absolute_wait_time
152 param = int(param)
153 command = ((param & 0xFFFF) << 16) | cmd.value
154 self.cmd_stream.append((command,))
155
156 def cmd_do_operation(self, cmd, param=0):
157 param = int(param)
158 command = ((param & 0xFFFF) << 16) | cmd.value
159
160 self.cmd_stream.append((command,))
161 self.get_reg_machine(cmd).switch_bank()
162
163
164def calc_command_dependencies(cmd_stream, arch):
165 cmd_starts = {}
166 cmd_ends = {}
167 memory_accesses = {}
168
169 # Keep track of accumulated number of commands in command stream.
170 # First element kernel ops: (# of blocks, # of commands)
171 # Second element DMA ops: (# of commands)
172 pos = np.array((np.array((0, 0)), np.array([0])))
173
174 dependencies = {}
175
176 for cmd in cmd_stream:
177 cmd_starts[cmd] = pos
178 op_count = cmd.get_operation_count()
179 # Keep track of both num blocks and commands
180 cmd_add = 0 if (op_count[0] == 0) else 1
181 pos = np.array((pos[0] + np.array((op_count[0], cmd_add)), pos[1] + np.array([op_count[1]])))
182 cmd_ends[cmd] = np.array((pos[0], pos[1]))
183 memory_accesses[cmd] = cmd.get_memory_accesses()
184
185 for idx, cmd in enumerate(cmd_stream):
186 curr_accesses = memory_accesses[cmd]
187 # Keep track of command dependency.
188 # First element kernel ops: (# of blocks, # of commands)
189 # Second element DMA ops: (# of commands)
190 dep_offsets = np.array((np.array((-1, -1)), np.array([-1])))
191 dep_cmds = [None] * CommandType.Size.value
192 if idx > 0:
193 # Look at the previous commands in backwards order
194 for prev_cmd in cmd_stream[idx - 1 :: -1]:
195 assert prev_cmd is not cmd
196 if dep_cmds[prev_cmd.cmdtype] is None:
197 is_dependency = False
198 if cmd.cmdtype == CommandType.NpuStripe and prev_cmd.cmdtype == CommandType.NpuStripe:
199 # Special handling here, as dpu -> dpu operations require additional care
200 if not SharedBufferAllocation.is_compatible(prev_cmd.ps.shared_buffer, cmd.ps.shared_buffer):
201 is_dependency = True
202 elif memory_accesses[prev_cmd].conflicts(curr_accesses):
203 is_dependency = True
204 else:
205 if memory_accesses[prev_cmd].conflicts(curr_accesses):
206 is_dependency = True
207
208 if is_dependency:
209 new_offset = cmd_ends[prev_cmd][prev_cmd.cmdtype]
210 if new_offset[0] > dep_offsets[prev_cmd.cmdtype][0]:
211 dep_cmds[prev_cmd.cmdtype] = prev_cmd
212 dep_offsets[prev_cmd.cmdtype] = new_offset
213
214 # Check if we've got dependencies for all commands, in which case we can early out
215 for dep in dep_cmds:
216 if dep is None:
217 break
218 else:
219 break # all handled
220
221 # Convert absolute to relative dependencies, using None to signal the special case of no
222 # dependency of this kind
223 res = [None] * CommandType.Size.value
224 for i in range(CommandType.Size.value):
225 if dep_cmds[i] is not None:
226 res[i] = cmd_starts[cmd][i] - dep_offsets[i]
227
228 dependencies[cmd] = cmd_starts[cmd], res
229
230 return dependencies
231
232
233def get_op_kernel(ps):
234 if ps.primary_op is None:
235 return None
236
237 strides = ps.primary_op.attrs.get("strides", (1, 1, 1, 1))
238 dilation = ps.primary_op.attrs.get("dilation", (1, 1, 1, 1))
239 if ps.weight_tensor:
240 if ps.npu_block_type in set((NpuBlockType.VectorProduct, NpuBlockType.ElementWise)):
241 k_h = 1
242 k_w = 1
243 else:
244 k_h = ps.weight_tensor.shape[0]
245 k_w = ps.weight_tensor.shape[1]
246 else:
247 k_h = ps.primary_op.attrs.get("filter_height", 1)
248 k_w = ps.primary_op.attrs.get("filter_width", 1)
249
250 return Kernel(k_w, k_h, strides[2], strides[1], dilation[2], dilation[1])
251
252
253def full_shape(shape, fill):
254 return ([fill] * (4 - len(shape))) + shape
255
256
257def has_prev_op_dependency(prev_cmd, cmd):
258 if prev_cmd is None:
259 return False
260 if (prev_cmd.cmdtype == cmd.cmdtype == CommandType.NpuStripe) and (prev_cmd.ps != cmd.ps):
261 if prev_cmd.ofm_tensor == cmd.ifm_tensor:
262 return True
263 else:
264 return prev_cmd.ofm_tensor.equivalence_id == cmd.ifm_tensor.equivalence_id
265 return False
266
267
268def get_op_ofm_rect(cmd):
269 start = full_shape(cmd.ofm_box.start_coord, 0)
270 end = full_shape(cmd.ofm_box.end_coord, 1)
271 return Rect(start[-2], start[-3], start[-1], end[-2] - 1, end[-3] - 1, end[-1] - 1)
272
273
274def get_op_ifm_rect(cmd):
275 start = full_shape(cmd.ifm_box.start_coord, 0)
276 end = full_shape(cmd.ifm_box.end_coord, 1)
277 return Rect(start[-2], start[-3], start[-1], end[-2] - 1, end[-3] - 1, end[-1] - 1)
278
279
280def get_op_ifmofm_block_depth(arch, cmd):
281 # Note: NOT equivalent to the normal ifm block depth calculation since
282 # it takes into account 'depthless' block operations by returning full
283 # depth
284 if cmd.ps.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling, NpuBlockType.ElementWise):
285 return cmd.ofm_box.get_size_shape()[-1]
286
287 return arch.calc_ifm_block_depth(cmd.ifm_box.get_size_shape()[-1], cmd.ifm_tensor.dtype.bits)
288
289
290def get_op_padding_lt(cmd):
291 if cmd.ps.npu_block_type not in (
292 NpuBlockType.ConvolutionDepthWise,
293 NpuBlockType.Pooling,
294 NpuBlockType.ConvolutionMxN,
295 ):
296 return (0, 0)
297
298 explicit_padding = list(cmd.ps.primary_op.attrs["explicit_padding"]) # (top, left, bottom, right)
299
300 # Check if this is for horizontal ifm streaming
301 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
302 explicit_padding[0] = cmd.pad_top
303 explicit_padding[2] = cmd.pad_bottom
304
305 return (explicit_padding[1], explicit_padding[0])
306
307
308def generate_register_command_stream(nng, sg, arch, verbose=False):
309 emit = CommandStreamEmitter()
310
311 base_ptr_idx_map = {
312 MemArea.Sram: BasePointerIndex.Scratch,
313 MemArea.OnChipFlash: BasePointerIndex.ReadOnly,
314 MemArea.OffChipFlash: BasePointerIndex.ReadOnly,
315 MemArea.Dram: BasePointerIndex.ReadOnly,
316 }
317
318 # Maps an AccumulatorType enum to the corresponding acc_format value
319 acc_format_map = {
320 SHRAMElements.Acc16: acc_format.FP_S5_10.value,
321 SHRAMElements.Acc32: acc_format.INT_32BIT.value,
322 SHRAMElements.Acc40: acc_format.INT_40BIT.value,
323 }
324
325 # Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
326 elementwise_mode_map = {
327 "MulAct": elementwise_mode.MUL.value,
328 "AddAct": elementwise_mode.ADD.value,
329 "SubAct": elementwise_mode.SUB.value,
330 "Minimum": elementwise_mode.MIN.value,
331 "Maximum": elementwise_mode.MAX.value,
332 "LeakyRelu": elementwise_mode.LRELU.value,
333 "Abs": elementwise_mode.ABS.value,
334 }
335
336 cmd_stream = []
337 for cmd in sg.high_level_command_stream:
338 if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.npu_block_type == NpuBlockType.Default:
339 print("Warning: Skipping register command stream generation for", cmd.ps)
340 else:
341 cmd_stream.append(cmd)
342
343 dependencies = calc_command_dependencies(cmd_stream, arch)
344
345 # Initialise operator dependency state
346 prev_ifm_rect = cur_ifm_rect = None
347 prev_ifm_block_depth = cur_ifm_block_depth = None
348 prev_ofm_rect = cur_ofm_rect = None
349 prev_ofm_block = cur_ofm_block = None
350 prev_kernel = cur_kernel = None
351 prev_cmd = None
352
353 def emit_wait_commands(cmd):
354 # The command is fully set up, emit whatever wait commands we need
355 absolute_dep, relative_dep = dependencies[cmd]
356 if relative_dep[CommandType.NpuStripe] is not None:
357 if cmd.cmdtype == CommandType.DMA:
358 param = relative_dep[CommandType.NpuStripe][1]
359 if param <= 3:
360 emit.cmd_wait(cmd0.NPU_OP_KERNEL_WAIT, param, absolute_dep[CommandType.NpuStripe][1])
361 else:
362 param = relative_dep[CommandType.NpuStripe][0]
363 param = min(param, 0xFFFF) # Clamp to allowable wait amount
364
365 if relative_dep[CommandType.DMA] is not None:
366 param = relative_dep[CommandType.DMA][0]
367 param = min(param, 0xF) # Clamp to allowable wait amount
368 emit.cmd_wait(cmd0.NPU_OP_DMA_WAIT, param, absolute_dep[CommandType.DMA][0])
Tim Hall79d07d22020-04-27 18:20:16 +0100369
Tim Hall79d07d22020-04-27 18:20:16 +0100370 for cmd in cmd_stream:
371 if cmd.cmdtype == CommandType.DMA:
372 start_coord = cmd.box.start_coord
373
374 src_addr = cmd.in_tensor.address_for_coordinate(start_coord)
375 dst_addr = cmd.out_tensor.address_for_coordinate(start_coord)
376
377 if cmd.in_tensor.compressed_values is not None:
378 stream_index = cmd.in_tensor.compressed_stream_index_from_coord(start_coord)
379 sz = cmd.in_tensor.size_of_compressed_stream(stream_index)
380 else:
381 sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
382
383 # TODO: Yoda support needs to use feature_maps_not_in_fast_storage and force_outputs_to_fast_storage
384 emit.cmd0_with_param(cmd0.NPU_SET_DMA0_SRC_REGION, base_ptr_idx_map[cmd.in_tensor.mem_area])
385 emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_SRC, src_addr)
386 emit.cmd0_with_param(cmd0.NPU_SET_DMA0_DST_REGION, base_ptr_idx_map[cmd.out_tensor.mem_area])
387 emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_DST, dst_addr)
388 emit.cmd1_with_offset(cmd1.NPU_SET_DMA0_LEN, sz)
389 dma_channel = 0
390 mode = 0 # From external to external
391
392 emit_wait_commands(cmd)
393 emit.cmd_do_operation(cmd0.NPU_OP_DMA_START, dma_channel * 16 + mode)
394
395 elif cmd.cmdtype == CommandType.NpuStripe:
396
397 ps = cmd.ps
398 primary_op = ps.primary_op
399 npu_block_type = ps.npu_block_type
400 # Specifies if global scale from the NPU_SET_OFM_SCALE register should be used instead of per-channel scale
401 use_global_scale = False
402 # Specifies type of rounding to be used.
403 rounding_mode = rounding.TFL
404 fmf = primary_op.attrs.get("fused_memory_function", None)
405 faf = primary_op.attrs.get("fused_activation_function", None)
406
407 # Specifies which operand to apply scaling to in bitexact elementwise ADD/SUB
408 op_to_scale = 0
409
410 # Update state history
411 prev_ifm_rect = cur_ifm_rect
412 prev_ifm_block_depth = cur_ifm_block_depth
413 prev_ofm_rect = cur_ofm_rect
414 prev_ofm_block = cur_ofm_block
415 prev_kernel = cur_kernel
416
417 block_config = ps.block_config
418 emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_HEIGHT_M1, block_config[0] - 1)
419 emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_WIDTH_M1, block_config[1] - 1)
420 emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_DEPTH_M1, block_config[3] - 1)
421
422 shared_buffer = ps.shared_buffer
423
424 if npu_block_type == NpuBlockType.ElementWise:
425 ifm2_broadcast = 0
426
427 if cmd.ifm_tensor.shape == []:
428 # The scalar has to be the ifm2 tensor so switch the ifms
429 cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
430 cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
431
432 # Set ReverseOperandOrder bit to IFM2_BROADCAST
433 ifm2_broadcast |= IFM2Broadcast.ReverseOperandOrder
434
435 # Calculate scales needed for arithmetic elementwise operators
436 if primary_op.type in set(("AddAct", "MulAct", "SubAct",)):
437 input_scale = cmd.ifm_tensor.quantization.scale_f32
438 input2_scale = cmd.ifm2_tensor.quantization.scale_f32
439 output_scale = cmd.ofm_tensor.quantization.scale_f32
440 use_global_scale = True
441
442 if primary_op.type == "MulAct":
443 if (faf == "Sigmoid") or (faf == "Tanh"):
444 output_scale = 1 / 0x3000
445
446 ofm_scale, shift = scaling.elementwise_mul_scale(input_scale, input2_scale, output_scale)
447 emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
448 else: # AddAct/SubAct
449 if (faf == "Sigmoid") or (faf == "Tanh"):
450 output_scale = 1 / 0x3000
451
452 if input_scale == input2_scale:
453 opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale(
454 input_scale, input2_scale, output_scale
455 )
456 opa_shift = 0 # Unused for this case
457 else:
458 # Use advanced implementation only when input scales differ
459 bitdepth = cmd.ifm_tensor.dtype.bits
460 (
461 opa_scale,
462 opa_shift,
463 ofm_scale,
464 shift,
465 op_to_scale,
466 ) = scaling.advanced_elementwise_add_sub_scale(
467 input_scale, input2_scale, output_scale, bitdepth
468 )
469 opb_scale = 0 # Unused for this case
470 if ifm2_broadcast & IFM2Broadcast.ReverseOperandOrder:
471 # If the operand order is reversed we also have to swap which operand is scaled
472 if op_to_scale == scaling.OperandToScale.OPa:
473 op_to_scale = scaling.OperandToScale.OPb
474 else:
475 op_to_scale = scaling.OperandToScale.OPa
476
477 emit.cmd1_with_offset(cmd1.NPU_SET_OPA_SCALE, opa_scale, opa_shift)
478 emit.cmd1_with_offset(cmd1.NPU_SET_OPB_SCALE, opb_scale)
479 emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
480
481 if primary_op.type in set(("LeakyRelu", "Abs",)):
482 output_scale = cmd.ofm_tensor.quantization.scale_f32
483 use_global_scale = True
484
485 if primary_op.type == "LeakyRelu":
486 output_scale *= primary_op.attrs["alpha"]
487
488 ofm_scale, shift = scaling.quantise_scale(output_scale)
489 emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
490
491 # For elementwise set the required SHRAM to be equal to the total size of SHRAM
492 shram_required = arch.shram_total_banks
493 emit.cmd0_with_param(cmd0.NPU_SET_IFM_IB_END, shram_required)
494
495 # Acc buffers not needed so set AB_START to size of SHRAM
496 emit.cmd0_with_param(cmd0.NPU_SET_AB_START, arch.shram_total_banks)
497
498 # Is not a unary operator
499 if cmd.ifm2_tensor is not None:
500 if cmd.ifm2_tensor.shape == []:
501 # IFM2 is a constant, set UseIFM2Scalar bit to IFM2_BROADCAST
502 ifm2_broadcast |= IFM2Broadcast.UseIFM2Scalar
503 else:
504 ifm_box_shape = cmd.ifm_box.get_size_shape()
505 ifm2_box_shape = cmd.ifm2_box.get_size_shape()
506
507 if len(cmd.ifm_tensor.shape) > 1 and ifm_box_shape[1] != ifm2_box_shape[1]:
508 # Broadcast in 'H' dimension
509 assert cmd.ifm2_tensor.shape[1] == 1
510 ifm2_broadcast |= IFM2Broadcast.BroadcastHdim
511
512 if len(cmd.ifm_tensor.shape) > 2 and ifm_box_shape[2] != ifm2_box_shape[2]:
513 # Broadcast in 'W' dimension
514 assert cmd.ifm2_tensor.shape[2] == 1
515 ifm2_broadcast |= IFM2Broadcast.BroadcastWdim
516
517 if len(cmd.ifm_tensor.shape) > 3 and ifm_box_shape[3] != ifm2_box_shape[3]:
518 # Broadcast in 'C' dimension
519 assert cmd.ifm2_tensor.shape[3] == 1
520 ifm2_broadcast |= IFM2Broadcast.BroadcastCdim
521
522 # Set IFM2_IB_START to the latter half of the IB space
523 ifm_ib_start = shared_buffer.bank_locations[SharedBufferArea.IFM]
524 emit.cmd0_with_param(
525 cmd0.NPU_SET_IFM2_IB_START, (shram_required - ifm_ib_start) / 2 + ifm_ib_start
526 )
527
528 emit.cmd0_with_param(cmd0.NPU_SET_IFM2_BROADCAST, ifm2_broadcast)
529
530 else:
531 emit.cmd0_with_param(
532 cmd0.NPU_SET_IFM_IB_END,
533 shared_buffer.bank_locations[SharedBufferArea.IFM]
534 + shared_buffer.banks_required[SharedBufferArea.IFM],
535 )
536 emit.cmd0_with_param(cmd0.NPU_SET_AB_START, shared_buffer.bank_locations[SharedBufferArea.Accumulators])
537
538 emit.cmd0_with_param(cmd0.NPU_SET_ACC_FORMAT, acc_format_map[shared_buffer.use_accumulator_element])
539
540 emit.cmd0_with_param(cmd0.NPU_SET_IFM_UPSCALE, 0)
541
542 if npu_block_type in set(
543 (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling)
544 ):
545 # Set up padding
546 explicit_padding = list(primary_op.attrs["explicit_padding"]) # (top, left, bottom, right)
547
548 # Check if this is for horizontal ifm streaming
549 if not (cmd.is_first_h_stripe and cmd.is_last_h_stripe):
550 explicit_padding[0] = cmd.pad_top
551 explicit_padding[2] = cmd.pad_bottom
552
553 # Indexing from end since a 1x1 Avgpool might have been added with non 4-dimensional input/output,
554 # because of activation function needed to be fused.
555 if cmd.ifm_box.start_coord[-2] > 0:
556 explicit_padding[1] = 0
557 if cmd.ifm_box.end_coord[-2] < cmd.ifm_tensor.shape[-2]:
558 explicit_padding[3] = 0
559
560 emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_TOP, explicit_padding[0])
561 emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_LEFT, explicit_padding[1])
562 emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_BOTTOM, explicit_padding[2])
563 emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_RIGHT, explicit_padding[3])
564
Dwight Lidman0538a772020-05-06 14:09:17 +0200565 # set kernel x stride low bit
566 stride = primary_op.attrs["strides"][2] - 1 & 1
567 # set kernel y stride low bit
568 stride |= (primary_op.attrs["strides"][1] - 1 & 1) << 1
569 # set kernel x stride extension bits
570 stride |= (primary_op.attrs["strides"][2] - 1 >> 1) << 6
571 # set kernel y stride extension bits
572 stride |= (primary_op.attrs["strides"][1] - 1 >> 1) << 9
573
Tim Hall79d07d22020-04-27 18:20:16 +0100574
575 if npu_block_type == NpuBlockType.Pooling:
576 k_height, k_width = primary_op.attrs["ksize"][1:3]
577 emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_HEIGHT_M1, k_height - 1)
578 emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_WIDTH_M1, k_width - 1)
579
580 valid_padding = sum(explicit_padding) == 0
581
582 if primary_op.type in set(("AvgPool", "AvgPoolAct")) and valid_padding:
583 # For valid padding vela has to output scaling values
584 if faf == "Sigmoid" or faf == "Tanh":
585 rescale = 0x3000 * cmd.ifm_tensor.quantization.scale_f32
586 rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
587
588 scale, shift = scaling.quantise_pooling_scale(k_height * k_width, rescale_bits)
589 scale = int(round_away_zero(scale * rescale))
590 else:
591 # In case avg pool fused with concat or other memory operation, rescaling might be needed.
592 # k_height == k_width == 1 is allways true in this case
593 # Normally the scale is maximised, to get maximum precision, which means that
594 # if rescale != 1, scale need to consider the number of bits needed for rescaling
595 rescale = cmd.ifm_tensor.quantization.scale_f32 / cmd.ofm_tensor.quantization.scale_f32
596 rescale_bits = 0
597 if k_height == k_width == 1:
598 if fmf == "ConcatSliceWrite":
599 rounding_mode = rounding.NATURAL
600 if rescale > 1:
601 rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
602 elif rescale < 1:
603 rescale_bits = -(len(bin(round_up_to_int(1 / rescale))) - 2 - 1)
604 scale, shift = scaling.quantise_pooling_scale(k_height * k_width, rescale_bits)
605 scale = int(round_away_zero(scale * rescale))
606
607 emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, scale, shift)
608 # Valid-padded average pool should use the global scale from
609 # NPU_SET_OFM_SCALE register, which is set above.
610 use_global_scale = True
611
612 else: # Convolution
613 assert cmd.weight_tensor.block_traversal != TensorBlockTraversal.Default
Fredrik Svedbergd67c0aa2020-03-30 13:15:28 +0200614 # Reduced precision quantization and natural rounding used for int16
615 if cmd.ifm_tensor.dtype == DataType.int16:
616 rounding_mode = rounding.NATURAL
Tim Hall79d07d22020-04-27 18:20:16 +0100617 emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_HEIGHT_M1, cmd.weight_tensor.shape[0] - 1)
618 emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_WIDTH_M1, cmd.weight_tensor.shape[1] - 1)
619 if cmd.weight_tensor.block_traversal == TensorBlockTraversal.PartKernelFirst:
620 # Part-kernel-first weight ordering
621 assert npu_block_type == NpuBlockType.ConvolutionMxN
622 stride |= 1 << 2
623
624 emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_STRIDE, stride)
625
626 elif npu_block_type in set((NpuBlockType.VectorProduct,)):
627 # Vector product is implemented using a 1x1 convolution so need
628 # to setup the appropriate padding and kernel info
629 emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_TOP, 0)
630 emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_LEFT, 0)
631 emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_BOTTOM, 0)
632 emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_RIGHT, 0)
633
634 # kernel stride reg = 0 means stride(1,1) + depth first weight
635 # order + dilation(0,0) + kernel_split_size=8
636 emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_STRIDE, 0)
637
638 emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_HEIGHT_M1, 0)
639 emit.cmd0_with_param(cmd0.NPU_SET_KERNEL_WIDTH_M1, 0)
640
641 if npu_block_type in set(
642 (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.VectorProduct)
643 ):
644 # Emit Weight base address commands, only maps the area required for
645 # this command's weights from the larger tensor.
646 stream_index = cmd.weight_tensor.compressed_stream_index_from_coord(cmd.weight_box.start_coord)
647 weight_addr = cmd.weight_tensor.address_for_coordinate(cmd.weight_box.start_coord)
648 weight_len = cmd.weight_tensor.size_of_compressed_stream(stream_index)
649 # Select weight/scale region depending on where permanent storage was defined
650 weight_region = base_ptr_idx_map[cmd.weight_tensor.mem_area]
651 if arch.permanent_storage_mem_area == MemArea.Sram:
652 weight_region = BasePointerIndex.ReadOnly
653 emit.cmd0_with_param(cmd0.NPU_SET_WEIGHT_REGION, weight_region)
654 emit.cmd1_with_offset(cmd1.NPU_SET_WEIGHT_BASE, weight_addr)
655 emit.cmd1_with_offset(cmd1.NPU_SET_WEIGHT_LENGTH, weight_len)
656
657 # Emit Scale & Bias base address commands, with length matching the amount required by
658 # the weight tensors.
659 if cmd.scale_tensor is not None:
660 # Get address and size of the scale/bias data area
661 scale_addr = cmd.scale_tensor.address_for_coordinate(cmd.weight_box.start_coord[-1:])
662 scale_len = (
663 cmd.scale_tensor.address_for_coordinate(cmd.weight_box.end_coord[-1:], True) - scale_addr
664 )
665 # Emit base address for NPU to access scale & bias data
666 scale_region = base_ptr_idx_map[cmd.scale_tensor.mem_area]
667 if arch.permanent_storage_mem_area == MemArea.Sram:
668 scale_region = BasePointerIndex.ReadOnly
669 emit.cmd0_with_param(cmd0.NPU_SET_SCALE_REGION, scale_region)
670 emit.cmd1_with_offset(cmd1.NPU_SET_SCALE_BASE, scale_addr)
671 emit.cmd1_with_offset(cmd1.NPU_SET_SCALE_LENGTH, round_up(scale_len, 16))
672
673 ofm_quant = cmd.ofm_tensor.quantization
674 ofm_quant_qmin = cmd.ofm_tensor.quantization.quant_min
675 ofm_quant_qmax = cmd.ofm_tensor.quantization.quant_max
676 ifm_min = cmd.ifm_tensor.quantization.min
677 ifm_max = cmd.ifm_tensor.quantization.max
678
679 # Emit commands for any fused activation function
Diego Russoea6111a2020-04-14 18:41:58 +0100680 if faf is None:
Tim Hall79d07d22020-04-27 18:20:16 +0100681 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.NONE)
682 # Even if no activation function, values need to be set to override previous values
683 faf_min = ofm_quant_qmin
684 faf_max = ofm_quant_qmax
685 elif faf == "Relu":
686 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.NONE)
687 faf_min = quantise_float32(0.0, ofm_quant.scale_f32, ofm_quant.zero_point)
688 faf_max = ofm_quant_qmax
689 elif faf == "Relu6":
690 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.NONE)
691 faf_min = quantise_float32(0.0, ofm_quant.scale_f32, ofm_quant.zero_point)
692 faf_max = quantise_float32(6.0, ofm_quant.scale_f32, ofm_quant.zero_point)
693 elif faf == "ReluN1To1":
694 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.NONE)
695 faf_min = quantise_float32(-1.0, ofm_quant.scale_f32, ofm_quant.zero_point)
696 faf_max = quantise_float32(1.0, ofm_quant.scale_f32, ofm_quant.zero_point)
697 elif faf == "Tanh":
698 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.TANH)
699 faf_min = quantise_float32(clamp_tanh(ifm_min), ofm_quant.scale_f32, ofm_quant.zero_point)
700 faf_max = quantise_float32(clamp_tanh(ifm_max), ofm_quant.scale_f32, ofm_quant.zero_point)
701 elif faf == "Sigmoid":
702 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.SIGMOID)
703 faf_min = quantise_float32(clamp_sigmoid(ifm_min), ofm_quant.scale_f32, ofm_quant.zero_point)
704 faf_max = quantise_float32(clamp_sigmoid(ifm_max), ofm_quant.scale_f32, ofm_quant.zero_point)
705 else:
706 raise Exception("Unsupported fused_activation_function = " + faf)
707
708 # Activation range needs to be set based upon the quantisation range and the fused activation range
709 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION_MIN, max(ofm_quant_qmin, faf_min))
710 emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION_MAX, min(ofm_quant_qmax, faf_max))
711
712 out_shape = cmd.ofm_box.get_size_shape()
713 if len(out_shape) >= 4:
714 emit.cmd0_with_param(cmd0.NPU_SET_OFM_HEIGHT_M1, out_shape[-3] - 1)
715 else:
716 emit.cmd0_with_param(cmd0.NPU_SET_OFM_HEIGHT_M1, 0)
717 if len(out_shape) >= 2:
718 emit.cmd0_with_param(cmd0.NPU_SET_OFM_WIDTH_M1, out_shape[-2] - 1)
719 else:
720 emit.cmd0_with_param(cmd0.NPU_SET_OFM_WIDTH_M1, 0)
721 emit.cmd0_with_param(cmd0.NPU_SET_OFM_DEPTH_M1, out_shape[-1] - 1)
722
723 if npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct)):
724 in_shape = cmd.ifm_box.get_size_shape()
725 emit.cmd0_with_param(cmd0.NPU_SET_IFM_DEPTH_M1, in_shape[-1] - 1)
726 else:
727 emit.cmd0_with_param(cmd0.NPU_SET_IFM_DEPTH_M1, out_shape[-1] - 1)
728
Jacob Bohlin3c678292020-04-27 10:27:25 +0200729 for tens, box, region_op, ptr_ops, stride_ops, zero_point_op in (
Tim Hall79d07d22020-04-27 18:20:16 +0100730 (
731 cmd.ifm_tensor,
732 cmd.ifm_box,
Jacob Bohlin3c678292020-04-27 10:27:25 +0200733 cmd0.NPU_SET_IFM_REGION,
Tim Hall79d07d22020-04-27 18:20:16 +0100734 (cmd1.NPU_SET_IFM_BASE0, cmd1.NPU_SET_IFM_BASE1, cmd1.NPU_SET_IFM_BASE2, cmd1.NPU_SET_IFM_BASE3),
735 (cmd1.NPU_SET_IFM_STRIDE_C, cmd1.NPU_SET_IFM_STRIDE_Y, cmd1.NPU_SET_IFM_STRIDE_X),
736 cmd0.NPU_SET_IFM_ZERO_POINT,
737 ),
738 (
739 cmd.ifm2_tensor,
740 cmd.ifm2_box,
Jacob Bohlin3c678292020-04-27 10:27:25 +0200741 cmd0.NPU_SET_IFM2_REGION,
Tim Hall79d07d22020-04-27 18:20:16 +0100742 (
743 cmd1.NPU_SET_IFM2_BASE0,
744 cmd1.NPU_SET_IFM2_BASE1,
745 cmd1.NPU_SET_IFM2_BASE2,
746 cmd1.NPU_SET_IFM2_BASE3,
747 ),
748 (cmd1.NPU_SET_IFM2_STRIDE_C, cmd1.NPU_SET_IFM2_STRIDE_Y, cmd1.NPU_SET_IFM2_STRIDE_X),
749 cmd0.NPU_SET_IFM2_ZERO_POINT,
750 ),
751 (
752 cmd.ofm_tensor,
753 cmd.ofm_box,
Jacob Bohlin3c678292020-04-27 10:27:25 +0200754 cmd0.NPU_SET_OFM_REGION,
Tim Hall79d07d22020-04-27 18:20:16 +0100755 (cmd1.NPU_SET_OFM_BASE0, cmd1.NPU_SET_OFM_BASE1, cmd1.NPU_SET_OFM_BASE2, cmd1.NPU_SET_OFM_BASE3),
756 (cmd1.NPU_SET_OFM_STRIDE_C, cmd1.NPU_SET_OFM_STRIDE_Y, cmd1.NPU_SET_OFM_STRIDE_X),
757 cmd0.NPU_SET_OFM_ZERO_POINT,
758 ),
759 ):
760
Diego Russoea6111a2020-04-14 18:41:58 +0100761 if tens is None:
Tim Hall79d07d22020-04-27 18:20:16 +0100762 continue
763
Diego Russoea6111a2020-04-14 18:41:58 +0100764 need_zero_point = (faf is not None) or (fmf == "ConcatSliceWrite")
Tim Hall79d07d22020-04-27 18:20:16 +0100765 if (
766 primary_op.type in set(("AvgPool", "AvgPoolAct")) and not need_zero_point
Diego Russoea6111a2020-04-14 18:41:58 +0100767 ) or tens.quantization is None:
Tim Hall79d07d22020-04-27 18:20:16 +0100768 # Actual integer operation, just set scale to 1 and zero point to 0
769 emit.cmd0_with_param(zero_point_op, 0)
770 else:
771 assert tens.quantization.zero_point is not None, "need an actual zero point set"
772 emit.cmd0_with_param(zero_point_op, int(tens.quantization.zero_point))
773
774 if tens.shape == []:
775 # Empty shape, elementwise constant
776 ifm2_scalar = tens.quant_values.astype(np.uint8)
777 assert ifm2_scalar.size == 1
778 emit.cmd0_with_param(cmd0.NPU_SET_IFM2_SCALAR, ifm2_scalar.item(0))
779 continue
780
781 height_0, height_1, width_0, addresses = tens.addresses_for_rolling_buffer(
782 box.start_coord, box.end_coord
783 )
784 if npu_block_type != NpuBlockType.VectorProduct:
785 if tens == cmd.ifm_tensor:
786 emit.cmd0_with_param(cmd0.NPU_SET_IFM_HEIGHT0_M1, height_0 - 1)
787 emit.cmd0_with_param(cmd0.NPU_SET_IFM_HEIGHT1_M1, height_1 - 1)
788 emit.cmd0_with_param(cmd0.NPU_SET_IFM_WIDTH0_M1, width_0 - 1)
789 elif tens == cmd.ofm_tensor:
790 emit.cmd0_with_param(cmd0.NPU_SET_OFM_HEIGHT0_M1, height_0 - 1)
791 emit.cmd0_with_param(cmd0.NPU_SET_OFM_HEIGHT1_M1, height_1 - 1)
792 emit.cmd0_with_param(cmd0.NPU_SET_OFM_WIDTH0_M1, width_0 - 1)
793 elif tens == cmd.ifm2_tensor:
794 emit.cmd0_with_param(cmd0.NPU_SET_IFM2_HEIGHT0_M1, height_0 - 1)
795 emit.cmd0_with_param(cmd0.NPU_SET_IFM2_HEIGHT1_M1, height_1 - 1)
796 emit.cmd0_with_param(cmd0.NPU_SET_IFM2_WIDTH0_M1, width_0 - 1)
797 else:
798 if len(out_shape) == 2:
799 # TODO: N is put in W-dimension for now
800 # Should be spread over H and W, but then block size selectetion,
801 # and stride calculation should be changed
802 if tens == cmd.ifm_tensor:
803 emit.cmd0_with_param(cmd0.NPU_SET_IFM_WIDTH0_M1, out_shape[-2] - 1)
804 elif tens == cmd.ofm_tensor:
805 emit.cmd0_with_param(cmd0.NPU_SET_OFM_WIDTH0_M1, out_shape[-2] - 1)
806 else:
807 assert False
808
Jacob Bohlin3c678292020-04-27 10:27:25 +0200809 if tens.mem_area == MemArea.Sram:
810 emit.cmd0_with_param(region_op, BasePointerIndex.Scratch)
811 else:
812 emit.cmd0_with_param(region_op, BasePointerIndex.ReadOnly)
813
Tim Hall79d07d22020-04-27 18:20:16 +0100814 for idx, addr in enumerate(addresses):
815 if addr is None:
816 addresses[idx] = 0
817
818 emit.cmd1_with_offset(ptr_ops[0], addresses[0])
819 emit.cmd1_with_offset(ptr_ops[1], addresses[1])
820 emit.cmd1_with_offset(ptr_ops[2], addresses[2])
821 emit.cmd1_with_offset(ptr_ops[3], addresses[3])
822
823 strides = tens.get_strides()
824 emit.cmd1_with_offset(stride_ops[0], strides[1]) # stride between 16-byte channel blocks (C)
825 emit.cmd1_with_offset(stride_ops[2], strides[3]) # stride between horisontal values (W)
826 emit.cmd1_with_offset(stride_ops[1], strides[2]) # stride between vertical values (H)
827
828 if tens.format == TensorFormat.NHCWB16:
829 # Check that all BasePointer addresses are aligned to 16 bytes
830 assert (int(addresses[0]) % 16) == 0
831 assert (int(addresses[1]) % 16) == 0
832 assert (int(addresses[2]) % 16) == 0
833 assert (int(addresses[3]) % 16) == 0
834
835 ofm_dtype = cmd.ofm_tensor.dtype
836 assert ofm_dtype.type & BaseType.Int
837 prec = 0
838 if ofm_dtype.size_in_bits() == 8:
839 prec = 0
840 elif ofm_dtype.size_in_bits() == 16:
841 prec = 2
842 else:
843 assert 0
844
845 if ofm_dtype.type & BaseType.Signed:
846 prec += 1
847
848 if use_global_scale:
849 # Set global scale bit, as opposed to using per channel scale
850 prec |= 1 << 8
851
852 if cmd.ofm_tensor.format == TensorFormat.NHCWB16:
853 prec |= 1 << 6
854
855 prec |= rounding_mode.value << 14
856
857 emit.cmd0_with_param(cmd0.NPU_SET_OFM_PRECISION, prec)
858
859 prec = None
860 weight_bits = 8
861 if cmd.weight_tensor is not None:
862 weight_bits = cmd.weight_tensor.dtype.size_in_bits()
863
864 ifm_dtype = cmd.ifm_tensor.dtype
865
866 assert weight_bits == 8, "Unsupported weight bit depth"
867 assert ifm_dtype.size_in_bits() in {8, 16}
868
869 if ifm_dtype.size_in_bits() == 8:
870 if ifm_dtype.type & BaseType.Signed:
Diqing Zhongfed918b2020-04-27 10:27:34 +0200871 prec = ifm_precision.S8
Tim Hall79d07d22020-04-27 18:20:16 +0100872 else:
Diqing Zhongfed918b2020-04-27 10:27:34 +0200873 prec = ifm_precision.U8
Tim Hall79d07d22020-04-27 18:20:16 +0100874 elif ifm_dtype.size_in_bits() == 16:
875 if ifm_dtype.type & BaseType.Signed:
Diqing Zhongfed918b2020-04-27 10:27:34 +0200876 prec = ifm_precision.S16
Tim Hall79d07d22020-04-27 18:20:16 +0100877 else:
Diqing Zhongfed918b2020-04-27 10:27:34 +0200878 prec = ifm_precision.U16
Tim Hall79d07d22020-04-27 18:20:16 +0100879
880 ifm_prec = prec.value
881 ifm2_prec = ifm_prec
882
883 if cmd.ifm_tensor.format == TensorFormat.NHCWB16:
884 ifm_prec |= 1 << 6
885
886 ifm_prec |= op_to_scale << 8
887
888 emit.cmd0_with_param(cmd0.NPU_SET_IFM_PRECISION, ifm_prec)
889
890 if cmd.ifm2_tensor is not None:
891 if cmd.ifm2_tensor.format == TensorFormat.NHCWB16:
892 ifm2_prec |= 1 << 6
893 emit.cmd0_with_param(cmd0.NPU_SET_IFM2_PRECISION, ifm2_prec)
894
895 emit_wait_commands(cmd)
896
897 # Get op parameters
898 cur_ifm_block_depth = get_op_ifmofm_block_depth(arch, cmd)
899 cur_ofm_block = Block(ps.block_config[1], ps.block_config[0], ps.block_config[3])
900 cur_ofm_rect = get_op_ofm_rect(cmd)
901 cur_ifm_rect = get_op_ifm_rect(cmd)
902 cur_kernel = get_op_kernel(cmd.ps)
903 cur_padLT = get_op_padding_lt(cmd)
904 if (prev_kernel is not None) and (cur_kernel is not None) and has_prev_op_dependency(prev_cmd, cmd):
905 if cmd.ifm_tensor.shape == prev_cmd.ofm_tensor.shape:
906 blockdep = arch.calc_block_dep(
907 prev_ifm_rect,
908 prev_ofm_rect,
909 prev_ifm_block_depth,
910 prev_ofm_block,
911 prev_kernel,
912 cur_ifm_rect,
913 cur_ofm_rect,
914 cur_ifm_block_depth,
915 cur_ofm_block,
916 cur_kernel,
917 cur_padLT,
918 )
919 else:
920 blockdep = 0
921 else:
922 blockdep = ArchitectureFeatures.MAX_BLOCKDEP
923
924 # Set between every op (dependent or not)
925 blockdep = min(blockdep, arch.max_blockdep)
926 emit.cmd0_with_param(cmd0.NPU_SET_BLOCKDEP, blockdep)
927 prev_cmd = cmd
928
929 if npu_block_type == NpuBlockType.ConvolutionMxN:
930 emit.cmd_do_operation(cmd0.NPU_OP_CONV)
931 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
932 emit.cmd_do_operation(cmd0.NPU_OP_DEPTHWISE)
933 elif npu_block_type == NpuBlockType.VectorProduct:
934 # Vector product is implemented using a 1x1 convolution
935 emit.cmd_do_operation(cmd0.NPU_OP_CONV)
936 elif npu_block_type == NpuBlockType.Pooling:
937 param = "Max" not in primary_op.type
938 emit.cmd_do_operation(cmd0.NPU_OP_POOL, param=param)
939 elif npu_block_type == NpuBlockType.ElementWise:
940 param = elementwise_mode_map[primary_op.type]
941 emit.cmd_do_operation(cmd0.NPU_OP_ELEMENTWISE, param)
942 else:
943 print("Warning: Skipping register command stream generation for", ps)
944
945 # Fill in final part of command stream:
946 emit.cmd_do_operation(cmd0.NPU_OP_STOP, param=0xFFFF)
947
948 sg.register_command_stream = emit.to_list()
949 if verbose:
950 emit.print_cmds()
951 print("number of commands", len(emit.cmd_stream))
952 print("command stream length in words", len(sg.register_command_stream))