MLBEDSW-3146: Cycle estimation for conv/pooling ops

Signed-off-by: Diqing Zhong <diqing.zhong@arm.com>
Change-Id: Ic6ae795a1626d1cdf63a69d2ff86f7cd898f3134
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index 24b4c68..4d221be 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -24,13 +24,14 @@
 import numpy as np
 
 from . import numeric_util
+from .architecture_features import Accelerator
 from .architecture_features import Block
-from .architecture_features import SHRAMElements
 from .data_type import DataType
 from .nn_graph import PassPlacement
 from .nn_graph import SchedulerRewrite
 from .operation import NpuBlockType
 from .operation import Op
+from .shared_buffer_allocation import is_acc_40bits_used
 from .tensor import MemArea
 from .tensor import shape_num_elements
 from .tensor import TensorBlockTraversal
@@ -212,22 +213,20 @@
     return total_blocks, total_area, block_setup
 
 
-def get_output_cycle_estimate(arch, ps):
-    primary_op = ps.primary_op
-    assert primary_op
-    npu_block_type = primary_op.type.npu_block_type
+def get_output_cycle_estimate(
+    arch, npu_block_type, primary_op, num_elems, ifm_tensor, ofm_tensor, ifm2_tensor, use_acc_40bits=False
+):
     faf = primary_op.activation
-
-    if npu_block_type == NpuBlockType.ElementWise and ps.ifm_tensor.dtype == DataType.int32:
-        if ps.ifm2_tensor is None:
+    if npu_block_type == NpuBlockType.ElementWise and ifm_tensor.dtype == DataType.int32:
+        if ifm2_tensor is None:
             # Unary op
             output_perf_index = 0
         else:
             # Binary op
             output_perf_index = 1
-    elif ps.primary_op.type == Op.Mul and ps.ofm_tensor.dtype == DataType.int32:
+    elif primary_op.type == Op.Mul and ofm_tensor.dtype == DataType.int32:
         output_perf_index = 2
-    elif ps.primary_op.type == Op.Mul or (
+    elif primary_op.type == Op.Mul or (
         npu_block_type
         in (
             NpuBlockType.ConvolutionMxN,
@@ -236,13 +235,13 @@
             NpuBlockType.ReduceSum,
             NpuBlockType.VectorProduct,
         )
-        and ps.shared_buffer.use_accumulator_element == SHRAMElements.Acc40
+        and use_acc_40bits
     ):
         output_perf_index = 3
-    elif ps.primary_op.type in (Op.Add, Op.Sub):
-        input_scale = ps.ifm_tensor.quantization.scale_f32
-        input2_scale = ps.ifm2_tensor.quantization.scale_f32
-        output_scale = ps.ofm_tensor.quantization.scale_f32
+    elif primary_op.type in (Op.Add, Op.Sub):
+        input_scale = ifm_tensor.quantization.scale_f32
+        input2_scale = ifm2_tensor.quantization.scale_f32
+        output_scale = ofm_tensor.quantization.scale_f32
 
         if "resizebilinear" in primary_op.attrs:
             output_scale = input2_scale
@@ -253,7 +252,7 @@
         else:
             # Advanced Add/Sub
             output_perf_index = 5
-    elif ps.primary_op.type.is_maxpool_op():
+    elif primary_op.type.is_maxpool_op():
         output_perf_index = 6
     else:
         output_perf_index = 7
@@ -265,13 +264,95 @@
     else:
         activation_perf_index = 2
 
-    num_elems = ps.outputs[0].elements()
     cycle_per_elem = max(
         arch.output_cycles_per_elem[output_perf_index], arch.activation_cycles_per_elem[activation_perf_index]
     )
     return num_elems * cycle_per_elem
 
 
+def get_conv_pooling_cycle_estimate(
+    arch, npu_block_type, primary_op, block_config: Block, block_traversal, kernel_dims, ifm_tensor, ofm_tensor
+):
+    num_ublk = (
+        (block_config.width // arch.config.ofm_ublock.width)
+        * (block_config.height // arch.config.ofm_ublock.height)
+        * (block_config.depth // arch.config.ofm_ublock.depth)
+    )
+    num_ofm_blk = 0
+    total_cycles = 0
+    num_elems_blk = block_config.width * block_config.height * block_config.depth
+    ifm_tens_shape = numeric_util.full_shape(4, ifm_tensor.shape, 1)
+    ofm_tens_shape = numeric_util.full_shape(4, ofm_tensor.shape, 1)
+    use_acc_40bits = is_acc_40bits_used(npu_block_type, ifm_tensor, ofm_tensor)
+
+    sub_kernel_limits = arch.sub_kernel_limits[npu_block_type]
+    n_sub_kernels_y = numeric_util.round_up_divide(kernel_dims[0], sub_kernel_limits[0])
+    n_sub_kernels_x = numeric_util.round_up_divide(kernel_dims[1], sub_kernel_limits[1])
+    sub_kernel_x = [
+        min((kernel_dims[1] - i * sub_kernel_limits[1]), sub_kernel_limits[1]) for i in range(n_sub_kernels_x)
+    ]
+    sub_kernel_y = [
+        min((kernel_dims[0] - i * sub_kernel_limits[0]), sub_kernel_limits[0]) for i in range(n_sub_kernels_y)
+    ]
+    sub_kernel_size = (x * y for y in sub_kernel_y for x in sub_kernel_x)
+
+    ifm_blk_depth = 0
+    if npu_block_type != NpuBlockType.Pooling:
+        if ifm_tensor.dtype.size_in_bits() == 16 or block_traversal == TensorBlockTraversal.PartKernelFirst:
+            ifm_blk_depth = 16
+        elif ifm_tensor.dtype.size_in_bits() == 8:
+            ifm_blk_depth = 32
+        else:
+            ifm_blk_depth = 8
+
+    cycles_dpu_blk = 0
+
+    for num_kernel_elems in sub_kernel_size:
+        if npu_block_type == NpuBlockType.Pooling:
+            cycles = max(4, num_kernel_elems) * num_ublk
+            if ifm_tensor.dtype.size_in_bits() == 16 and arch.accelerator_config != Accelerator.Ethos_U55_32:
+                cycles *= 2
+        elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
+            cycles = 4 * numeric_util.round_up_divide(num_kernel_elems, 4) * num_ublk
+            if ifm_tensor.dtype.size_in_bits() == 16:
+                cycles *= 2
+        elif (
+            (npu_block_type == NpuBlockType.ConvolutionMxN and block_traversal != TensorBlockTraversal.PartKernelFirst)
+            or npu_block_type == NpuBlockType.VectorProduct
+            or npu_block_type == NpuBlockType.ReduceSum
+        ):
+            cycles = 4 * num_kernel_elems * num_ublk * numeric_util.round_up_divide(ifm_tens_shape[3], ifm_blk_depth)
+        else:
+            assert block_traversal == TensorBlockTraversal.PartKernelFirst
+            divider = 2 if ifm_tensor.dtype.size_in_bits() == 16 else 4
+            cycles = 4 * (
+                numeric_util.round_up_divide(num_kernel_elems, divider)
+                * numeric_util.round_up_divide(ifm_blk_depth, 8)
+                * num_ublk
+                * numeric_util.round_up_divide(ifm_tens_shape[3], ifm_blk_depth)
+            )
+        cycles_dpu_blk += cycles
+
+    cycles_dpu_blk /= arch.ncores
+
+    num_ofm_blk = (
+        numeric_util.round_up_divide(ofm_tens_shape[1], block_config.height)
+        * numeric_util.round_up_divide(ofm_tens_shape[2], block_config.width)
+        * numeric_util.round_up_divide(ofm_tens_shape[3], block_config.depth)
+    )
+
+    cycles_output_blk = get_output_cycle_estimate(
+        arch, npu_block_type, primary_op, num_elems_blk, ifm_tensor, ofm_tensor, None, use_acc_40bits
+    )
+
+    if cycles_dpu_blk > cycles_output_blk:
+        total_cycles = cycles_dpu_blk * num_ofm_blk + cycles_output_blk
+    else:
+        total_cycles = cycles_output_blk * num_ofm_blk + cycles_dpu_blk
+
+    return total_cycles
+
+
 def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], force_outputs_to_fast_storage=False):
     if block_config is None:
         block_config = ps.block_config
@@ -302,7 +383,12 @@
         ifm_tensor, _, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm()
 
         if npu_block_type in set(
-            (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling)
+            (
+                NpuBlockType.ConvolutionMxN,
+                NpuBlockType.ConvolutionDepthWise,
+                NpuBlockType.Pooling,
+                NpuBlockType.ReduceSum,
+            )
         ):
             # extent the ifm to full dimension
             ifm_tensor_brick_size = tuple(numeric_util.full_shape(4, list(ifm_tensor.brick_size), 1))
@@ -316,12 +402,22 @@
             ifm_tensor_shape[1] += explicit_padding[0] + explicit_padding[2]  # height += top and bottom
             ifm_tensor_shape[2] += explicit_padding[1] + explicit_padding[3]  # width  += left and right
 
+            block_traversal = TensorBlockTraversal.Default
+
             strides = primary_op.attrs["strides"]
             if npu_block_type != NpuBlockType.Pooling:
-                weight_tensor_shape = weight_tensor.shape
-                weight_tensor_bandwidth_shape = weight_tensor.bandwidth_shape
-                weight_tensor_element_size = weight_tensor.element_size()
-                weight_tensor_bandwidth_compression_scale = weight_tensor.bandwidth_compression_scale
+                if npu_block_type == NpuBlockType.ReduceSum:
+                    block_traversal = TensorBlockTraversal.DepthFirst
+                    weight_tensor_shape = [1, 1, ifm_tensor.shape[3], ofm_tensor.shape[3]]
+                    weight_tensor_bandwidth_shape = [0] * 4
+                    weight_tensor_element_size = 0
+                    weight_tensor_bandwidth_compression_scale = 0.0
+                else:
+                    block_traversal = weight_tensor.block_traversal
+                    weight_tensor_shape = weight_tensor.shape
+                    weight_tensor_bandwidth_shape = weight_tensor.bandwidth_shape
+                    weight_tensor_element_size = weight_tensor.element_size()
+                    weight_tensor_bandwidth_compression_scale = weight_tensor.bandwidth_compression_scale
                 nn_ops = (
                     int(ofm_tensor.shape[0])
                     * int(ofm_tensor.shape[1])
@@ -394,7 +490,7 @@
             n_kernel_xy = kernel_dims[0] * kernel_dims[1]
             n_input_channels_at_a_time = block_config[2]
 
-            if npu_block_type == NpuBlockType.Pooling or weight_tensor.block_traversal in set(
+            if npu_block_type == NpuBlockType.Pooling or block_traversal in set(
                 (TensorBlockTraversal.PartKernelFirst, TensorBlockTraversal.DepthWise)
             ):
                 n_input_channels_at_a_time = numeric_util.round_up_divide(n_input_channels_at_a_time, 4)
@@ -416,14 +512,18 @@
                     * n_kernel_xy
                 )
 
-            if npu_block_type == NpuBlockType.Pooling:
-                # TODO: improve pooling estimation
-                cycles[PassCycles.Dpu] = num_mac_ops / arch.num_macs_per_cycle / 2
-            else:
-                cycles[PassCycles.Dpu] = num_mac_ops / arch.num_macs_per_cycle
             macs[MacCount.NeuralNetworkMacs] += nn_ops
             macs[MacCount.HardwareMacs] += num_mac_ops
-
+            cycles[PassCycles.Dpu] = get_conv_pooling_cycle_estimate(
+                arch,
+                npu_block_type,
+                primary_op,
+                Block(block_config[1], block_config[0], block_config[3]),
+                block_traversal,
+                kernel_dims,
+                ifm_tensor,
+                ofm_tensor,
+            )
         elif npu_block_type == NpuBlockType.VectorProduct:
             nn_macs = (
                 ifm_tensor.shape[0]
@@ -432,7 +532,16 @@
             )
             num_mac_ops = nn_macs
 
-            cycles[PassCycles.Dpu] = num_mac_ops / arch.num_macs_per_cycle
+            cycles[PassCycles.Dpu] = get_conv_pooling_cycle_estimate(
+                arch,
+                npu_block_type,
+                primary_op,
+                Block(block_config[1], block_config[0], block_config[3]),
+                weight_tensor.block_traversal,
+                [1, 1],
+                ifm_tensor,
+                ofm_tensor,
+            )
             macs[MacCount.NeuralNetworkMacs] += nn_macs
             macs[MacCount.HardwareMacs] += num_mac_ops
 
@@ -449,8 +558,9 @@
             weight_read_multiple = non_zero_fraction
         elif npu_block_type == NpuBlockType.ElementWise:
             # Work out how many elements we have and calculate performance.
-            cycles[PassCycles.ElementWise] = get_output_cycle_estimate(arch, ps)
-
+            cycles[PassCycles.ElementWise] = get_output_cycle_estimate(
+                arch, npu_block_type, primary_op, ofm_tensor.elements(), ps.ifm_tensor, ps.ofm_tensor, ps.ifm2_tensor
+            )
     # apply the desired rewrites
     for rewrite_op, tens, _, _, _, ps_to_rewrite in rewrite_list:
         if ps != ps_to_rewrite: