MLBEDSW-3653: Fix type errors in annotated files

This commit corrects a number of type errors
reported by mypy and refactors some parts of
the code which are no longer necessary after
making adjustments to satisfy mypy.

Signed-off-by: Dwight Lidman <dwight.lidman@arm.com>
Change-Id: I16b880b228e57f2a92fb8936f53e94886e0f9f44
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index d4947b1..fa56d35 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -20,6 +20,7 @@
 from collections import defaultdict
 from enum import Enum
 from enum import IntEnum
+from typing import Dict
 from typing import List
 from typing import Optional
 
@@ -33,6 +34,7 @@
 from .api import NpuBlockOperation
 from .api import NpuBlockTraversal
 from .api import NpuConv2DOperation
+from .api import NpuConvDepthWiseOperation
 from .api import NpuDataType
 from .api import NpuDmaOperation
 from .api import NpuElementWiseOp
@@ -68,13 +70,13 @@
 from .numeric_util import round_away_zero
 from .numeric_util import round_up_to_int
 from .operation import NpuBlockType
+from .range_set import MemoryAccessSet
 from .register_command_stream_util import calc_blockdep
 from .register_command_stream_util import get_dma_memory_accesses
 from .register_command_stream_util import get_op_memory_accesses
 from .register_command_stream_util import get_strides
 from .register_command_stream_util import get_wait_dependency
 from .register_command_stream_util import has_ifm2
-from .register_command_stream_util import is_dma_op
 from .register_command_stream_util import to_kernel
 from .register_command_stream_util import UNARY_ELEMWISE_OPS
 from .register_command_stream_util import Watermark
@@ -549,15 +551,13 @@
 
 def create_shared_buffer(npu_op: NpuBlockOperation, arch: ArchitectureFeatures) -> SharedBufferAllocation:
     """Creates shared buffer allocation for the given operation"""
-    op_type = npu_op.op_type
-    block_type = NpuBlockType.Default
-    if op_type == NpuOperationType.Conv2D:
+    if isinstance(npu_op, NpuConv2DOperation):
         block_type = NpuBlockType.ConvolutionMxN
-    elif op_type == NpuOperationType.ConvDepthWise:
+    elif isinstance(npu_op, NpuConvDepthWiseOperation):
         block_type = NpuBlockType.ConvolutionDepthWise
-    elif op_type == NpuOperationType.Pooling:
+    elif isinstance(npu_op, NpuPoolingOperation):
         block_type = NpuBlockType.ReduceSum if npu_op.sub_op_type == NpuPoolingOp.REDUCE_SUM else NpuBlockType.Pooling
-    elif op_type == NpuOperationType.ElementWise:
+    elif isinstance(npu_op, NpuElementWiseOperation):
         block_type = NpuBlockType.ElementWise
     else:
         assert 0, "Unsupported operation"
@@ -599,7 +599,7 @@
     generate_activation(emit, npu_op.activation, npu_op.ofm)
     shared_buffer = create_shared_buffer(npu_op, arch)
     generate_block_config(emit, npu_op, arch, shared_buffer)
-    if npu_op.op_type == NpuOperationType.ElementWise:
+    if isinstance(npu_op, NpuElementWiseOperation):
         generate_shram_registers_elementwise(emit, npu_op, arch, shared_buffer)
     else:
         generate_shram_registers_non_elementwise(emit, shared_buffer)
@@ -746,17 +746,20 @@
         print(f"         {stride_str}, tiles: w0={t.width_0}, h0={t.height_0}, h1={t.height_1}, base={addresses}")
 
 
-def print_operation(npu_op: NpuOperation, index: int = 0):
-    pass_info = f", {npu_op.cmd}" if hasattr(npu_op, "cmd") else ""
-    if is_dma_op(npu_op):
+def print_operation(npu_op: NpuOperation, index: int = 0, cmd=None):
+    pass_info = f", {cmd}" if cmd else ""
+    if isinstance(npu_op, NpuOperation) and not isinstance(npu_op, (NpuDmaOperation, NpuBlockOperation)):
+        print(f"{index} {npu_op.op_type.name}{pass_info}")
+        return
+    if isinstance(npu_op, NpuDmaOperation):
         print(f"{index} DMA_START src={npu_op.src}, dest={npu_op.dest}{pass_info}")
         return
     k = None if npu_op.kernel is None else to_kernel(npu_op.kernel)
-    if npu_op.op_type in (NpuOperationType.Pooling, NpuOperationType.ElementWise):
+    if isinstance(npu_op, (NpuPoolingOperation, NpuElementWiseOperation)):
         print(f"{index} {npu_op.sub_op_type.name} {npu_op.op_type.name}:{pass_info}")
     else:
         if (
-            npu_op.op_type == NpuOperationType.Conv2D
+            isinstance(npu_op, NpuConv2DOperation)
             and k.elements_wh() * k.stride.x * k.stride.y * k.dilation.x * k.dilation.y == 1
         ):
             fc = "FullyConnected "
@@ -783,16 +786,19 @@
         if act.op_type != NpuActivationOp.NONE_OR_RELU or act.min is not None or act.max is not None:
             lut = f", lut index={act.lookup_table_index}" if act.op_type == NpuActivationOp.TABLE_LOOKUP else ""
             print(f"      Activation: {act.op_type.name}, min={act.min}, max={act.max}{lut}")
-    if npu_op.op_type == NpuOperationType.Conv2D:
+    if isinstance(npu_op, NpuConv2DOperation):
         print(f"      {npu_op.block_traversal}")
     bh, bw, bc = npu_op.block_config
-    rescale = f", rescale={npu_op.rescale}" if hasattr(npu_op, "rescale") else ""
+    rescale = (
+        f", rescale={npu_op.rescale}" if isinstance(npu_op, (NpuPoolingOperation, NpuElementWiseOperation)) else ""
+    )
     print(f"      Block config: h={bh},w={bw},c={bc}, {npu_op.ifm_upscale}, {npu_op.rounding_mode}{rescale}")
 
 
-def print_operations(npu_op_list: List[NpuOperation]):
+def print_operations(npu_op_list: List[NpuOperation], npu_op_to_cmd=None):
+    npu_op_to_cmd = dict() if npu_op_to_cmd is None else npu_op_to_cmd
     for index, npu_op in enumerate(npu_op_list):
-        print_operation(npu_op, index)
+        print_operation(npu_op, index, npu_op_to_cmd.get(npu_op))
 
 
 # -------------------------------------------------------------------
@@ -802,16 +808,15 @@
 
 def generate_operation_code(emit: CommandStreamEmitter, npu_op: NpuOperation):
     """Generates NPU_OP_* command"""
-    op_type = npu_op.op_type
-    if op_type == NpuOperationType.Dma:
+    if isinstance(npu_op, NpuDmaOperation):
         emit.cmd_do_operation(cmd0.NPU_OP_DMA_START, npu_op.channel * 16 + npu_op.mode)
-    elif op_type == NpuOperationType.Conv2D:
+    elif isinstance(npu_op, NpuConv2DOperation):
         emit.cmd_do_operation(cmd0.NPU_OP_CONV)
-    elif op_type == NpuOperationType.ConvDepthWise:
+    elif isinstance(npu_op, NpuConvDepthWiseOperation):
         emit.cmd_do_operation(cmd0.NPU_OP_DEPTHWISE)
-    elif op_type == NpuOperationType.Pooling:
+    elif isinstance(npu_op, NpuPoolingOperation):
         emit.cmd_do_operation(cmd0.NPU_OP_POOL, param=pooling_op_map[npu_op.sub_op_type])
-    elif op_type == NpuOperationType.ElementWise:
+    elif isinstance(npu_op, NpuElementWiseOperation):
         emit.cmd_do_operation(cmd0.NPU_OP_ELEMENTWISE, param=elementwise_op_map[npu_op.sub_op_type])
     else:
         assert 0, "Unsupported operation"
@@ -822,7 +827,9 @@
     generate_common(emit, npu_op, npu_op.block_traversal, arch)
 
 
-def generate_conv_depthwise_op(emit: CommandStreamEmitter, npu_op: NpuPoolingOperation, arch: ArchitectureFeatures):
+def generate_conv_depthwise_op(
+    emit: CommandStreamEmitter, npu_op: NpuConvDepthWiseOperation, arch: ArchitectureFeatures
+):
     """Generates register commands for depthwise convolution operations"""
     generate_common(emit, npu_op, NpuBlockTraversal.DEPTH_FIRST, arch)
 
@@ -880,23 +887,22 @@
     Generates register commands for the given operation, but not the final NPU_OP_... command.
     Returns the selected block config
     """
-    op_type = npu_op.op_type
-    if op_type == NpuOperationType.Conv2D:
+    if isinstance(npu_op, NpuConv2DOperation):
         generate_conv2d_op(emit, npu_op, arch)
-    elif op_type == NpuOperationType.ConvDepthWise:
+    elif isinstance(npu_op, NpuConvDepthWiseOperation):
         generate_conv_depthwise_op(emit, npu_op, arch)
-    elif op_type == NpuOperationType.Pooling:
+    elif isinstance(npu_op, NpuPoolingOperation):
         generate_pooling_op(emit, npu_op, arch)
-    elif op_type == NpuOperationType.ElementWise:
+    elif isinstance(npu_op, NpuElementWiseOperation):
         generate_elementwise_op(emit, npu_op, arch)
-    elif op_type == NpuOperationType.Dma:
+    elif isinstance(npu_op, NpuDmaOperation):
         generate_dma_op(emit, npu_op)
     else:
         assert 0, "Unsupported operation"
 
 
 def generate_command_stream(
-    npu_op_list: List[NpuOperation], arch: ArchitectureFeatures, verbose: bool, add_to_debug_db=None,
+    npu_op_list: List[NpuOperation], arch: ArchitectureFeatures, verbose: bool, add_to_debug_db=None, npu_op_to_cmd=None
 ) -> List[int]:
     """
     Generates register commands for the given list of NPU operations.
@@ -904,14 +910,16 @@
     """
     emit = CommandStreamEmitter()
     if verbose:
-        print_operations(npu_op_list)
+        print_operations(npu_op_list, npu_op_to_cmd)
     # Calculate memory accesses for every operation
-    memory_accesses = {}
+    memory_accesses: Dict[NpuOperation, MemoryAccessSet] = {}
     for npu_op in npu_op_list:
-        if is_dma_op(npu_op):
+        if isinstance(npu_op, NpuDmaOperation):
             memory_accesses[npu_op] = get_dma_memory_accesses(npu_op)
-        else:
+        elif isinstance(npu_op, NpuBlockOperation):
             memory_accesses[npu_op] = get_op_memory_accesses(npu_op, arch)
+        else:
+            assert 0, "Invalid operation type"
     if arch.is_ethos_u65_system:
         emit.cmd0_with_param(cmd0.NPU_SET_PARALLEL_MODE, arch.ncores - 1)
     dep_watermark = Watermark(0, 0)
@@ -920,7 +928,7 @@
     for op_index, npu_op in enumerate(npu_op_list):
         dep_watermark, cmd_waits = get_wait_dependency(arch, npu_op_list, memory_accesses, op_index, dep_watermark)
         generate_registers_for_op(emit, npu_op, arch)
-        if not is_dma_op(npu_op):
+        if not isinstance(npu_op, NpuDmaOperation) and isinstance(npu_op, NpuBlockOperation):
             # Generate BLOCKDEP
             blockdep = calc_blockdep(arch, prev_op, npu_op)
             blockdep = min(blockdep, arch.max_blockdep)
@@ -951,12 +959,12 @@
     """
     Internal implementation of the public facing API for finding block configs.
     """
-    if is_dma_op(npu_op):
-        return []
-    arch = create_default_arch(Accelerator.from_npu_accelerator(npu_accelerator))
-    shared_buffer = create_shared_buffer(npu_op, arch)
-    blocks = find_suitable_block_configs(arch, shared_buffer)
-    return [NpuShape3D(height=block[0], width=block[1], depth=block[3]) for block in blocks]
+    if isinstance(npu_op, NpuBlockOperation):
+        arch = create_default_arch(Accelerator.from_npu_accelerator(npu_accelerator))
+        shared_buffer = create_shared_buffer(npu_op, arch)
+        blocks = find_suitable_block_configs(arch, shared_buffer)
+        return [NpuShape3D(height=block[0], width=block[1], depth=block[3]) for block in blocks]
+    return []
 
 
 def generate_register_command_stream(npu_op_list: List[NpuOperation], npu_accelerator: NpuAccelerator) -> List[int]: