MLBEDSW-6260: Add support for using DMA to copy feature maps

- Reshape ops can be bypassed and there is no need to process them by the NPU.
There are use cases when the IFM must be preserved so a memcpy is needed.
This is implemented by an AvgPool.
- In order to reduce the cost of the AvgPool the IFM can be copied by DMA.
This is faster and also it can be turned into a real NOP in cases where
the IFM and the OFM can use the same memory space.
- Added new memcpy op. Only NHWC format supported since DMA can not change
the format on the fly.
- Allow ofm to reuse ifm for memcpy op
- Make sure the DMA copy size is 16 byte aligned

Change-Id: I3605a48d47646ff60d2bb3644dd3a23f872235a7
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index 24a5583..e8d5ac6 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -27,7 +27,7 @@
 from .errors import UnsupportedFeatureError
 from .errors import VelaError
 from .operation import Op
-from .operation_util import create_avgpool_nop
+from .operation_util import create_memcpy
 from .shape4d import Shape4D
 from .tensor import create_const_tensor
 from .tensor import QuantizationParameters
@@ -89,6 +89,11 @@
     return False
 
 
+def _avoid_nhcwb16_for_memory_only(tens):
+    # check all producers/consumers to see if any op is preventing NHCWB16
+    return any(op.type == Op.Memcpy for op in (tens.consumer_list + tens.ops))
+
+
 # Check if non linear format can be used
 def check_format_restrictions(tens, arch):
     if len(tens.ops) < 1:
@@ -116,6 +121,10 @@
     if _avoid_nhcwb16_for_shapes(tens):
         return
 
+    # Memory only ifm/ofm exception: DMA ops must use NHCW
+    if _avoid_nhcwb16_for_memory_only(tens):
+        return
+
     # Resize bilinear half pixel center implementation requires OFM with linear format to
     # allow stride modification in H/W dimensions.
     for op in tens.ops:
@@ -274,10 +283,10 @@
 
 
 def insert_copy_op_before_op(op):
-    # Create a avg_pool nop op with ifm as input
+    # Create a memcpy op with ifm as input
     tens = op.ifm
     copy_tens = tens.clone()
-    copy_op = create_avgpool_nop(f"{tens.name}_avgpool")
+    copy_op = create_memcpy(f"{tens.name}_memcpy")
     copy_op.add_input_tensor(tens)
     copy_op.set_output_tensor(copy_tens)
     copy_op.set_ifm_ofm_shapes()
@@ -290,9 +299,9 @@
 def insert_copy_op_after_tens(tens):
     tens_cons_list_copy = tens.consumer_list.copy()
 
-    # Create a avg_pool nop op with ifm as input
+    # Create a mempcy op with ifm as input
     copy_tens = tens.clone()
-    copy_op = create_avgpool_nop(tens.name + "_avgpool")
+    copy_op = create_memcpy(tens.name + "_memcpy")
     copy_op.add_input_tensor(tens)
     copy_op.set_output_tensor(copy_tens)
     copy_op.set_ifm_ofm_shapes()
diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py
index 609f855..09c1805 100644
--- a/ethosu/vela/high_level_command_stream.py
+++ b/ethosu/vela/high_level_command_stream.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -293,3 +293,19 @@
     def get_operation_count(self):
         # returns numpy array of (DPU blocks, dma_ops)
         return np.array((0, 1))
+
+
+class NOP(Command):
+    def __init__(self, ps, in_tensor, out_tensor):
+        self.ps = ps
+        self.in_tensor = in_tensor
+        self.out_tensor = out_tensor
+
+    def __str__(self):
+        return f"<NOP: in={self.in_tensor.name}, out={self.out_tensor.name}>"
+
+    __repr__ = __str__
+
+    def get_operation_count(self):
+        # returns numpy array of (DPU blocks, dma_ops)
+        return np.array((0, 0))
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index 5f6a93a..770241b 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -18,6 +18,7 @@
 # Generate a high-level command stream from a schedule
 from .high_level_command_stream import Box
 from .high_level_command_stream import DMA
+from .high_level_command_stream import NOP
 from .high_level_command_stream import NpuStripe
 from .numeric_util import round_up_divide
 from .operation import create_activation_function
@@ -33,6 +34,19 @@
         yield DMA(ps, src_tensor, tensor, box)
 
 
+def dma_feature_map_if_necessary(ps, src_tensor, dst_tensor):
+    box = Box([0] * len(src_tensor.shape), list(src_tensor.shape))
+    src_addr = src_tensor.address_for_coordinate(box.start_coord)
+    dst_addr = dst_tensor.address_for_coordinate(box.start_coord)
+
+    if src_addr != dst_addr or src_tensor.mem_area != dst_tensor.mem_area:
+        yield DMA(ps, src_tensor, dst_tensor, box)
+    else:
+        # Source and destination is the same so no need for a DMA transaction
+        # Create a NOP for visibility when printing the high_level_command_stream
+        yield NOP(ps, src_tensor, dst_tensor)
+
+
 def generate_high_level_command_stream_for_schedule(nng, sg, arch, verbose_high_level_command_stream):
     res = []
     # sg.sched_ops are ordered by execution
@@ -224,21 +238,24 @@
                     lut_dma_done = True
                     yield from dma_if_necessary(sched_op.parent_ps, lut_box, lut_tensor)
 
-                yield NpuStripe(
-                    sched_op.parent_ps,
-                    block_config.old_style_representation(),
-                    is_first_h_stripe,
-                    is_last_h_stripe,
-                    ifm_tensor,
-                    ifm_box,
-                    ofm_tensor,
-                    ofm_box,
-                    weight_tensor,
-                    weight_box,
-                    scale_tensor,
-                    ifm2_tensor=ifm2_tensor,
-                    ifm2_box=ifm2_box,
-                    pad_top=pad_top,
-                    pad_bottom=pad_bottom,
-                    reversed_operands=sched_op.reversed_operands,
-                )
+                if parent_op.type == Op.Memcpy:
+                    yield from dma_feature_map_if_necessary(sched_op.parent_ps, ifm_tensor, ofm_tensor)
+                else:
+                    yield NpuStripe(
+                        sched_op.parent_ps,
+                        block_config.old_style_representation(),
+                        is_first_h_stripe,
+                        is_last_h_stripe,
+                        ifm_tensor,
+                        ifm_box,
+                        ofm_tensor,
+                        ofm_box,
+                        weight_tensor,
+                        weight_box,
+                        scale_tensor,
+                        ifm2_tensor=ifm2_tensor,
+                        ifm2_box=ifm2_box,
+                        pad_top=pad_top,
+                        pad_bottom=pad_bottom,
+                        reversed_operands=sched_op.reversed_operands,
+                    )
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 2c62c6f..7634fe1 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -54,6 +54,7 @@
 from .high_level_command_stream import Box
 from .high_level_command_stream import Command
 from .high_level_command_stream import DMA
+from .high_level_command_stream import NOP
 from .high_level_command_stream import NpuStripe
 from .numeric_util import quantise_float32
 from .numeric_util import round_up
@@ -627,7 +628,8 @@
     else:
         src_addr = cmd.in_tensor.address_for_coordinate(cmd.box.start_coord)
         dest_addr = cmd.out_tensor.address_for_coordinate(cmd.box.start_coord)
-        sz = cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr
+        # DMA must use 16 bytes alignment (tensors are always aligned but the sz calculation uses actual size)
+        sz = round_up(cmd.in_tensor.address_for_coordinate(cmd.box.end_coord, is_top_box=True) - src_addr, 16)
     src = NpuAddressRange(src_region, int(src_addr), int(sz))
     dest = NpuAddressRange(dest_region, int(dest_addr), int(sz))
     return NpuDmaOperation(src, dest)
@@ -663,6 +665,9 @@
     for cmd in sg.high_level_command_stream:
         if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
             print("Warning: Skipping register command stream generation for", cmd.ps)
+        elif isinstance(cmd, NOP):
+            # NOP should not generate anything
+            continue
         else:
             npu_op = convert_command_to_npu_op(cmd, arch)
             npu_op_list.append(npu_op)
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 05e481e..995a0cc 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -165,16 +165,11 @@
 
 
 def _get_ifm_to_fuse(sched_op, target_mem_area=None, target_mem_type_set=None):
-    def _tensor_should_be_ignored(tens):
-        if tens.ifm_write_protected:
-            return True
-        return tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set)
-
-    # Check if possible to merge ifm/ofm live ranges of elementwise op
     ifm_tens = None
     if sched_op.op_type.is_elementwise_op():
+        # Check if possible to merge ifm/ofm live ranges of elementwise op
         elem_op = sched_op.parent_op
-        if not _tensor_should_be_ignored(elem_op.ofm):
+        if not tensor_should_be_ignored(elem_op.ofm, target_mem_area, target_mem_type_set):
             # Check if overwriting the inputs can be allowed
             OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"])
             outp = OpShapeTens(elem_op.ofm_shapes[0], elem_op.ofm)
@@ -183,7 +178,6 @@
                 inps.append(OpShapeTens(elem_op.ifm_shapes[0], elem_op.ifm))
             if elem_op.ifm2 is not None:
                 inps.append(OpShapeTens(elem_op.ifm_shapes[1], elem_op.ifm2))
-
             # find an input tensor that can be overwritten by the output
             for inp in inps:
                 if (
@@ -192,7 +186,8 @@
                     # check input tensor is valid
                     and inp.tens is not None
                     and inp.tens.shape != []
-                    and not _tensor_should_be_ignored(inp.tens)
+                    and not inp.tens.ifm_write_protected
+                    and not tensor_should_be_ignored(inp.tens, target_mem_area, target_mem_type_set)
                     # check input and output tensors are compatible
                     and inp.tens.format == outp.tens.format
                     and inp.tens.dtype == outp.tens.dtype
@@ -203,6 +198,17 @@
                 ):
                     ifm_tens = inp.tens
                     break
+    elif sched_op.op_type == Op.Memcpy:
+        # Check if possible to merge ifm/ofm live ranges of dma op
+        dma_op = sched_op.parent_op
+        ifm = dma_op.ifm
+        ofm = dma_op.ofm
+        if not (
+            tensor_should_be_ignored(ifm, target_mem_area, target_mem_type_set)
+            or tensor_should_be_ignored(ofm, target_mem_area, target_mem_type_set)
+        ):
+            # Currently DMA only used when bypassing memory only ops so ok to reuse ifm
+            ifm_tens = ifm
 
     return ifm_tens
 
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index 967a7ac..8001124 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -472,6 +472,10 @@
             _estimate_output_cycles_per_element(arch, op_type, faf_type, query)
             * Shape4D.round_up(query.ofm_shape, ofm_rounding).elements()
         )
+    # DMA cycle calculation
+    elif query.npu_block_type == NpuBlockType.Dma:
+        # Return 0 since this is not an actual NPU op
+        cycles.op_cycles = 0
     else:
         assert False
 
@@ -541,6 +545,10 @@
                 elif query.ifm2_bits > 8:
                     # ifm2 is a non 8-bit scalar
                     access.ifm_read[1] = Shape4D.round_up(query.ifm2_shape, ifm_rounding).elements()
+    # DMA
+    elif query.npu_block_type == NpuBlockType.Dma:
+        # Return empty access since this is not an actual NPU op
+        return access
     # Unknown
     else:
         assert False
@@ -646,18 +654,28 @@
 
     # LUT Transfer
     parent_op = op.parent_op
-    lut_transfer_cycles = 0
+    dma_transfer_cycles = 0
     if parent_op.activation_lut:
         lut_tensor = [tens for tens in parent_op.inputs if tens.purpose == TensorPurpose.LUT][0]
         src_tensor = lut_tensor.src_tensor
         if src_tensor and lut_tensor.mem_area != src_tensor.mem_area:
             bw = src_tensor.storage_size()
-            lut_transfer_cycles = measure_mem2mem_cycles(arch, src_tensor.mem_area, lut_tensor.mem_area, bw)
+            dma_transfer_cycles += measure_mem2mem_cycles(arch, src_tensor.mem_area, lut_tensor.mem_area, bw)
 
             bws[src_tensor.mem_area][lut_tensor.purpose][BandwidthDirection.Read] += bw
             # LUT read from SHRAM TODO remove?
             scaled_bws[lut_tensor.mem_area][lut_tensor.purpose][BandwidthDirection.Read] += bw
 
+    # DMA Transfer
+    if parent_op.type == Op.Memcpy:
+        src_tensor = parent_op.ifm
+        dst_tensor = parent_op.ofm
+        if src_tensor.mem_area != dst_tensor.mem_area:
+            bw = src_tensor.storage_size()
+            dma_transfer_cycles += measure_mem2mem_cycles(arch, src_tensor.mem_area, dst_tensor.mem_area, bw)
+            bws[src_tensor.mem_area][src_tensor.purpose][BandwidthDirection.Read] += bw
+            bws[dst_tensor.mem_area][src_tensor.purpose][BandwidthDirection.Write] += bw
+
     if cost.npu_weights_tensor and cost.buffered_weight_tensors:
         # DMA Weight Transfer
         sz = 0
@@ -690,11 +708,11 @@
                 cycles.op_cycles + cost.full_weight_transfer_cycles - min(ws_first_transfer_cycles, slack_cycles)
             )
 
-        # Add cycles for LUT Transfer
-        cycles_a[PassCycles.Npu] += lut_transfer_cycles
+        # Add cycles for LUT + mempcy op Transfer
+        cycles_a[PassCycles.Npu] += dma_transfer_cycles
     else:
-        # Add cycles for LUT Transfer
-        cycles_a[PassCycles.Npu] += max(lut_transfer_cycles - slack_cycles, 0)
+        # Add cycles for LUT + mempcy op Transfer
+        cycles_a[PassCycles.Npu] += max(dma_transfer_cycles - slack_cycles, 0)
 
     # OFM write
     ofm = op.parent_op.ofm
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 19b00b3..6be9dc2 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -51,6 +51,7 @@
     ConvolutionDepthWise = 4
     ElementWise = 5
     ReduceSum = 6
+    Dma = 7
 
 
 class Kernel:
@@ -174,6 +175,7 @@
     )
     Dequantize = OperatorInfo(indices=NNG_IFM_INDICES)
     Div = OperatorInfo()
+    Memcpy = OperatorInfo(block_type=NpuBlockType.Dma, indices=NNG_IFM_INDICES)
     Elu = OperatorInfo()
     EmbeddingLookup = OperatorInfo()
     EmbeddingLookupSparse = OperatorInfo()
@@ -373,6 +375,9 @@
     def is_resize_op(self):
         return self in (Op.ResizeBilinear, Op.ResizeNearestNeighbor)
 
+    def is_memcpy_op(self):
+        return self.info.block_type == NpuBlockType.Dma
+
     def needs_bias(self):
         return bool(self.info.indices.biases)
 
diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py
index 7b66dff..21f9dbe 100644
--- a/ethosu/vela/operation_util.py
+++ b/ethosu/vela/operation_util.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -51,6 +51,12 @@
     return op
 
 
+def create_memcpy(name: str) -> Operation:
+    op = Operation(Op.Memcpy, name)
+    op.run_on_npu = True
+    return op
+
+
 def create_pad_nop(name: str) -> Operation:
     op = Operation(Op.Pad, name)
     op.run_on_npu = True
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 5a9f957..e43a919 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -39,6 +39,7 @@
     StartupInit = 64
     MemoryOnly = 128
     PostFusingLimited = 256
+    Memcpy = 512
 
 
 mac_main_ops = set(
@@ -95,6 +96,7 @@
         Op.ExpandDims,
     )
 )
+memcpy_ops = set((Op.Memcpy,))
 
 
 test_sequence = [
@@ -160,6 +162,16 @@
     ),
     (
         # ops_set
+        memcpy_ops,
+        # incompatible_pack_flags
+        PassFlags.Cpu | PassFlags.MemoryOnly | PassFlags.Mac | PassFlags.Main | PassFlags.PostFusingLimited,
+        # flags_to_set
+        PassFlags.Npu | PassFlags.Memcpy | PassFlags.Main,
+        # flags_to_clear
+        PassFlags.Empty,
+    ),
+    (
+        # ops_set
         cpu_ops,
         # incompatible_pack_flags
         PassFlags.Npu | PassFlags.MemoryOnly | PassFlags.Main,
@@ -248,7 +260,11 @@
 
                         if flags_to_set & PassFlags.Npu:
                             if flags_to_set & (
-                                PassFlags.Mac | PassFlags.ElementWise | PassFlags.Post | PassFlags.PostFusingLimited
+                                PassFlags.Mac
+                                | PassFlags.ElementWise
+                                | PassFlags.Post
+                                | PassFlags.PostFusingLimited
+                                | PassFlags.Memcpy
                             ):
                                 assert len(curr_op.inputs) >= 1
                                 ifm_tensor = curr_op.ifm