MLBEDSW-3653: Fix type errors in annotated files

This commit corrects a number of type errors
reported by mypy and refactors some parts of
the code which are no longer necessary after
making adjustments to satisfy mypy.

Signed-off-by: Dwight Lidman <dwight.lidman@arm.com>
Change-Id: I16b880b228e57f2a92fb8936f53e94886e0f9f44
diff --git a/ethosu/vela/__init__.py b/ethosu/vela/__init__.py
index 90376be..77c171d 100644
--- a/ethosu/vela/__init__.py
+++ b/ethosu/vela/__init__.py
@@ -14,6 +14,5 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from ._version import __version__
-from .vela import main
 
-__all__ = [main, __version__]
+__all__ = ["main", __version__]
diff --git a/ethosu/vela/data_type.py b/ethosu/vela/data_type.py
index a4b7b53..3ad642a 100644
--- a/ethosu/vela/data_type.py
+++ b/ethosu/vela/data_type.py
@@ -16,6 +16,7 @@
 # Description:
 # Defines the basic numeric type classes for tensors.
 import enum
+from typing import Any
 
 from .numeric_util import round_up_divide
 
@@ -43,6 +44,34 @@
 
     __slots__ = "type", "bits"
 
+    int8: Any
+    int16: Any
+    int32: Any
+    int64: Any
+    uint8: Any
+    uint16: Any
+    uint32: Any
+    uint64: Any
+    quint4: Any
+    quint8: Any
+    quint12: Any
+    quint16: Any
+    quint32: Any
+    qint4: Any
+    qint8: Any
+    qint12: Any
+    qint16: Any
+    qint32: Any
+    float16: Any
+    float32: Any
+    float64: Any
+    string: Any
+    bool: Any
+    resource: Any
+    variant: Any
+    complex64: Any
+    complex128: Any
+
     def __init__(self, type_, bits):
         self.type = type_
         self.bits = bits
diff --git a/ethosu/vela/debug_database.py b/ethosu/vela/debug_database.py
index b5852cd..4f0a50a 100644
--- a/ethosu/vela/debug_database.py
+++ b/ethosu/vela/debug_database.py
@@ -15,6 +15,9 @@
 # limitations under the License.
 import csv
 import io
+from typing import Any
+from typing import Dict
+from typing import List
 
 import lxml.etree as xml
 
@@ -22,28 +25,32 @@
 from .operation import Operation
 
 
+UntypedDict = Dict[Any, Any]
+UntypedList = List[Any]
+
+
 class DebugDatabase:
     NULLREF = -1
     show_warnings = False
 
     SOURCE_TABLE = "source"
-    _sourceUID = {}
+    _sourceUID: UntypedDict = {}
     _sourceHeaders = ["id", "operator", "kernel_w", "kernel_h", "ofm_w", "ofm_h", "ofm_d"]
-    _sourceTable = []
+    _sourceTable: UntypedList = []
 
     OPTIMISED_TABLE = "optimised"
-    _optimisedUID = {}
+    _optimisedUID: UntypedDict = {}
     _optimisedHeaders = ["id", "source_id", "operator", "kernel_w", "kernel_h", "ofm_w", "ofm_h", "ofm_d"]
-    _optimisedTable = []
+    _optimisedTable: UntypedList = []
 
     QUEUE_TABLE = "queue"
     _queueHeaders = ["offset", "cmdstream_id", "optimised_id"]
-    _queueTable = []
+    _queueTable: UntypedList = []
 
     STREAM_TABLE = "cmdstream"
-    _streamUID = {}
+    _streamUID: UntypedDict = {}
     _streamHeaders = ["id", "file_offset"]
-    _streamTable = []
+    _streamTable: UntypedList = []
 
     @classmethod
     def add_source(cls, op: Operation):
diff --git a/ethosu/vela/driver_actions.py b/ethosu/vela/driver_actions.py
index 86bed11..5a85df0 100644
--- a/ethosu/vela/driver_actions.py
+++ b/ethosu/vela/driver_actions.py
@@ -117,7 +117,7 @@
     """Creates driver header and includes the given command
     """
     # Prepare driver actions for this command tensor
-    da_list = []
+    da_list: List[int] = []
     emit_fourcc(da_list, "COP1")
     emit_config(da_list, 0, 1, arch)
     emit_cmd_stream_header(da_list, len(register_command_stream))
diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py
index d057d17..c45bc4e 100644
--- a/ethosu/vela/high_level_command_stream.py
+++ b/ethosu/vela/high_level_command_stream.py
@@ -15,8 +15,6 @@
 # limitations under the License.
 # Description:
 # Contains classes that hold commands for the high-level command stream (one command per DMA or NPU stripe).
-from enum import IntEnum
-
 import numpy as np
 
 from .architecture_features import Block
@@ -144,12 +142,6 @@
     __repr__ = __str__
 
 
-class CommandType(IntEnum):
-    NpuStripe = 0
-    DMA = 1
-    Size = 2
-
-
 class Command:
     def get_ofm_y_range_for_pass(self, ps_requested):
         return None
@@ -158,7 +150,7 @@
         return False
 
     def get_operation_count(self):
-        # returns numpy array of (DPU blocks, dma_ops). Should line up with the CommandType enum
+        # returns numpy array of (DPU blocks, dma_ops).
         return np.array((0, 0))
 
 
@@ -185,7 +177,6 @@
         pad_top=0,
         pad_bottom=0,
     ):
-        self.cmdtype = CommandType.NpuStripe
         self.ps = ps
         self.block_config = block_config
         self.is_first = is_first
@@ -333,7 +324,6 @@
 
 class DMA(Command):
     def __init__(self, ps, in_tensor, out_tensor, box):
-        self.cmdtype = CommandType.DMA
         self.ps = ps
         self.in_tensor = in_tensor
         self.out_tensor = out_tensor
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 7db4931..9e0ed01 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -48,7 +48,6 @@
 from .debug_database import DebugDatabase
 from .high_level_command_stream import Box
 from .high_level_command_stream import Command
-from .high_level_command_stream import CommandType
 from .high_level_command_stream import DMA
 from .high_level_command_stream import NpuStripe
 from .operation import NpuBlockType
@@ -56,7 +55,6 @@
 from .operation import Operation
 from .register_command_stream_generator import generate_command_stream
 from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
-from .register_command_stream_util import is_dma_op
 from .register_command_stream_util import to_npu_kernel
 from .register_command_stream_util import UNARY_ELEMWISE_OPS
 from .tensor import MemType
@@ -168,7 +166,7 @@
     else:
         base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
 
-    return int(base_ptr_idx_map[tens.mem_type])
+    return base_ptr_idx_map[tens.mem_type].value
 
 
 def get_upscale(op: Operation) -> NpuResamplingMode:
@@ -461,9 +459,10 @@
 
 def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
     """Converts the high level command to NpuOperation"""
-    if cmd.cmdtype == CommandType.DMA:
+    npu_op: NpuOperation
+    if isinstance(cmd, DMA):
         npu_op = create_dma_op(cmd, arch)
-    elif cmd.cmdtype == CommandType.NpuStripe:
+    elif isinstance(cmd, NpuStripe):
         npu_block_type = cmd.ps.primary_op.type.npu_block_type
         if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
             npu_op = create_npu_conv2d_op(cmd, arch)
@@ -475,8 +474,6 @@
             npu_op = create_npu_elementwise_op(cmd, arch)
         else:
             assert 0, f"Unknown command type {npu_block_type}"
-    # add a link to the high level command for debugging purposes
-    npu_op.cmd = cmd
     return npu_op
 
 
@@ -486,7 +483,7 @@
     npu_op_list = []
     npu_op_to_cmd = dict()  # map from npu op to high level command
     for cmd in sg.high_level_command_stream:
-        if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.npu_block_type == NpuBlockType.Default:
+        if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
             print("Warning: Skipping register command stream generation for", cmd.ps)
         else:
             npu_op = convert_command_to_npu_op(cmd, arch)
@@ -498,8 +495,8 @@
 
     def add_to_debug_db(npu_op: NpuOperation, offset: int):
         """Adds info to the debug database"""
-        if not is_dma_op(npu_op):
+        if not isinstance(npu_op, NpuDmaOperation):
             cmd = npu_op_to_cmd[npu_op]
             DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
 
-    sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db)
+    sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db, npu_op_to_cmd)
diff --git a/ethosu/vela/lut.py b/ethosu/vela/lut.py
index 8e28b95..8a23b51 100644
--- a/ethosu/vela/lut.py
+++ b/ethosu/vela/lut.py
@@ -20,7 +20,8 @@
 import numpy as np
 
 from . import numeric_util
-from .high_level_command_stream import CommandType
+from .high_level_command_stream import DMA
+from .high_level_command_stream import NpuStripe
 from .tensor import create_const_tensor
 from .tensor import create_equivalence_id
 from .tensor import TensorPurpose
@@ -101,11 +102,11 @@
     lut_start = arch.shram_lut_address
     lut_end = lut_start + arch.shram_lut_size
     for cmd in sg.high_level_command_stream:
-        if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.lut_tensor is None and arch.shram_reserved_unused_banks == 0:
+        if isinstance(cmd, NpuStripe) and cmd.ps.lut_tensor is None and arch.shram_reserved_unused_banks == 0:
             # The command overwrites the last 2 banks containing the LUT; next LUT operation will require DMA
             # TODO: check the command's SHRAM usage in more detail to determine if the LUT is overwritten or not
             lut_state = LUTState()
-        if cmd.cmdtype != CommandType.DMA or cmd.out_tensor.purpose != TensorPurpose.LUT:
+        if not isinstance(cmd, DMA) or cmd.out_tensor.purpose != TensorPurpose.LUT:
             # Non-LUT operation; leave untouched
             cmd_stream.append(cmd)
             continue
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 45fae21..32cba36 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -18,10 +18,17 @@
 import copy
 from collections import namedtuple
 from enum import Enum
+from typing import Any
+from typing import Dict
+from typing import List
 from typing import Optional
+from typing import TYPE_CHECKING
 
 from .numeric_util import full_shape
 
+if TYPE_CHECKING:
+    from .tensor import Tensor
+
 PointXY = namedtuple("PointXY", "x y")
 PointXYZ = namedtuple("PointXYZ", "x y z")
 
@@ -392,9 +399,9 @@
     def __init__(self, op_type: Op, name: str):
         self.type = op_type
         self.name = name
-        self.attrs = {}
-        self.inputs = []
-        self.outputs = []
+        self.attrs: Dict[str, Any] = {}
+        self.inputs: List[Tensor] = []
+        self.outputs: List[Tensor] = []
         self.flops = 0
         self.run_on_npu = True
         # Fused activation function. If not none: operator code.
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index d4947b1..fa56d35 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -20,6 +20,7 @@
 from collections import defaultdict
 from enum import Enum
 from enum import IntEnum
+from typing import Dict
 from typing import List
 from typing import Optional
 
@@ -33,6 +34,7 @@
 from .api import NpuBlockOperation
 from .api import NpuBlockTraversal
 from .api import NpuConv2DOperation
+from .api import NpuConvDepthWiseOperation
 from .api import NpuDataType
 from .api import NpuDmaOperation
 from .api import NpuElementWiseOp
@@ -68,13 +70,13 @@
 from .numeric_util import round_away_zero
 from .numeric_util import round_up_to_int
 from .operation import NpuBlockType
+from .range_set import MemoryAccessSet
 from .register_command_stream_util import calc_blockdep
 from .register_command_stream_util import get_dma_memory_accesses
 from .register_command_stream_util import get_op_memory_accesses
 from .register_command_stream_util import get_strides
 from .register_command_stream_util import get_wait_dependency
 from .register_command_stream_util import has_ifm2
-from .register_command_stream_util import is_dma_op
 from .register_command_stream_util import to_kernel
 from .register_command_stream_util import UNARY_ELEMWISE_OPS
 from .register_command_stream_util import Watermark
@@ -549,15 +551,13 @@
 
 def create_shared_buffer(npu_op: NpuBlockOperation, arch: ArchitectureFeatures) -> SharedBufferAllocation:
     """Creates shared buffer allocation for the given operation"""
-    op_type = npu_op.op_type
-    block_type = NpuBlockType.Default
-    if op_type == NpuOperationType.Conv2D:
+    if isinstance(npu_op, NpuConv2DOperation):
         block_type = NpuBlockType.ConvolutionMxN
-    elif op_type == NpuOperationType.ConvDepthWise:
+    elif isinstance(npu_op, NpuConvDepthWiseOperation):
         block_type = NpuBlockType.ConvolutionDepthWise
-    elif op_type == NpuOperationType.Pooling:
+    elif isinstance(npu_op, NpuPoolingOperation):
         block_type = NpuBlockType.ReduceSum if npu_op.sub_op_type == NpuPoolingOp.REDUCE_SUM else NpuBlockType.Pooling
-    elif op_type == NpuOperationType.ElementWise:
+    elif isinstance(npu_op, NpuElementWiseOperation):
         block_type = NpuBlockType.ElementWise
     else:
         assert 0, "Unsupported operation"
@@ -599,7 +599,7 @@
     generate_activation(emit, npu_op.activation, npu_op.ofm)
     shared_buffer = create_shared_buffer(npu_op, arch)
     generate_block_config(emit, npu_op, arch, shared_buffer)
-    if npu_op.op_type == NpuOperationType.ElementWise:
+    if isinstance(npu_op, NpuElementWiseOperation):
         generate_shram_registers_elementwise(emit, npu_op, arch, shared_buffer)
     else:
         generate_shram_registers_non_elementwise(emit, shared_buffer)
@@ -746,17 +746,20 @@
         print(f"         {stride_str}, tiles: w0={t.width_0}, h0={t.height_0}, h1={t.height_1}, base={addresses}")
 
 
-def print_operation(npu_op: NpuOperation, index: int = 0):
-    pass_info = f", {npu_op.cmd}" if hasattr(npu_op, "cmd") else ""
-    if is_dma_op(npu_op):
+def print_operation(npu_op: NpuOperation, index: int = 0, cmd=None):
+    pass_info = f", {cmd}" if cmd else ""
+    if isinstance(npu_op, NpuOperation) and not isinstance(npu_op, (NpuDmaOperation, NpuBlockOperation)):
+        print(f"{index} {npu_op.op_type.name}{pass_info}")
+        return
+    if isinstance(npu_op, NpuDmaOperation):
         print(f"{index} DMA_START src={npu_op.src}, dest={npu_op.dest}{pass_info}")
         return
     k = None if npu_op.kernel is None else to_kernel(npu_op.kernel)
-    if npu_op.op_type in (NpuOperationType.Pooling, NpuOperationType.ElementWise):
+    if isinstance(npu_op, (NpuPoolingOperation, NpuElementWiseOperation)):
         print(f"{index} {npu_op.sub_op_type.name} {npu_op.op_type.name}:{pass_info}")
     else:
         if (
-            npu_op.op_type == NpuOperationType.Conv2D
+            isinstance(npu_op, NpuConv2DOperation)
             and k.elements_wh() * k.stride.x * k.stride.y * k.dilation.x * k.dilation.y == 1
         ):
             fc = "FullyConnected "
@@ -783,16 +786,19 @@
         if act.op_type != NpuActivationOp.NONE_OR_RELU or act.min is not None or act.max is not None:
             lut = f", lut index={act.lookup_table_index}" if act.op_type == NpuActivationOp.TABLE_LOOKUP else ""
             print(f"      Activation: {act.op_type.name}, min={act.min}, max={act.max}{lut}")
-    if npu_op.op_type == NpuOperationType.Conv2D:
+    if isinstance(npu_op, NpuConv2DOperation):
         print(f"      {npu_op.block_traversal}")
     bh, bw, bc = npu_op.block_config
-    rescale = f", rescale={npu_op.rescale}" if hasattr(npu_op, "rescale") else ""
+    rescale = (
+        f", rescale={npu_op.rescale}" if isinstance(npu_op, (NpuPoolingOperation, NpuElementWiseOperation)) else ""
+    )
     print(f"      Block config: h={bh},w={bw},c={bc}, {npu_op.ifm_upscale}, {npu_op.rounding_mode}{rescale}")
 
 
-def print_operations(npu_op_list: List[NpuOperation]):
+def print_operations(npu_op_list: List[NpuOperation], npu_op_to_cmd=None):
+    npu_op_to_cmd = dict() if npu_op_to_cmd is None else npu_op_to_cmd
     for index, npu_op in enumerate(npu_op_list):
-        print_operation(npu_op, index)
+        print_operation(npu_op, index, npu_op_to_cmd.get(npu_op))
 
 
 # -------------------------------------------------------------------
@@ -802,16 +808,15 @@
 
 def generate_operation_code(emit: CommandStreamEmitter, npu_op: NpuOperation):
     """Generates NPU_OP_* command"""
-    op_type = npu_op.op_type
-    if op_type == NpuOperationType.Dma:
+    if isinstance(npu_op, NpuDmaOperation):
         emit.cmd_do_operation(cmd0.NPU_OP_DMA_START, npu_op.channel * 16 + npu_op.mode)
-    elif op_type == NpuOperationType.Conv2D:
+    elif isinstance(npu_op, NpuConv2DOperation):
         emit.cmd_do_operation(cmd0.NPU_OP_CONV)
-    elif op_type == NpuOperationType.ConvDepthWise:
+    elif isinstance(npu_op, NpuConvDepthWiseOperation):
         emit.cmd_do_operation(cmd0.NPU_OP_DEPTHWISE)
-    elif op_type == NpuOperationType.Pooling:
+    elif isinstance(npu_op, NpuPoolingOperation):
         emit.cmd_do_operation(cmd0.NPU_OP_POOL, param=pooling_op_map[npu_op.sub_op_type])
-    elif op_type == NpuOperationType.ElementWise:
+    elif isinstance(npu_op, NpuElementWiseOperation):
         emit.cmd_do_operation(cmd0.NPU_OP_ELEMENTWISE, param=elementwise_op_map[npu_op.sub_op_type])
     else:
         assert 0, "Unsupported operation"
@@ -822,7 +827,9 @@
     generate_common(emit, npu_op, npu_op.block_traversal, arch)
 
 
-def generate_conv_depthwise_op(emit: CommandStreamEmitter, npu_op: NpuPoolingOperation, arch: ArchitectureFeatures):
+def generate_conv_depthwise_op(
+    emit: CommandStreamEmitter, npu_op: NpuConvDepthWiseOperation, arch: ArchitectureFeatures
+):
     """Generates register commands for depthwise convolution operations"""
     generate_common(emit, npu_op, NpuBlockTraversal.DEPTH_FIRST, arch)
 
@@ -880,23 +887,22 @@
     Generates register commands for the given operation, but not the final NPU_OP_... command.
     Returns the selected block config
     """
-    op_type = npu_op.op_type
-    if op_type == NpuOperationType.Conv2D:
+    if isinstance(npu_op, NpuConv2DOperation):
         generate_conv2d_op(emit, npu_op, arch)
-    elif op_type == NpuOperationType.ConvDepthWise:
+    elif isinstance(npu_op, NpuConvDepthWiseOperation):
         generate_conv_depthwise_op(emit, npu_op, arch)
-    elif op_type == NpuOperationType.Pooling:
+    elif isinstance(npu_op, NpuPoolingOperation):
         generate_pooling_op(emit, npu_op, arch)
-    elif op_type == NpuOperationType.ElementWise:
+    elif isinstance(npu_op, NpuElementWiseOperation):
         generate_elementwise_op(emit, npu_op, arch)
-    elif op_type == NpuOperationType.Dma:
+    elif isinstance(npu_op, NpuDmaOperation):
         generate_dma_op(emit, npu_op)
     else:
         assert 0, "Unsupported operation"
 
 
 def generate_command_stream(
-    npu_op_list: List[NpuOperation], arch: ArchitectureFeatures, verbose: bool, add_to_debug_db=None,
+    npu_op_list: List[NpuOperation], arch: ArchitectureFeatures, verbose: bool, add_to_debug_db=None, npu_op_to_cmd=None
 ) -> List[int]:
     """
     Generates register commands for the given list of NPU operations.
