Vela: bandwidth calculation improvements

  - Combine conv and vector_product calculation
  - Remove internal bandwidth
  - Remove blocks and hw_macs from report
  - Use scaled_bws for cycle estimation

Related to: MLBEDSW-3598

Change-Id: I1927a8311ec563f68115e0f2ed077806b86fd717
Signed-off-by: Diqing Zhong <diqing.zhong@arm.com>
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index 2d7a1b0..8ada1e2 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -90,22 +90,6 @@
         )
 
 
-class MacCount(IntEnum):
-    NeuralNetworkMacs = 0
-    HardwareMacs = auto()
-    Size = auto()
-
-    def display_name(self):
-        return ("Neural Network Macs", "Hardware Macs", "Size")[self.value]
-
-    def identifier_name(self):
-        return ("nn_macs", "hardware_macs", "size")[self.value]
-
-    @staticmethod
-    def all():
-        return (MacCount.NeuralNetworkMacs, MacCount.HardwareMacs)
-
-
 class BandwidthDirection(IntEnum):
     Read = 0
     Write = auto()
@@ -126,77 +110,18 @@
     return np.zeros((MemArea.Size, TensorPurpose.Size, BandwidthDirection.Size))
 
 
-def make_macs_array():
-    return np.zeros(MacCount.Size, np.int)
-
-
 def make_cycles_array():
     return np.zeros(PassCycles.Size)
 
 
 def make_metrics_arrays():
-    return (make_bandwidth_array(), make_macs_array(), make_cycles_array())
-
-
-def get_n_blocks_and_area(
-    ifm_brick_size, ifm_height_width, orig_skirt, clamped_skirt, block_config, min_block_size, strides
-):
-
-    ifm_block_config = (block_config[0] * strides[1], block_config[1] * strides[2])
-
-    n_normal_blocks = []
-    remainder_size = []
-    for i in range(2):
-        non_skirt_dim = ifm_height_width[i] - orig_skirt[i] - orig_skirt[2 + i]
-        n_blocks = non_skirt_dim // ifm_block_config[i]
-        n_normal_blocks.append(n_blocks)
-        remainder_dim = numeric_util.round_up(
-            ((non_skirt_dim - n_blocks * ifm_block_config[i] - 1) // strides[i + 1]) + 1, min_block_size[i]
-        )
-        remainder_size.append(remainder_dim)
-
-    # this will actually calculate reads into the edge padding.
-
-    # there are four cases in total, handling the edges that will not fill a complete block.
-
-    # 0000000001
-    # 0000000001
-    # 0000000001
-    # 0000000001
-    # 0000000001
-    # 0000000001
-    # 2222222223
-    total_blocks = 0
-    total_area = 0
-
-    block_setup = (
-        (n_normal_blocks[0] * n_normal_blocks[1], block_config),
-        (1 * n_normal_blocks[1], (remainder_size[0], block_config[1])),
-        (n_normal_blocks[0] * 1, (block_config[0], remainder_size[1])),
-        (1 * 1, remainder_size),
-    )
-
-    for n_blocks, block_size in block_setup:
-        if block_size[0] == 0 or block_size[1] == 0:
-            continue
-        read_dims = [0, 0]
-        for i in range(2):
-            read_dims[i] = (
-                numeric_util.round_up(clamped_skirt[i], ifm_brick_size[i + 1])
-                + block_size[i] * strides[i + 1]
-                + numeric_util.round_up(clamped_skirt[2 + i], ifm_brick_size[i + 1])
-            )
-        assert n_blocks >= 0
-        total_blocks += n_blocks
-        total_area += n_blocks * read_dims[0] * read_dims[1]
-    assert total_blocks >= 1
-    return total_blocks, total_area, block_setup
+    return (make_bandwidth_array(), 0, make_cycles_array())
 
 
 def get_ifm_block_depth(npu_block_type, ifm_depth, ifm_elemwidth, block_traversal, ofm_blk_depth):
     ifm_blk_depth = ofm_blk_depth
 
-    if npu_block_type == NpuBlockType.ConvolutionMxN or npu_block_type == NpuBlockType.ReduceSum:
+    if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
         if ifm_elemwidth == 16 or block_traversal == TensorBlockTraversal.PartKernelFirst:
             ifm_blk_depth = 16
         elif ifm_elemwidth == 8:
@@ -213,11 +138,11 @@
     ifm_tens_blk = Tensor((1, ifm_blk.height, ifm_blk.width, ifm_blk.depth), ifm_tensor.dtype, "ifm_blk")
     ofm_tens_blk = Tensor((1, ofm_blk.height, ofm_blk.width, ofm_blk.depth), ofm_tensor.dtype, "ofm_blk")
     cycles_ifm_blk = (
-        estimate_memory_bandwidth(arch, ifm_tensor.mem_area, BandwidthDirection.Read, ifm_tens_blk, ifm_blk)
+        estimate_memory_transfer_efficiency(arch, ifm_tensor.mem_area, BandwidthDirection.Read, ifm_tens_blk, ifm_blk)
         / arch.memory_bandwidths_per_cycle[ifm_tensor.mem_area]
     )
     cycles_ofm_blk = (
-        estimate_memory_bandwidth(arch, ofm_tensor.mem_area, BandwidthDirection.Write, ofm_tens_blk, ofm_blk)
+        estimate_memory_transfer_efficiency(arch, ofm_tensor.mem_area, BandwidthDirection.Write, ofm_tens_blk, ofm_blk)
         / arch.memory_bandwidths_per_cycle[ofm_tensor.mem_area]
     )
     return (
@@ -449,7 +374,7 @@
     return total_cycles
 
 
-def estimate_memory_bandwidth(arch, mem_area, direction, tensor, block_size: Block, replace_bw=None):
+def estimate_memory_transfer_efficiency(arch, mem_area, direction, tensor, block_size: Block, replace_bw=None):
     if tensor.format not in (TensorFormat.NHWC, TensorFormat.NHCWB16):
         return tensor.bandwidth() if replace_bw is None else replace_bw
 
@@ -493,18 +418,15 @@
     if block_config is None:
         block_config = ps.block_config
     bws = make_bandwidth_array()
-    macs = make_macs_array()
+    scaled_bws = make_bandwidth_array()  # scaled bw with memory transfer efficiency
+    macs = 0
     cycles = make_cycles_array()
-    blocks = 0
     ifm_read_multiple = 1
     weight_read_multiple = 0
 
     if ps.placement in (PassPlacement.MemoryOnly, PassPlacement.StartupInit):
-        return bws, macs, cycles, blocks, ifm_read_multiple, weight_read_multiple  # nothing real happening in this pass
+        return bws, macs, cycles, ifm_read_multiple, weight_read_multiple  # nothing real happening in this pass
 
-    min_block_size = arch.min_block_sizes[ps.npu_block_type]
-
-    skirt = (0, 0, 0, 0)
     explicit_padding = (0, 0, 0, 0)
     primary_op = ps.primary_op
     replacement_read_bws = {}
@@ -512,13 +434,13 @@
     ifm_block = Block(block_config[1], block_config[0], block_config[3])
 
     if ps.placement == PassPlacement.Npu and primary_op:
-        skirt = primary_op.attrs.get("skirt", skirt)
         explicit_padding = primary_op.attrs.get("explicit_padding", explicit_padding)
         assert primary_op.type.npu_block_type == ps.npu_block_type
         npu_block_type = primary_op.type.npu_block_type
 
         ifm_tensor, _, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm()
         ifm_tensor_shape = numeric_util.full_shape(4, ifm_tensor.shape, 1)
+        ofm_tensor_shape = numeric_util.full_shape(4, ofm_tensor.shape, 1)
 
         if npu_block_type == NpuBlockType.ReduceSum:
             block_traversal = TensorBlockTraversal.DepthFirst
@@ -540,21 +462,17 @@
         if npu_block_type in (
             NpuBlockType.ConvolutionMxN,
             NpuBlockType.ConvolutionDepthWise,
+            NpuBlockType.VectorProduct,
             NpuBlockType.Pooling,
             NpuBlockType.ReduceSum,
         ):
             # extent the ifm to full dimension
-            ifm_tensor_brick_size = tuple(numeric_util.full_shape(4, list(ifm_tensor.brick_size), 1))
-            ifm_tensor_bandwidth_shape = numeric_util.full_shape(4, ifm_tensor.bandwidth_shape, 1)
-
             batch_size = ifm_tensor_shape[0]
-            ifm_depth = ifm_tensor_bandwidth_shape[3]
 
             # add in padding
             ifm_tensor_shape[1] += explicit_padding[0] + explicit_padding[2]  # height += top and bottom
             ifm_tensor_shape[2] += explicit_padding[1] + explicit_padding[3]  # width  += left and right
 
-            strides = primary_op.attrs["strides"]
             if npu_block_type != NpuBlockType.Pooling:
                 if npu_block_type == NpuBlockType.ReduceSum:
                     weight_tensor_shape = [1, 1, ifm_tensor.shape[3], ofm_tensor.shape[3]]
@@ -562,14 +480,16 @@
                     weight_tensor_element_size = 0
                     weight_tensor_bandwidth_compression_scale = 0.0
                 else:
-                    weight_tensor_shape = weight_tensor.shape
-                    weight_tensor_bandwidth_shape = weight_tensor.bandwidth_shape
+                    # For Vector product, weight format of IO is extended to HWIO, with H=W=1
+                    weight_tensor_shape = numeric_util.full_shape(4, weight_tensor.shape, 1)
+                    weight_tensor_bandwidth_shape = numeric_util.full_shape(4, weight_tensor.bandwidth_shape, 1)
                     weight_tensor_element_size = weight_tensor.element_size()
                     weight_tensor_bandwidth_compression_scale = weight_tensor.bandwidth_compression_scale
+
                 nn_ops = (
-                    int(ofm_tensor.shape[0])
-                    * int(ofm_tensor.shape[1])
-                    * int(ofm_tensor.shape[2])
+                    int(ofm_tensor_shape[0])
+                    * int(ofm_tensor_shape[1])
+                    * int(ofm_tensor_shape[2])
                     * int(weight_tensor_shape[0])
                     * int(weight_tensor_shape[1])
                     * int(weight_tensor_shape[2])
@@ -595,72 +515,25 @@
             n_sub_kernels_x = numeric_util.round_up_divide(kernel_dims[1], sub_kernel_limits[1])
             n_sub_kernels = n_sub_kernels_y * n_sub_kernels_x
 
-            clamped_skirt = list(skirt)
-            clamped_skirt[2] = min(clamped_skirt[2], sub_kernel_limits[0] - 1 - clamped_skirt[0])
-            clamped_skirt[3] = min(clamped_skirt[3], sub_kernel_limits[1] - 1 - clamped_skirt[1])
-            n_blocks, area, block_setup = get_n_blocks_and_area(
-                ifm_tensor_brick_size,
-                ifm_tensor_shape[1:3],
-                skirt,
-                clamped_skirt,
-                block_config,
-                min_block_size,
-                strides,
-            )
+            n_full_depth_stages = numeric_util.round_up_divide(weight_tensor_bandwidth_shape[3], ofm_block.depth)
+            if npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling):
+                n_full_depth_stages = 1  # force to no reread
 
-            blocks = n_blocks * numeric_util.round_up_divide(weight_tensor_shape[3], ofm_block.depth)
+            ifm_read_multiple = n_sub_kernels * n_full_depth_stages
+            replacement_read_bws[ifm_tensor] = ifm_tensor.bandwidth() * ifm_read_multiple
 
-            n_weight_stages = numeric_util.round_up_divide(weight_tensor_bandwidth_shape[3], ofm_block.depth)
-            if npu_block_type == NpuBlockType.ConvolutionDepthWise or npu_block_type == NpuBlockType.Pooling:
-                n_weight_stages = 1  # force to no reread
-
-            ifm_tensor_bw = (
-                n_sub_kernels
-                * batch_size
-                * area
-                * ifm_depth
-                * n_weight_stages
-                * ifm_tensor.element_size()
-                * ifm_tensor.bandwidth_compression_scale
-            )
-            replacement_read_bws[ifm_tensor] = ifm_tensor_bw
-            ifm_read_multiple = n_weight_stages
-
+            weight_read_multiple = numeric_util.round_up_divide(
+                ofm_tensor_shape[1], ofm_block.height
+            ) * numeric_util.round_up_divide(ofm_tensor_shape[2], ofm_block.width)
             replacement_read_bws[weight_tensor] = (
                 batch_size
                 * shape_num_elements(weight_tensor_bandwidth_shape)
                 * weight_tensor_element_size
                 * weight_tensor_bandwidth_compression_scale
-                * n_blocks
-            )  # read once per block and batch
-            weight_read_multiple = n_blocks
+                * weight_read_multiple
+            )
 
-            n_kernel_xy = kernel_dims[0] * kernel_dims[1]
-            n_input_channels_at_a_time = block_config[2]
-
-            if (npu_block_type == NpuBlockType.Pooling) or (
-                block_traversal in (TensorBlockTraversal.PartKernelFirst, TensorBlockTraversal.DepthWise)
-            ):
-                n_input_channels_at_a_time = numeric_util.round_up_divide(n_input_channels_at_a_time, 4)
-                n_kernel_xy = max(
-                    n_kernel_xy, 4
-                )  # need at least 4, as this is the minimum duty cycle for secondary accumulator writes
-                if weight_tensor is not None:
-                    n_kernel_xy = numeric_util.round_up(n_kernel_xy, 4)  # weights need to be read in blocks of 4
-
-            num_mac_ops = 0
-            for n_blocks_for_size, block_size in block_setup:
-                num_mac_ops += (
-                    batch_size
-                    * n_blocks_for_size
-                    * block_size[0]
-                    * block_size[1]
-                    * numeric_util.round_up(weight_tensor_shape[2], n_input_channels_at_a_time)
-                    * numeric_util.round_up(weight_tensor_shape[3], ofm_block.depth)
-                    * n_kernel_xy
-                )
-            macs[MacCount.NeuralNetworkMacs] += nn_ops
-            macs[MacCount.HardwareMacs] += num_mac_ops
+            macs += nn_ops
             cycles[PassCycles.Npu] = estimate_conv_pooling_cycles(
                 arch,
                 npu_block_type,
@@ -673,31 +546,6 @@
                 ofm_tensor,
                 ps.scale_tensor,
             )
-        elif npu_block_type == NpuBlockType.VectorProduct:
-            nn_macs = (
-                ifm_tensor.shape[0]
-                * numeric_util.round_up(weight_tensor.shape[-2], block_config[2])
-                * numeric_util.round_up(weight_tensor.shape[-1], block_config[3])
-            )
-            num_mac_ops = nn_macs
-
-            cycles[PassCycles.Npu] = estimate_conv_pooling_cycles(
-                arch, npu_block_type, primary_op, ifm_block, ofm_block, block_traversal, [1, 1], ifm_tensor, ofm_tensor,
-            )
-            macs[MacCount.NeuralNetworkMacs] += nn_macs
-            macs[MacCount.HardwareMacs] += num_mac_ops
-
-            blocks = 1 * numeric_util.round_up_divide(weight_tensor.shape[-1], ofm_block.depth)
-
-            non_zero_fraction = 1.0
-            if ifm_tensor.values is not None:
-                nz_vector = np.amax(ifm_tensor.values != 0, axis=0)  # max across batch axis
-                non_zero_fraction = np.average(nz_vector)
-
-            replacement_read_bws[ifm_tensor] = ifm_tensor.bandwidth()
-            replacement_read_bws[weight_tensor] = weight_tensor.bandwidth() * non_zero_fraction
-            ifm_read_multiple = 1
-            weight_read_multiple = non_zero_fraction
         elif npu_block_type == NpuBlockType.ElementWise:
             # Work out how many elements we have and calculate performance.
             cycles[PassCycles.Npu] = estimate_output_cycles(
@@ -729,8 +577,9 @@
             if rewrite_op == SchedulerRewrite.Nop:
                 pass  # these are fine, no bandwidth changes
             elif rewrite_op in (SchedulerRewrite.ChangeTensorSubPurpose,):
+                bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Read] += replacement_read_bws[tens]
                 if tens.purpose == TensorPurpose.FeatureMap:
-                    bw = estimate_memory_bandwidth(
+                    scaled_bw = estimate_memory_transfer_efficiency(
                         arch,
                         arch.fast_storage_mem_area,
                         BandwidthDirection.Read,
@@ -739,22 +588,27 @@
                         replacement_read_bws[tens],
                     )
                 else:
-                    bw = replacement_read_bws[tens]
-                bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Read] += bw
+                    scaled_bw = replacement_read_bws[tens]
+                scaled_bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Read] += scaled_bw
                 replacement_read_bws[tens] = 0
 
     for tens in ps.outputs:
         if force_outputs_to_fast_storage:
-            bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Write] += estimate_memory_bandwidth(
+            bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
+            scaled_bws[arch.fast_storage_mem_area][tens.purpose][
+                BandwidthDirection.Write
+            ] += estimate_memory_transfer_efficiency(
                 arch, arch.fast_storage_mem_area, BandwidthDirection.Write, tens, ofm_block
             )
         else:
-            bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += estimate_memory_bandwidth(
+            bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
+            scaled_bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += estimate_memory_transfer_efficiency(
                 arch, tens.mem_area, BandwidthDirection.Write, tens, ofm_block
             )
 
     for tens in ps.intermediates:
         bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
+        scaled_bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
 
         if tens in replacement_read_bws:
             bw = replacement_read_bws[tens]
@@ -762,16 +616,23 @@
             bw = tens.bandwidth()
 
         bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
+        scaled_bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
 
     for tens in ps.inputs:
-        bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += estimate_memory_bandwidth(
-            arch, tens.mem_area, BandwidthDirection.Read, tens, ifm_block, replacement_read_bws.get(tens)
+        if tens in replacement_read_bws:
+            bw = replacement_read_bws[tens]
+        else:
+            bw = tens.bandwidth()
+
+        bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
+        scaled_bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += estimate_memory_transfer_efficiency(
+            arch, tens.mem_area, BandwidthDirection.Read, tens, ifm_block, bw
         )
 
     # quick build access counts for only current pass, even though these aren't the final numbers
-    update_summary_cycles(arch, bws, cycles)
+    update_summary_cycles(arch, scaled_bws, cycles)
 
-    return bws, macs, cycles, blocks, ifm_read_multiple, weight_read_multiple
+    return bws, macs, cycles, ifm_read_multiple, weight_read_multiple
 
 
 def update_summary_cycles(arch, bws, cycles):
@@ -794,15 +655,14 @@
 
 def performance_for_cascaded_pass(arch, cps):
     total_bws = make_bandwidth_array()
-    total_macs = make_macs_array()
+    total_macs = 0
     total_cycles = make_cycles_array()
 
     for ps in cps.passes:
-        bws, macs, cycles, blocks, _, _ = performance_metrics_for_pass(arch, ps)
+        bws, macs, cycles, _, _ = performance_metrics_for_pass(arch, ps)
         ps.bandwidths = bws
         ps.macs = macs
         ps.cycles = cycles
-        ps.n_blocks = blocks
         total_bws += bws
         total_macs += macs
         total_cycles += cycles
@@ -816,7 +676,7 @@
 
 def calc_performance_for_network(nng, arch):
     total_bws = make_bandwidth_array()
-    total_macs = np.zeros(MacCount.Size)
+    total_macs = 0
     total_cycles = np.zeros(PassCycles.Size)
 
     for sg in nng.subgraphs: