blob: 29e0df9a24360175e717caf22bcd84f206548cc7 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# NPU performance estimation functions to estimate performance of a Pass and CascadedPass. Uses a model that takes the
18# maximum of the 'cycles required for bandwidth' and 'cycles required for computing'.
19#
20# Called during scheduling to evaluate different proposals, as well as post-scheduling to provide a final performance
21# estimate.
Diqing Zhonge168b962020-11-05 17:18:47 +010022from enum import auto
23from enum import IntEnum
Diego Russoea6111a2020-04-14 18:41:58 +010024
Tim Hall79d07d22020-04-27 18:20:16 +010025import numpy as np
Diego Russoea6111a2020-04-14 18:41:58 +010026
27from . import numeric_util
Diqing Zhong09387e22020-09-28 18:46:22 +020028from .architecture_features import Accelerator
Diego Russoe8a10452020-04-21 17:39:10 +010029from .architecture_features import Block
Diqing Zhonge8887a32020-09-24 09:53:48 +020030from .data_type import DataType
Diego Russoe8a10452020-04-21 17:39:10 +010031from .nn_graph import PassPlacement
32from .nn_graph import SchedulerRewrite
Diego Russoea6111a2020-04-14 18:41:58 +010033from .operation import NpuBlockType
Diqing Zhonge8887a32020-09-24 09:53:48 +020034from .operation import Op
Diqing Zhong09387e22020-09-28 18:46:22 +020035from .shared_buffer_allocation import is_acc_40bits_used
Diego Russoe8a10452020-04-21 17:39:10 +010036from .tensor import MemArea
37from .tensor import shape_num_elements
38from .tensor import TensorBlockTraversal
Diqing Zhonge168b962020-11-05 17:18:47 +010039from .tensor import TensorFormat
Diego Russoe8a10452020-04-21 17:39:10 +010040from .tensor import TensorPurpose
Tim Hall79d07d22020-04-27 18:20:16 +010041
42
43def rolling_buffer_dims_from_passes(arch, ps1, block_config_ps1, ps2, block_config_ps2):
Tim Hall79d07d22020-04-27 18:20:16 +010044 ofm_block = Block(block_config_ps2[-3], block_config_ps2[-4], block_config_ps2[-1])
Tim Hall4ed38bc2020-10-20 18:54:20 +010045 kernel = ps2.primary_op.kernel
Tim Hall79d07d22020-04-27 18:20:16 +010046
47 if ps2.npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct)):
Louis Verhaard93dc5532020-06-07 12:40:18 +020048 op = ps2.primary_op
Louis Verhaardaee5d752020-09-30 09:01:52 +020049 ifm_block_depth = arch.calc_ifm_block_depth(op.ifm.shape[-1], op.ifm.dtype.size_in_bits())
Tim Hall79d07d22020-04-27 18:20:16 +010050 else:
51 ifm_block_depth = block_config_ps2[-1]
52
Louis Verhaard93dc5532020-06-07 12:40:18 +020053 ifm_block = arch.get_ifm_block_size(ifm_block_depth, ofm_block, kernel, arch.ofm_block_max)
Tim Hall79d07d22020-04-27 18:20:16 +010054
55 # The performed height calculation is for worst case
56 height = numeric_util.round_up(ifm_block.height + block_config_ps1[0], block_config_ps1[0])
57 width = ifm_block.width
Louis Verhaard93dc5532020-06-07 12:40:18 +020058 return [height, width]
Tim Hall79d07d22020-04-27 18:20:16 +010059
60
Diqing Zhonge168b962020-11-05 17:18:47 +010061class PassCycles(IntEnum):
Diqing Zhong42e833d2020-10-02 13:18:42 +020062 Npu = 0
Diqing Zhonge168b962020-11-05 17:18:47 +010063 Cpu = auto()
64 SramAccess = auto()
65 DramAccess = auto()
66 OnChipFlashAccess = auto()
67 OffChipFlashAccess = auto()
68 Total = auto()
69 Size = auto()
Tim Hall79d07d22020-04-27 18:20:16 +010070
71 def display_name(self):
72 return (
Diqing Zhong42e833d2020-10-02 13:18:42 +020073 "NPU",
Tim Hall79d07d22020-04-27 18:20:16 +010074 "CPU",
75 "SRAM Access",
Tim Hall79d07d22020-04-27 18:20:16 +010076 "DRAM Access",
77 "On-chip Flash Access",
78 "Off-chip Flash Access",
79 "Total",
80 "Size",
81 )[self.value]
82
83 def identifier_name(self):
84 return (
Diqing Zhong42e833d2020-10-02 13:18:42 +020085 "npu",
Tim Hall79d07d22020-04-27 18:20:16 +010086 "cpu",
87 "sram_access",
Tim Hall79d07d22020-04-27 18:20:16 +010088 "dram_access",
89 "on_chip_flash_access",
90 "off_chip_flash_access",
91 "total",
92 "size",
93 )[self.value]
94
95 @staticmethod
96 def all():
97 return (
Diqing Zhong42e833d2020-10-02 13:18:42 +020098 PassCycles.Npu,
Tim Hall79d07d22020-04-27 18:20:16 +010099 PassCycles.Cpu,
100 PassCycles.SramAccess,
101 PassCycles.DramAccess,
102 PassCycles.OnChipFlashAccess,
103 PassCycles.OffChipFlashAccess,
104 PassCycles.Total,
105 )
106
107
Diqing Zhonge168b962020-11-05 17:18:47 +0100108class MacCount(IntEnum):
Tim Hall79d07d22020-04-27 18:20:16 +0100109 NeuralNetworkMacs = 0
Diqing Zhonge168b962020-11-05 17:18:47 +0100110 HardwareMacs = auto()
111 Size = auto()
Tim Hall79d07d22020-04-27 18:20:16 +0100112
113 def display_name(self):
114 return ("Neural Network Macs", "Hardware Macs", "Size")[self.value]
115
116 def identifier_name(self):
117 return ("nn_macs", "hardware_macs", "size")[self.value]
118
119 @staticmethod
120 def all():
121 return (MacCount.NeuralNetworkMacs, MacCount.HardwareMacs)
122
123
Diqing Zhonge168b962020-11-05 17:18:47 +0100124class BandwidthDirection(IntEnum):
Tim Hall79d07d22020-04-27 18:20:16 +0100125 Read = 0
Diqing Zhonge168b962020-11-05 17:18:47 +0100126 Write = auto()
127 Size = auto()
Tim Hall79d07d22020-04-27 18:20:16 +0100128
129 def display_name(self):
130 return self.name
131
132 def identifier_name(self):
133 return self.name.lower()
134
135 @staticmethod
136 def all():
137 return (BandwidthDirection.Read, BandwidthDirection.Write)
138
139
140def make_bandwidth_array():
141 return np.zeros((MemArea.Size, TensorPurpose.Size, BandwidthDirection.Size))
142
143
144def make_macs_array():
145 return np.zeros(MacCount.Size, np.int)
146
147
148def make_cycles_array():
149 return np.zeros(PassCycles.Size)
150
151
152def make_metrics_arrays():
153 return (make_bandwidth_array(), make_macs_array(), make_cycles_array())
154
155
156def get_n_blocks_and_area(
157 ifm_brick_size, ifm_height_width, orig_skirt, clamped_skirt, block_config, min_block_size, strides
158):
159
160 ifm_block_config = (block_config[0] * strides[1], block_config[1] * strides[2])
161
162 n_normal_blocks = []
163 remainder_size = []
164 for i in range(2):
165 non_skirt_dim = ifm_height_width[i] - orig_skirt[i] - orig_skirt[2 + i]
166 n_blocks = non_skirt_dim // ifm_block_config[i]
167 n_normal_blocks.append(n_blocks)
168 remainder_dim = numeric_util.round_up(
169 ((non_skirt_dim - n_blocks * ifm_block_config[i] - 1) // strides[i + 1]) + 1, min_block_size[i]
170 )
171 remainder_size.append(remainder_dim)
172
173 # this will actually calculate reads into the edge padding.
174
175 # there are four cases in total, handling the edges that will not fill a complete block.
176
177 # 0000000001
178 # 0000000001
179 # 0000000001
180 # 0000000001
181 # 0000000001
182 # 0000000001
183 # 2222222223
184 total_blocks = 0
185 total_area = 0
186
187 block_setup = (
188 (n_normal_blocks[0] * n_normal_blocks[1], block_config),
189 (1 * n_normal_blocks[1], (remainder_size[0], block_config[1])),
190 (n_normal_blocks[0] * 1, (block_config[0], remainder_size[1])),
191 (1 * 1, remainder_size),
192 )
193
194 for n_blocks, block_size in block_setup:
195 if block_size[0] == 0 or block_size[1] == 0:
196 continue
197 read_dims = [0, 0]
198 for i in range(2):
199 read_dims[i] = (
200 numeric_util.round_up(clamped_skirt[i], ifm_brick_size[i + 1])
201 + block_size[i] * strides[i + 1]
202 + numeric_util.round_up(clamped_skirt[2 + i], ifm_brick_size[i + 1])
203 )
204 assert n_blocks >= 0
205 total_blocks += n_blocks
206 total_area += n_blocks * read_dims[0] * read_dims[1]
207 assert total_blocks >= 1
208 return total_blocks, total_area, block_setup
209
210
Diqing Zhong42e833d2020-10-02 13:18:42 +0200211def get_ifm_block_depth(npu_block_type, ifm_depth, ifm_elemwidth, block_traversal, ofm_blk_depth):
212 ifm_blk_depth = ofm_blk_depth
213
214 if npu_block_type == NpuBlockType.ConvolutionMxN or npu_block_type == NpuBlockType.ReduceSum:
215 if ifm_elemwidth == 16 or block_traversal == TensorBlockTraversal.PartKernelFirst:
216 ifm_blk_depth = 16
217 elif ifm_elemwidth == 8:
218 ifm_blk_depth = 32
219 else:
220 ifm_blk_depth = 8
221
222 return min(ifm_depth, ifm_blk_depth)
223
224
225def estimate_output_cycles(
Diqing Zhong09387e22020-09-28 18:46:22 +0200226 arch, npu_block_type, primary_op, num_elems, ifm_tensor, ofm_tensor, ifm2_tensor, use_acc_40bits=False
227):
Louis Verhaarde8a5a782020-11-02 18:04:27 +0100228 faf = None if primary_op.activation is None else primary_op.activation.op_type
Diqing Zhong09387e22020-09-28 18:46:22 +0200229 if npu_block_type == NpuBlockType.ElementWise and ifm_tensor.dtype == DataType.int32:
230 if ifm2_tensor is None:
Diqing Zhonge8887a32020-09-24 09:53:48 +0200231 # Unary op
232 output_perf_index = 0
233 else:
234 # Binary op
235 output_perf_index = 1
Diqing Zhong09387e22020-09-28 18:46:22 +0200236 elif primary_op.type == Op.Mul and ofm_tensor.dtype == DataType.int32:
Diqing Zhonge8887a32020-09-24 09:53:48 +0200237 output_perf_index = 2
Diqing Zhong09387e22020-09-28 18:46:22 +0200238 elif primary_op.type == Op.Mul or (
Diqing Zhonge8887a32020-09-24 09:53:48 +0200239 npu_block_type
240 in (
241 NpuBlockType.ConvolutionMxN,
242 NpuBlockType.ConvolutionDepthWise,
243 NpuBlockType.Pooling,
244 NpuBlockType.ReduceSum,
245 NpuBlockType.VectorProduct,
246 )
Diqing Zhong09387e22020-09-28 18:46:22 +0200247 and use_acc_40bits
Diqing Zhonge8887a32020-09-24 09:53:48 +0200248 ):
249 output_perf_index = 3
Diqing Zhong09387e22020-09-28 18:46:22 +0200250 elif primary_op.type in (Op.Add, Op.Sub):
251 input_scale = ifm_tensor.quantization.scale_f32
252 input2_scale = ifm2_tensor.quantization.scale_f32
253 output_scale = ofm_tensor.quantization.scale_f32
Diqing Zhonge8887a32020-09-24 09:53:48 +0200254
255 if "resizebilinear" in primary_op.attrs:
256 output_scale = input2_scale
257
258 if None in (input_scale, input2_scale, output_scale) or input_scale == input2_scale:
259 # Simple Add/Sub
260 output_perf_index = 4
261 else:
262 # Advanced Add/Sub
263 output_perf_index = 5
Diqing Zhong09387e22020-09-28 18:46:22 +0200264 elif primary_op.type.is_maxpool_op():
Diqing Zhonge8887a32020-09-24 09:53:48 +0200265 output_perf_index = 6
266 else:
267 output_perf_index = 7
268
269 if faf in (Op.Sigmoid, Op.Tanh, Op.LUT):
270 activation_perf_index = 0
271 elif faf in (Op.Relu, Op.Relu6, Op.ReluN1To1):
272 activation_perf_index = 1
273 else:
274 activation_perf_index = 2
275
Diqing Zhonge8887a32020-09-24 09:53:48 +0200276 cycle_per_elem = max(
277 arch.output_cycles_per_elem[output_perf_index], arch.activation_cycles_per_elem[activation_perf_index]
278 )
Diqing Zhong986e3192020-11-16 16:15:56 +0100279
Diqing Zhonge8887a32020-09-24 09:53:48 +0200280 return num_elems * cycle_per_elem
281
282
Diqing Zhong42e833d2020-10-02 13:18:42 +0200283def estimate_conv_pooling_cycles(
Diqing Zhong986e3192020-11-16 16:15:56 +0100284 arch,
285 npu_block_type,
286 primary_op,
287 block_config: Block,
288 block_traversal,
289 kernel_dims,
290 ifm_tensor,
291 ofm_tensor,
292 scale_tensor=None,
Diqing Zhong09387e22020-09-28 18:46:22 +0200293):
Diqing Zhonge5204a62020-10-13 11:42:37 +0200294 ofm_ublock = Block(arch.config.ofm_ublock.width, arch.config.ofm_ublock.height, arch.config.ofm_ublock.depth)
295 ifm_tens_shape = numeric_util.full_shape(4, ifm_tensor.shape, 1)
296 ofm_tens_shape = numeric_util.full_shape(4, ofm_tensor.shape, 1)
297
298 if (
299 arch.config.ofm_ublock.height == 2
300 and npu_block_type
301 in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.VectorProduct)
302 and ofm_tens_shape[1] == 1
303 # Optimisation only applies for even width tensors
304 and ofm_tens_shape[2] % 2 == 0
305 and kernel_dims[0] == 1
306 ):
307 ofm_ublock.width = 4
308 ofm_ublock.height = 1
309 block_config.height = 1
310
Diqing Zhong986e3192020-11-16 16:15:56 +0100311 num_ublk_xy = numeric_util.round_up_divide(block_config.width, ofm_ublock.width) * (
312 block_config.height // ofm_ublock.height
Diqing Zhong09387e22020-09-28 18:46:22 +0200313 )
Diqing Zhong986e3192020-11-16 16:15:56 +0100314 num_ublk_z = block_config.depth // ofm_ublock.depth
315
Diqing Zhong09387e22020-09-28 18:46:22 +0200316 num_ofm_blk = 0
317 total_cycles = 0
318 num_elems_blk = block_config.width * block_config.height * block_config.depth
Diqing Zhonge5204a62020-10-13 11:42:37 +0200319
Diqing Zhong09387e22020-09-28 18:46:22 +0200320 use_acc_40bits = is_acc_40bits_used(npu_block_type, ifm_tensor, ofm_tensor)
321
322 sub_kernel_limits = arch.sub_kernel_limits[npu_block_type]
323 n_sub_kernels_y = numeric_util.round_up_divide(kernel_dims[0], sub_kernel_limits[0])
324 n_sub_kernels_x = numeric_util.round_up_divide(kernel_dims[1], sub_kernel_limits[1])
325 sub_kernel_x = [
326 min((kernel_dims[1] - i * sub_kernel_limits[1]), sub_kernel_limits[1]) for i in range(n_sub_kernels_x)
327 ]
328 sub_kernel_y = [
329 min((kernel_dims[0] - i * sub_kernel_limits[0]), sub_kernel_limits[0]) for i in range(n_sub_kernels_y)
330 ]
331 sub_kernel_size = (x * y for y in sub_kernel_y for x in sub_kernel_x)
332
Diqing Zhong42e833d2020-10-02 13:18:42 +0200333 ifm_blk_depth = get_ifm_block_depth(
334 npu_block_type, ifm_tens_shape[3], ifm_tensor.dtype.size_in_bits(), block_traversal, block_config.depth
335 )
Diqing Zhong09387e22020-09-28 18:46:22 +0200336 cycles_dpu_blk = 0
Diqing Zhong986e3192020-11-16 16:15:56 +0100337 cycles_wb = 32 * ofm_ublock.depth // 8
Diqing Zhong09387e22020-09-28 18:46:22 +0200338
339 for num_kernel_elems in sub_kernel_size:
340 if npu_block_type == NpuBlockType.Pooling:
Diqing Zhong986e3192020-11-16 16:15:56 +0100341 cycles = max(4, num_kernel_elems) * num_ublk_xy * num_ublk_z
Diqing Zhong09387e22020-09-28 18:46:22 +0200342 if ifm_tensor.dtype.size_in_bits() == 16 and arch.accelerator_config != Accelerator.Ethos_U55_32:
343 cycles *= 2
344 elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
Diqing Zhong986e3192020-11-16 16:15:56 +0100345 cycles = 4 * num_ublk_xy
Diqing Zhong09387e22020-09-28 18:46:22 +0200346 if ifm_tensor.dtype.size_in_bits() == 16:
347 cycles *= 2
Diqing Zhong986e3192020-11-16 16:15:56 +0100348 cycles = max(cycles_wb, cycles) * numeric_util.round_up_divide(num_kernel_elems, 4) * num_ublk_z
Diqing Zhong09387e22020-09-28 18:46:22 +0200349 elif (
350 (npu_block_type == NpuBlockType.ConvolutionMxN and block_traversal != TensorBlockTraversal.PartKernelFirst)
351 or npu_block_type == NpuBlockType.VectorProduct
352 or npu_block_type == NpuBlockType.ReduceSum
353 ):
Diqing Zhong986e3192020-11-16 16:15:56 +0100354 cycles = (
355 max(cycles_wb, 4 * num_ublk_xy)
356 * num_kernel_elems
357 * num_ublk_z
358 * numeric_util.round_up_divide(ifm_tens_shape[3], ifm_blk_depth)
359 )
Diqing Zhong09387e22020-09-28 18:46:22 +0200360 else:
361 assert block_traversal == TensorBlockTraversal.PartKernelFirst
362 divider = 2 if ifm_tensor.dtype.size_in_bits() == 16 else 4
Diqing Zhong986e3192020-11-16 16:15:56 +0100363 cycles = max(cycles_wb, 4 * num_ublk_xy) * (
Diqing Zhong09387e22020-09-28 18:46:22 +0200364 numeric_util.round_up_divide(num_kernel_elems, divider)
365 * numeric_util.round_up_divide(ifm_blk_depth, 8)
Diqing Zhong986e3192020-11-16 16:15:56 +0100366 * num_ublk_z
Diqing Zhong09387e22020-09-28 18:46:22 +0200367 * numeric_util.round_up_divide(ifm_tens_shape[3], ifm_blk_depth)
368 )
369 cycles_dpu_blk += cycles
370
371 cycles_dpu_blk /= arch.ncores
372
373 num_ofm_blk = (
374 numeric_util.round_up_divide(ofm_tens_shape[1], block_config.height)
375 * numeric_util.round_up_divide(ofm_tens_shape[2], block_config.width)
376 * numeric_util.round_up_divide(ofm_tens_shape[3], block_config.depth)
377 )
378
Diqing Zhong42e833d2020-10-02 13:18:42 +0200379 cycles_output_blk = estimate_output_cycles(
Diqing Zhong09387e22020-09-28 18:46:22 +0200380 arch, npu_block_type, primary_op, num_elems_blk, ifm_tensor, ofm_tensor, None, use_acc_40bits
381 )
382
Diqing Zhong986e3192020-11-16 16:15:56 +0100383 if scale_tensor:
384 if scale_tensor.mem_area is MemArea.Sram:
385 latency = 32
386 elif scale_tensor.mem_area is MemArea.Dram:
387 latency = 500
388 else:
389 latency = 64
390 cycles_bias_blk = 10 * min(block_config.depth, ofm_tens_shape[3]) * latency / 256
391 cycles_output_blk = max(cycles_output_blk, cycles_bias_blk)
392
Diqing Zhong09387e22020-09-28 18:46:22 +0200393 if cycles_dpu_blk > cycles_output_blk:
394 total_cycles = cycles_dpu_blk * num_ofm_blk + cycles_output_blk
395 else:
396 total_cycles = cycles_output_blk * num_ofm_blk + cycles_dpu_blk
397
398 return total_cycles
399
400
Diqing Zhonge168b962020-11-05 17:18:47 +0100401def estimate_memory_bandwidth(arch, mem_area, direction, tensor, block_size: Block, replace_bw=None):
402 if tensor.format not in (TensorFormat.NHWC, TensorFormat.NHCWB16):
403 return tensor.bandwidth() if replace_bw is None else replace_bw
404
405 # Estimate memory transfer efficiency by calculating the burst length
406 # this is related to data format, block shape, and tensor shape, etc.
407 max_burst_len = 32 if mem_area == MemArea.Sram else 128
408 burst_len = 0
409 elem_size = tensor.dtype.size_in_bytes()
410 is_ifm = direction == BandwidthDirection.Read
411 tens = tensor.clone()
412 if not tens.avoid_NHCWB16:
413 tens.set_format(TensorFormat.NHCWB16, arch)
414
415 if tens.format == TensorFormat.NHCWB16:
416 if tens.get_strides()[1] == block_size.depth:
417 burst_len = elem_size * block_size.depth * block_size.width
418 elif is_ifm:
419 burst_len = 16 * elem_size * block_size.width
420 else:
421 burst_len = 16 * elem_size * block_size.width * arch.ncores
422 else:
423 assert tens.format == TensorFormat.NHWC
424 if is_ifm:
425 if tens.get_strides()[3] == block_size.depth:
426 burst_len = elem_size * block_size.depth * block_size.width
427 else:
428 burst_len = elem_size * block_size.depth
429 else:
430 if block_size.depth <= 16 and tens.get_strides()[3] == block_size.depth:
431 burst_len = elem_size * block_size.depth * block_size.width
432 else:
433 burst_len = min(64, 16 * elem_size * arch.ncores, block_size.depth * elem_size)
434
435 burst_len = min(max_burst_len, burst_len)
436 bw = tens.bandwidth() if replace_bw is None else replace_bw
437
438 return bw * (max_burst_len / burst_len)
439
440
Tim Hall79d07d22020-04-27 18:20:16 +0100441def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], force_outputs_to_fast_storage=False):
442 if block_config is None:
443 block_config = ps.block_config
444 bws = make_bandwidth_array()
445 macs = make_macs_array()
446 cycles = make_cycles_array()
447 blocks = 0
448 ifm_read_multiple = 1
449 weight_read_multiple = 0
450
451 if ps.placement in set((PassPlacement.MemoryOnly, PassPlacement.StartupInit)):
452 return bws, macs, cycles, blocks, ifm_read_multiple, weight_read_multiple # nothing real happening in this pass
453
454 min_block_size = arch.min_block_sizes[ps.npu_block_type]
455
456 skirt = (0, 0, 0, 0)
457 explicit_padding = (0, 0, 0, 0)
458 primary_op = ps.primary_op
459 replacement_read_bws = {}
Diqing Zhonge168b962020-11-05 17:18:47 +0100460 ofm_block = Block(block_config[1], block_config[0], block_config[3])
461 ifm_block = Block(block_config[1], block_config[0], block_config[3])
462
Charles Xub02c8d92020-06-25 16:05:25 +0200463 if ps.placement == PassPlacement.Cpu:
464 cycles[PassCycles.Cpu] = arch.cpu_cycle_estimate(ps.ops[0])
465 elif primary_op:
Tim Hall79d07d22020-04-27 18:20:16 +0100466 skirt = primary_op.attrs.get("skirt", skirt)
467 explicit_padding = primary_op.attrs.get("explicit_padding", explicit_padding)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200468 assert primary_op.type.npu_block_type == ps.npu_block_type
469 npu_block_type = primary_op.type.npu_block_type
Diqing Zhong42e833d2020-10-02 13:18:42 +0200470 block_traversal = TensorBlockTraversal.Default
Tim Hall79d07d22020-04-27 18:20:16 +0100471
472 ifm_tensor, _, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm()
Diqing Zhonge168b962020-11-05 17:18:47 +0100473 ifm_tensor_shape = numeric_util.full_shape(4, ifm_tensor.shape, 1)
Tim Hall79d07d22020-04-27 18:20:16 +0100474
Tim Hallc30f4952020-06-15 20:47:35 +0100475 if npu_block_type in set(
Diqing Zhong09387e22020-09-28 18:46:22 +0200476 (
477 NpuBlockType.ConvolutionMxN,
478 NpuBlockType.ConvolutionDepthWise,
479 NpuBlockType.Pooling,
480 NpuBlockType.ReduceSum,
481 )
Tim Hallc30f4952020-06-15 20:47:35 +0100482 ):
Charles Xu3e9c4342020-04-22 08:31:43 +0200483 # extent the ifm to full dimension
484 ifm_tensor_brick_size = tuple(numeric_util.full_shape(4, list(ifm_tensor.brick_size), 1))
Charles Xu3e9c4342020-04-22 08:31:43 +0200485 ifm_tensor_bandwidth_shape = numeric_util.full_shape(4, ifm_tensor.bandwidth_shape, 1)
Tim Hall79d07d22020-04-27 18:20:16 +0100486
Diqing Zhong42e833d2020-10-02 13:18:42 +0200487 batch_size = ifm_tensor_shape[0]
Charles Xu3e9c4342020-04-22 08:31:43 +0200488 ifm_depth = ifm_tensor_bandwidth_shape[3]
Tim Hall79d07d22020-04-27 18:20:16 +0100489
490 # add in padding
491 ifm_tensor_shape[1] += explicit_padding[0] + explicit_padding[2] # height += top and bottom
492 ifm_tensor_shape[2] += explicit_padding[1] + explicit_padding[3] # width += left and right
493
494 strides = primary_op.attrs["strides"]
495 if npu_block_type != NpuBlockType.Pooling:
Diqing Zhong09387e22020-09-28 18:46:22 +0200496 if npu_block_type == NpuBlockType.ReduceSum:
497 block_traversal = TensorBlockTraversal.DepthFirst
498 weight_tensor_shape = [1, 1, ifm_tensor.shape[3], ofm_tensor.shape[3]]
499 weight_tensor_bandwidth_shape = [0] * 4
500 weight_tensor_element_size = 0
501 weight_tensor_bandwidth_compression_scale = 0.0
502 else:
503 block_traversal = weight_tensor.block_traversal
504 weight_tensor_shape = weight_tensor.shape
505 weight_tensor_bandwidth_shape = weight_tensor.bandwidth_shape
506 weight_tensor_element_size = weight_tensor.element_size()
507 weight_tensor_bandwidth_compression_scale = weight_tensor.bandwidth_compression_scale
Tim Hall79d07d22020-04-27 18:20:16 +0100508 nn_ops = (
509 int(ofm_tensor.shape[0])
510 * int(ofm_tensor.shape[1])
511 * int(ofm_tensor.shape[2])
512 * int(weight_tensor_shape[0])
513 * int(weight_tensor_shape[1])
514 * int(weight_tensor_shape[2])
515 * int(weight_tensor_shape[3])
Tim Hall79d07d22020-04-27 18:20:16 +0100516 )
517 else:
518 weight_tensor_shape = [
519 primary_op.attrs["ksize"][1],
520 primary_op.attrs["ksize"][2],
521 1,
522 ifm_tensor_shape[3],
523 ]
524 weight_tensor_bandwidth_shape = weight_tensor_shape
525 weight_tensor_element_size = 0
526 weight_tensor_bandwidth_compression_scale = 0.0
527 nn_ops = 0 # pooling doesn't count as NN ops
528
529 kernel_dims = weight_tensor_shape[:2]
530
531 sub_kernel_limits = arch.sub_kernel_limits[npu_block_type]
532 # count the sub kernels; the IFM block needs to be refetched for each of them
533 n_sub_kernels_y = numeric_util.round_up_divide(kernel_dims[0], sub_kernel_limits[0])
534 n_sub_kernels_x = numeric_util.round_up_divide(kernel_dims[1], sub_kernel_limits[1])
535 n_sub_kernels = n_sub_kernels_y * n_sub_kernels_x
536
537 clamped_skirt = list(skirt)
538 clamped_skirt[2] = min(clamped_skirt[2], sub_kernel_limits[0] - 1 - clamped_skirt[0])
539 clamped_skirt[3] = min(clamped_skirt[3], sub_kernel_limits[1] - 1 - clamped_skirt[1])
540 n_blocks, area, block_setup = get_n_blocks_and_area(
Charles Xu3e9c4342020-04-22 08:31:43 +0200541 ifm_tensor_brick_size,
Tim Hall79d07d22020-04-27 18:20:16 +0100542 ifm_tensor_shape[1:3],
543 skirt,
544 clamped_skirt,
545 block_config,
546 min_block_size,
547 strides,
548 )
549
Diqing Zhonge168b962020-11-05 17:18:47 +0100550 blocks = n_blocks * numeric_util.round_up_divide(weight_tensor_shape[3], ofm_block.depth)
Tim Hall79d07d22020-04-27 18:20:16 +0100551
Diqing Zhonge168b962020-11-05 17:18:47 +0100552 n_weight_stages = numeric_util.round_up_divide(weight_tensor_bandwidth_shape[3], ofm_block.depth)
Tim Hall79d07d22020-04-27 18:20:16 +0100553 if npu_block_type == NpuBlockType.ConvolutionDepthWise or npu_block_type == NpuBlockType.Pooling:
554 n_weight_stages = 1 # force to no reread
555
556 ifm_tensor_bw = (
557 n_sub_kernels
558 * batch_size
559 * area
560 * ifm_depth
561 * n_weight_stages
562 * ifm_tensor.element_size()
563 * ifm_tensor.bandwidth_compression_scale
564 )
565 replacement_read_bws[ifm_tensor] = ifm_tensor_bw
566 ifm_read_multiple = n_weight_stages
567
568 replacement_read_bws[weight_tensor] = (
569 batch_size
570 * shape_num_elements(weight_tensor_bandwidth_shape)
571 * weight_tensor_element_size
572 * weight_tensor_bandwidth_compression_scale
573 * n_blocks
574 ) # read once per block and batch
575 weight_read_multiple = n_blocks
576
577 n_kernel_xy = kernel_dims[0] * kernel_dims[1]
578 n_input_channels_at_a_time = block_config[2]
579
Diqing Zhong09387e22020-09-28 18:46:22 +0200580 if npu_block_type == NpuBlockType.Pooling or block_traversal in set(
Tim Hall79d07d22020-04-27 18:20:16 +0100581 (TensorBlockTraversal.PartKernelFirst, TensorBlockTraversal.DepthWise)
582 ):
583 n_input_channels_at_a_time = numeric_util.round_up_divide(n_input_channels_at_a_time, 4)
584 n_kernel_xy = max(
585 n_kernel_xy, 4
586 ) # need at least 4, as this is the minimum duty cycle for secondary accumulator writes
587 if weight_tensor is not None:
Diego Russoea6111a2020-04-14 18:41:58 +0100588 n_kernel_xy = numeric_util.round_up(n_kernel_xy, 4) # weights need to be read in blocks of 4
Tim Hall79d07d22020-04-27 18:20:16 +0100589
590 num_mac_ops = 0
591 for n_blocks_for_size, block_size in block_setup:
592 num_mac_ops += (
593 batch_size
594 * n_blocks_for_size
595 * block_size[0]
596 * block_size[1]
597 * numeric_util.round_up(weight_tensor_shape[2], n_input_channels_at_a_time)
Diqing Zhonge168b962020-11-05 17:18:47 +0100598 * numeric_util.round_up(weight_tensor_shape[3], ofm_block.depth)
Tim Hall79d07d22020-04-27 18:20:16 +0100599 * n_kernel_xy
600 )
601
Tim Hall79d07d22020-04-27 18:20:16 +0100602 macs[MacCount.NeuralNetworkMacs] += nn_ops
603 macs[MacCount.HardwareMacs] += num_mac_ops
Diqing Zhong42e833d2020-10-02 13:18:42 +0200604 cycles[PassCycles.Npu] = estimate_conv_pooling_cycles(
Diqing Zhong986e3192020-11-16 16:15:56 +0100605 arch,
606 npu_block_type,
607 primary_op,
608 ofm_block,
609 block_traversal,
610 kernel_dims,
611 ifm_tensor,
612 ofm_tensor,
613 ps.scale_tensor,
Diqing Zhong09387e22020-09-28 18:46:22 +0200614 )
Tim Hall79d07d22020-04-27 18:20:16 +0100615 elif npu_block_type == NpuBlockType.VectorProduct:
616 nn_macs = (
617 ifm_tensor.shape[0]
618 * numeric_util.round_up(weight_tensor.shape[-2], block_config[2])
619 * numeric_util.round_up(weight_tensor.shape[-1], block_config[3])
620 )
621 num_mac_ops = nn_macs
Diqing Zhonge168b962020-11-05 17:18:47 +0100622 block_traversal = weight_tensor.block_traversal
Tim Hall79d07d22020-04-27 18:20:16 +0100623
Diqing Zhong42e833d2020-10-02 13:18:42 +0200624 cycles[PassCycles.Npu] = estimate_conv_pooling_cycles(
Diqing Zhonge168b962020-11-05 17:18:47 +0100625 arch, npu_block_type, primary_op, ofm_block, block_traversal, [1, 1], ifm_tensor, ofm_tensor,
Diqing Zhong09387e22020-09-28 18:46:22 +0200626 )
Tim Hall79d07d22020-04-27 18:20:16 +0100627 macs[MacCount.NeuralNetworkMacs] += nn_macs
628 macs[MacCount.HardwareMacs] += num_mac_ops
629
Diqing Zhonge168b962020-11-05 17:18:47 +0100630 blocks = 1 * numeric_util.round_up_divide(weight_tensor.shape[-1], ofm_block.depth)
Tim Hall79d07d22020-04-27 18:20:16 +0100631
632 non_zero_fraction = 1.0
633 if ifm_tensor.values is not None:
634 nz_vector = np.amax(ifm_tensor.values != 0, axis=0) # max across batch axis
635 non_zero_fraction = np.average(nz_vector)
636
637 replacement_read_bws[ifm_tensor] = ifm_tensor.bandwidth()
638 replacement_read_bws[weight_tensor] = weight_tensor.bandwidth() * non_zero_fraction
639 ifm_read_multiple = 1
640 weight_read_multiple = non_zero_fraction
Diqing Zhonge8887a32020-09-24 09:53:48 +0200641 elif npu_block_type == NpuBlockType.ElementWise:
Tim Hall79d07d22020-04-27 18:20:16 +0100642 # Work out how many elements we have and calculate performance.
Diqing Zhong42e833d2020-10-02 13:18:42 +0200643 cycles[PassCycles.Npu] = estimate_output_cycles(
Diqing Zhong09387e22020-09-28 18:46:22 +0200644 arch, npu_block_type, primary_op, ofm_tensor.elements(), ps.ifm_tensor, ps.ofm_tensor, ps.ifm2_tensor
645 )
Diqing Zhong42e833d2020-10-02 13:18:42 +0200646
Diqing Zhonge168b962020-11-05 17:18:47 +0100647 ifm_block_depth = get_ifm_block_depth(
648 npu_block_type, ifm_tensor_shape[3], ifm_tensor.dtype.size_in_bits(), block_traversal, ofm_block.depth
649 )
650 ifm_block = arch.get_ifm_block_size(ifm_block_depth, ofm_block, primary_op.kernel)
651
Diqing Zhong42e833d2020-10-02 13:18:42 +0200652 prev_npu_pass = next((npu_ps for npu_ps in ps.dag_predecessors if npu_ps.placement is PassPlacement.Npu), None)
653 if prev_npu_pass is None:
654 # cycles for DMA ops in first pass
655 dma_ops = (op for op in ps.ops if op.type == Op.DMA)
656 for dma_op in dma_ops:
657 mem_area = dma_op.attrs["source"]
658 for tens in dma_op.inputs:
659 cycles[PassCycles.Npu] += tens.storage_size() / arch.memory_bandwidths_per_cycle[mem_area]
660
Tim Hall79d07d22020-04-27 18:20:16 +0100661 # apply the desired rewrites
662 for rewrite_op, tens, _, _, _, ps_to_rewrite in rewrite_list:
663 if ps != ps_to_rewrite:
664 continue
665 if rewrite_op == SchedulerRewrite.Nop:
666 pass # these are fine, no bandwidth changes
667 elif rewrite_op in (SchedulerRewrite.ChangeTensorSubPurpose,):
Diqing Zhonge168b962020-11-05 17:18:47 +0100668 if tens.purpose == TensorPurpose.FeatureMap:
669 bw = estimate_memory_bandwidth(
670 arch,
671 arch.fast_storage_mem_area,
672 BandwidthDirection.Read,
673 tens,
674 ifm_block,
675 replacement_read_bws[tens],
676 )
677 else:
678 bw = replacement_read_bws[tens]
679 bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Read] += bw
Tim Hall79d07d22020-04-27 18:20:16 +0100680 replacement_read_bws[tens] = 0
681
682 for tens in ps.outputs:
683 if force_outputs_to_fast_storage:
Diqing Zhonge168b962020-11-05 17:18:47 +0100684 bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Write] += estimate_memory_bandwidth(
685 arch, arch.fast_storage_mem_area, BandwidthDirection.Write, tens, ofm_block
686 )
Tim Hall79d07d22020-04-27 18:20:16 +0100687 else:
Diqing Zhonge168b962020-11-05 17:18:47 +0100688 bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += estimate_memory_bandwidth(
689 arch, tens.mem_area, BandwidthDirection.Write, tens, ofm_block
690 )
Tim Hall79d07d22020-04-27 18:20:16 +0100691
692 for tens in ps.intermediates:
693 bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
694
695 if tens in replacement_read_bws:
696 bw = replacement_read_bws[tens]
697 else:
698 bw = tens.bandwidth()
699
700 bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
701
702 for tens in ps.inputs:
Diqing Zhonge168b962020-11-05 17:18:47 +0100703 bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += estimate_memory_bandwidth(
704 arch, tens.mem_area, BandwidthDirection.Read, tens, ifm_block, replacement_read_bws.get(tens)
705 )
Tim Hall79d07d22020-04-27 18:20:16 +0100706
707 # quick build access counts for only current pass, even though these aren't the final numbers
Diqing Zhonge168b962020-11-05 17:18:47 +0100708 update_summary_cycles(arch, bws, cycles)
Tim Hall79d07d22020-04-27 18:20:16 +0100709
710 return bws, macs, cycles, blocks, ifm_read_multiple, weight_read_multiple
711
712
Diqing Zhonge168b962020-11-05 17:18:47 +0100713def update_summary_cycles(arch, bws, cycles):
714 cycles[PassCycles.SramAccess] = np.sum(bws[MemArea.Sram]) / arch.memory_bandwidths_per_cycle[MemArea.Sram]
Tim Hall79d07d22020-04-27 18:20:16 +0100715 cycles[PassCycles.DramAccess] = np.sum(bws[MemArea.Dram]) / arch.memory_bandwidths_per_cycle[MemArea.Dram]
716 cycles[PassCycles.OnChipFlashAccess] = (
717 np.sum(bws[MemArea.OnChipFlash]) / arch.memory_bandwidths_per_cycle[MemArea.OnChipFlash]
718 )
719 cycles[PassCycles.OffChipFlashAccess] = (
720 np.sum(bws[MemArea.OffChipFlash]) / arch.memory_bandwidths_per_cycle[MemArea.OffChipFlash]
721 )
722
723 cycles[PassCycles.Total] = np.max(cycles[: PassCycles.Total])
724 return cycles
725
726
727def collate_stats_for_cascaded_pass(arch, bws, macs, cycles):
728 return bws, macs, cycles
729
730
731def performance_for_cascaded_pass(arch, cps):
732 total_bws = make_bandwidth_array()
733 total_macs = make_macs_array()
734 total_cycles = make_cycles_array()
735
736 for ps in cps.passes:
737 bws, macs, cycles, blocks, _, _ = performance_metrics_for_pass(arch, ps)
738 ps.bandwidths = bws
739 ps.macs = macs
740 ps.cycles = cycles
741 ps.n_blocks = blocks
742 total_bws += bws
743 total_macs += macs
744 total_cycles += cycles
745
746 bws, macs, cycles = collate_stats_for_cascaded_pass(arch, total_bws, total_macs, total_cycles)
747 cps.bandwidths = bws
748 cps.macs = macs
749 cps.cycles = cycles
750 return bws, macs, cycles
751
752
753def calc_performance_for_network(nng, arch):
754 total_bws = make_bandwidth_array()
755 total_macs = np.zeros(MacCount.Size)
756 total_cycles = np.zeros(PassCycles.Size)
757
758 for sg in nng.subgraphs:
759 for cps in sg.cascaded_passes:
760 bws, macs, cycles = performance_for_cascaded_pass(arch, cps)
761 total_bws += bws
762 total_macs += macs
763 total_cycles += cycles
Tim Hall79d07d22020-04-27 18:20:16 +0100764
765 nng.bandwidths = total_bws
766 nng.macs = total_macs
767 nng.cycles = total_cycles