@@ -904,14 +910,16 @@
     """
     emit = CommandStreamEmitter()
     if verbose:
-        print_operations(npu_op_list)
+        print_operations(npu_op_list, npu_op_to_cmd)
     # Calculate memory accesses for every operation
-    memory_accesses = {}
+    memory_accesses: Dict[NpuOperation, MemoryAccessSet] = {}
     for npu_op in npu_op_list:
-        if is_dma_op(npu_op):
+        if isinstance(npu_op, NpuDmaOperation):
             memory_accesses[npu_op] = get_dma_memory_accesses(npu_op)
-        else:
+        elif isinstance(npu_op, NpuBlockOperation):
             memory_accesses[npu_op] = get_op_memory_accesses(npu_op, arch)
+        else:
+            assert 0, "Invalid operation type"
     if arch.is_ethos_u65_system:
         emit.cmd0_with_param(cmd0.NPU_SET_PARALLEL_MODE, arch.ncores - 1)
     dep_watermark = Watermark(0, 0)
@@ -920,7 +928,7 @@
     for op_index, npu_op in enumerate(npu_op_list):
         dep_watermark, cmd_waits = get_wait_dependency(arch, npu_op_list, memory_accesses, op_index, dep_watermark)
         generate_registers_for_op(emit, npu_op, arch)
-        if not is_dma_op(npu_op):
+        if not isinstance(npu_op, NpuDmaOperation) and isinstance(npu_op, NpuBlockOperation):
             # Generate BLOCKDEP
             blockdep = calc_blockdep(arch, prev_op, npu_op)
             blockdep = min(blockdep, arch.max_blockdep)
@@ -951,12 +959,12 @@
     """
     Internal implementation of the public facing API for finding block configs.
     """
-    if is_dma_op(npu_op):
-        return []
-    arch = create_default_arch(Accelerator.from_npu_accelerator(npu_accelerator))
-    shared_buffer = create_shared_buffer(npu_op, arch)
-    blocks = find_suitable_block_configs(arch, shared_buffer)
-    return [NpuShape3D(height=block[0], width=block[1], depth=block[3]) for block in blocks]
+    if isinstance(npu_op, NpuBlockOperation):
+        arch = create_default_arch(Accelerator.from_npu_accelerator(npu_accelerator))
+        shared_buffer = create_shared_buffer(npu_op, arch)
+        blocks = find_suitable_block_configs(arch, shared_buffer)
+        return [NpuShape3D(height=block[0], width=block[1], depth=block[3]) for block in blocks]
+    return []
 
 
 def generate_register_command_stream(npu_op_list: List[NpuOperation], npu_accelerator: NpuAccelerator) -> List[int]:
diff --git a/ethosu/vela/register_command_stream_util.py b/ethosu/vela/register_command_stream_util.py
index ce49fc2..55fa620 100644
--- a/ethosu/vela/register_command_stream_util.py
+++ b/ethosu/vela/register_command_stream_util.py
@@ -68,11 +68,6 @@
     return npu_op.ifm2 is not None and npu_op.ifm2_scalar is None
 
 
-def is_dma_op(npu_op: NpuOperation) -> bool:
-    """Checks if op is a DMA operation"""
-    return npu_op.op_type == NpuOperationType.Dma
-
-
 def shape3d_size(shape: NpuShape3D) -> int:
     return shape.width * shape.height * shape.depth
 
@@ -302,9 +297,9 @@
         prev_access = memory_accesses[prev_op]
 
         # Check NPU consuming DMA output
-        if is_dma_op(prev_op):
+        if isinstance(prev_op, NpuDmaOperation):
             if index >= dma_index:
-                if not is_dma_op(npu_op):
+                if not isinstance(npu_op, NpuDmaOperation):
                     if (dma_outstanding == -1) and prev_access.conflicts(op_access):
                         dma_outstanding = dma_ops
                 dma_ops += 1  # Count DMA ops in the pipeline
@@ -313,7 +308,7 @@
         # Check DMA consuming NPU output
         else:
             if index >= npu_index:
-                if is_dma_op(npu_op) and npu_outstanding == -1 and prev_access.conflicts(op_access):
+                if isinstance(npu_op, NpuDmaOperation) and npu_outstanding == -1 and prev_access.conflicts(op_access):
                     npu_outstanding = npu_ops
                 npu_ops += 1  # Count NPU ops in the pipeline
                 if npu_ops >= arch.max_outstanding_kernels: