blob: 84cc4931a30da19d5ef6b710d1aa5bead2330c9d [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18# Description:
19# NPU performance estimation functions to estimate performance of a Pass and CascadedPass. Uses a model that takes the
20# maximum of the 'cycles required for bandwidth' and 'cycles required for computing'.
21#
22# Called during scheduling to evaluate different proposals, as well as post-scheduling to provide a final performance
23# estimate.
24
25import enum
26from . import numeric_util
27import numpy as np
28from .tensor import TensorPurpose, MemArea, TensorFormat, shape_num_elements, Tensor, TensorBlockTraversal
29from .operation import Operation
30from .data_type import DataType, BaseType
31from .nn_graph import PassPlacement, NpuBlockType, SchedulerRewrite, Pass
32from .architecture_features import Block, Kernel
33
34
35def rolling_buffer_dims_from_passes(arch, ps1, block_config_ps1, ps2, block_config_ps2):
36 ps2_strides = (1, 1, 1, 1)
37 ps2_dilation = (1, 1, 1, 1)
38 for op in ps2.ops:
39 if "strides" in op.attrs:
40 ps2_strides = op.attrs["strides"]
41 if "dilation" in op.attrs:
42 ps2_dilation = op.attrs["dilation"]
43
44 ifm_idx, _, weight_idx, _, _ = op.get_ifm_ifm2_weight_bias_ofm_indices()
45
46 rolling_buffer_sizes = []
47
48 weight_tensor = op.inputs[weight_idx]
49
50 ofm_block = Block(block_config_ps2[-3], block_config_ps2[-4], block_config_ps2[-1])
51 kernel = Kernel(
52 weight_tensor.shape[1], weight_tensor.shape[0], ps2_strides[2], ps2_strides[1], ps2_dilation[2], ps2_dilation[1]
53 )
54 kernel_block = Block(weight_tensor.shape[1], weight_tensor.shape[0], 65536)
55
56 if ps2.npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct)):
57 ifm_block_depth = arch.calc_ifm_block_depth(
58 op.inputs[ifm_idx].shape[-1], op.inputs[ifm_idx].dtype.size_in_bits()
59 )
60 else:
61 ifm_block_depth = block_config_ps2[-1]
62
63 ifm_block = arch.get_ifm_block_size(ifm_block_depth, ofm_block, kernel, kernel_block)
64
65 # The performed height calculation is for worst case
66 height = numeric_util.round_up(ifm_block.height + block_config_ps1[0], block_config_ps1[0])
67 width = ifm_block.width
68
69 rolling_buffer_sizes.append(height)
70 rolling_buffer_sizes.append(width)
71
72 return rolling_buffer_sizes
73
74
75class PassCycles(enum.IntEnum):
76 Dpu = 0
77 ElementWise = 1
78 Cpu = 2
79 SramAccess = 3
80 TotalPerPass = 4
81 DramAccess = 5
82 OnChipFlashAccess = 6
83 OffChipFlashAccess = 7
84 Total = 8
85 Size = 9
86
87 def display_name(self):
88 return (
89 "DPU",
90 "Element wise",
91 "CPU",
92 "SRAM Access",
93 "Total per Pass",
94 "DRAM Access",
95 "On-chip Flash Access",
96 "Off-chip Flash Access",
97 "Total",
98 "Size",
99 )[self.value]
100
101 def identifier_name(self):
102 return (
103 "dpu",
104 "element_wise",
105 "cpu",
106 "sram_access",
107 "total_per_pass",
108 "dram_access",
109 "on_chip_flash_access",
110 "off_chip_flash_access",
111 "total",
112 "size",
113 )[self.value]
114
115 @staticmethod
116 def all():
117 return (
118 PassCycles.Dpu,
119 PassCycles.ElementWise,
120 PassCycles.Cpu,
121 PassCycles.SramAccess,
122 PassCycles.DramAccess,
123 PassCycles.OnChipFlashAccess,
124 PassCycles.OffChipFlashAccess,
125 PassCycles.Total,
126 )
127
128
129class MacCount(enum.IntEnum):
130 NeuralNetworkMacs = 0
131 HardwareMacs = 1
132 Size = 2
133
134 def display_name(self):
135 return ("Neural Network Macs", "Hardware Macs", "Size")[self.value]
136
137 def identifier_name(self):
138 return ("nn_macs", "hardware_macs", "size")[self.value]
139
140 @staticmethod
141 def all():
142 return (MacCount.NeuralNetworkMacs, MacCount.HardwareMacs)
143
144
145class BandwidthDirection(enum.IntEnum):
146 Read = 0
147 Write = 1
148 Size = 2
149
150 def display_name(self):
151 return self.name
152
153 def identifier_name(self):
154 return self.name.lower()
155
156 @staticmethod
157 def all():
158 return (BandwidthDirection.Read, BandwidthDirection.Write)
159
160
161def make_bandwidth_array():
162 return np.zeros((MemArea.Size, TensorPurpose.Size, BandwidthDirection.Size))
163
164
165def make_macs_array():
166 return np.zeros(MacCount.Size, np.int)
167
168
169def make_cycles_array():
170 return np.zeros(PassCycles.Size)
171
172
173def make_metrics_arrays():
174 return (make_bandwidth_array(), make_macs_array(), make_cycles_array())
175
176
177def get_n_blocks_and_area(
178 ifm_brick_size, ifm_height_width, orig_skirt, clamped_skirt, block_config, min_block_size, strides
179):
180
181 ifm_block_config = (block_config[0] * strides[1], block_config[1] * strides[2])
182
183 n_normal_blocks = []
184 remainder_size = []
185 for i in range(2):
186 non_skirt_dim = ifm_height_width[i] - orig_skirt[i] - orig_skirt[2 + i]
187 n_blocks = non_skirt_dim // ifm_block_config[i]
188 n_normal_blocks.append(n_blocks)
189 remainder_dim = numeric_util.round_up(
190 ((non_skirt_dim - n_blocks * ifm_block_config[i] - 1) // strides[i + 1]) + 1, min_block_size[i]
191 )
192 remainder_size.append(remainder_dim)
193
194 # this will actually calculate reads into the edge padding.
195
196 # there are four cases in total, handling the edges that will not fill a complete block.
197
198 # 0000000001
199 # 0000000001
200 # 0000000001
201 # 0000000001
202 # 0000000001
203 # 0000000001
204 # 2222222223
205 total_blocks = 0
206 total_area = 0
207
208 block_setup = (
209 (n_normal_blocks[0] * n_normal_blocks[1], block_config),
210 (1 * n_normal_blocks[1], (remainder_size[0], block_config[1])),
211 (n_normal_blocks[0] * 1, (block_config[0], remainder_size[1])),
212 (1 * 1, remainder_size),
213 )
214
215 for n_blocks, block_size in block_setup:
216 if block_size[0] == 0 or block_size[1] == 0:
217 continue
218 read_dims = [0, 0]
219 for i in range(2):
220 read_dims[i] = (
221 numeric_util.round_up(clamped_skirt[i], ifm_brick_size[i + 1])
222 + block_size[i] * strides[i + 1]
223 + numeric_util.round_up(clamped_skirt[2 + i], ifm_brick_size[i + 1])
224 )
225 assert n_blocks >= 0
226 total_blocks += n_blocks
227 total_area += n_blocks * read_dims[0] * read_dims[1]
228 assert total_blocks >= 1
229 return total_blocks, total_area, block_setup
230
231
232def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], force_outputs_to_fast_storage=False):
233 if block_config is None:
234 block_config = ps.block_config
235 bws = make_bandwidth_array()
236 macs = make_macs_array()
237 cycles = make_cycles_array()
238 blocks = 0
239 ifm_read_multiple = 1
240 weight_read_multiple = 0
241
242 if ps.placement in set((PassPlacement.MemoryOnly, PassPlacement.StartupInit)):
243 return bws, macs, cycles, blocks, ifm_read_multiple, weight_read_multiple # nothing real happening in this pass
244
245 min_block_size = arch.min_block_sizes[ps.npu_block_type]
246
247 skirt = (0, 0, 0, 0)
248 explicit_padding = (0, 0, 0, 0)
249 primary_op = ps.primary_op
250 replacement_read_bws = {}
251 if primary_op:
252 skirt = primary_op.attrs.get("skirt", skirt)
253 explicit_padding = primary_op.attrs.get("explicit_padding", explicit_padding)
254 assert primary_op.attrs["npu_block_type"] == ps.npu_block_type
255 npu_block_type = primary_op.attrs["npu_block_type"]
256
257 ifm_tensor, _, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm()
258
259 npu_convolution_ops = set((NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise))
260 if (npu_block_type == NpuBlockType.Pooling and len(ifm_tensor.shape) == 4) or (
261 npu_block_type in npu_convolution_ops
262 ):
263
264 batch_size = ifm_tensor.shape[0]
265 ifm_tensor_shape = list(ifm_tensor.shape)
266 ifm_depth = ifm_tensor.bandwidth_shape[3]
267
268 # add in padding
269 ifm_tensor_shape[1] += explicit_padding[0] + explicit_padding[2] # height += top and bottom
270 ifm_tensor_shape[2] += explicit_padding[1] + explicit_padding[3] # width += left and right
271
272 strides = primary_op.attrs["strides"]
273 if npu_block_type != NpuBlockType.Pooling:
274 weight_tensor_shape = weight_tensor.shape
275 weight_tensor_bandwidth_shape = weight_tensor.bandwidth_shape
276 weight_tensor_element_size = weight_tensor.element_size()
277 weight_tensor_bandwidth_compression_scale = weight_tensor.bandwidth_compression_scale
278 nn_ops = (
279 int(ofm_tensor.shape[0])
280 * int(ofm_tensor.shape[1])
281 * int(ofm_tensor.shape[2])
282 * int(weight_tensor_shape[0])
283 * int(weight_tensor_shape[1])
284 * int(weight_tensor_shape[2])
285 * int(weight_tensor_shape[3])
286 / int(strides[1])
287 / int(strides[2])
288 )
289 else:
290 weight_tensor_shape = [
291 primary_op.attrs["ksize"][1],
292 primary_op.attrs["ksize"][2],
293 1,
294 ifm_tensor_shape[3],
295 ]
296 weight_tensor_bandwidth_shape = weight_tensor_shape
297 weight_tensor_element_size = 0
298 weight_tensor_bandwidth_compression_scale = 0.0
299 nn_ops = 0 # pooling doesn't count as NN ops
300
301 kernel_dims = weight_tensor_shape[:2]
302
303 sub_kernel_limits = arch.sub_kernel_limits[npu_block_type]
304 # count the sub kernels; the IFM block needs to be refetched for each of them
305 n_sub_kernels_y = numeric_util.round_up_divide(kernel_dims[0], sub_kernel_limits[0])
306 n_sub_kernels_x = numeric_util.round_up_divide(kernel_dims[1], sub_kernel_limits[1])
307 n_sub_kernels = n_sub_kernels_y * n_sub_kernels_x
308
309 clamped_skirt = list(skirt)
310 clamped_skirt[2] = min(clamped_skirt[2], sub_kernel_limits[0] - 1 - clamped_skirt[0])
311 clamped_skirt[3] = min(clamped_skirt[3], sub_kernel_limits[1] - 1 - clamped_skirt[1])
312 n_blocks, area, block_setup = get_n_blocks_and_area(
313 ifm_tensor.brick_size,
314 ifm_tensor_shape[1:3],
315 skirt,
316 clamped_skirt,
317 block_config,
318 min_block_size,
319 strides,
320 )
321
322 blocks = n_blocks * numeric_util.round_up_divide(weight_tensor_shape[3], block_config[3])
323
324 n_weight_stages = numeric_util.round_up_divide(weight_tensor_bandwidth_shape[3], block_config[3])
325 if npu_block_type == NpuBlockType.ConvolutionDepthWise or npu_block_type == NpuBlockType.Pooling:
326 n_weight_stages = 1 # force to no reread
327
328 ifm_tensor_bw = (
329 n_sub_kernels
330 * batch_size
331 * area
332 * ifm_depth
333 * n_weight_stages
334 * ifm_tensor.element_size()
335 * ifm_tensor.bandwidth_compression_scale
336 )
337 replacement_read_bws[ifm_tensor] = ifm_tensor_bw
338 ifm_read_multiple = n_weight_stages
339
340 replacement_read_bws[weight_tensor] = (
341 batch_size
342 * shape_num_elements(weight_tensor_bandwidth_shape)
343 * weight_tensor_element_size
344 * weight_tensor_bandwidth_compression_scale
345 * n_blocks
346 ) # read once per block and batch
347 weight_read_multiple = n_blocks
348
349 n_kernel_xy = kernel_dims[0] * kernel_dims[1]
350 n_input_channels_at_a_time = block_config[2]
351
352 if npu_block_type == NpuBlockType.Pooling or weight_tensor.block_traversal in set(
353 (TensorBlockTraversal.PartKernelFirst, TensorBlockTraversal.DepthWise)
354 ):
355 n_input_channels_at_a_time = numeric_util.round_up_divide(n_input_channels_at_a_time, 4)
356 n_kernel_xy = max(
357 n_kernel_xy, 4
358 ) # need at least 4, as this is the minimum duty cycle for secondary accumulator writes
359 if weight_tensor is not None:
360 n_kernel_xy = numeric_util.round_up(
361 n_kernel_xy, 4
362 ) # weights need to be read in blocks of 4
363
364 num_mac_ops = 0
365 for n_blocks_for_size, block_size in block_setup:
366 num_mac_ops += (
367 batch_size
368 * n_blocks_for_size
369 * block_size[0]
370 * block_size[1]
371 * numeric_util.round_up(weight_tensor_shape[2], n_input_channels_at_a_time)
372 * numeric_util.round_up(weight_tensor_shape[3], block_config[3])
373 * n_kernel_xy
374 )
375
376 if npu_block_type == NpuBlockType.Pooling:
377 # TODO: improve pooling estimation
378 cycles[PassCycles.Dpu] = num_mac_ops / arch.num_macs_per_cycle / 2
379 else:
380 cycles[PassCycles.Dpu] = num_mac_ops / arch.num_macs_per_cycle
381 macs[MacCount.NeuralNetworkMacs] += nn_ops
382 macs[MacCount.HardwareMacs] += num_mac_ops
383
384 elif npu_block_type == NpuBlockType.VectorProduct:
385 nn_macs = (
386 ifm_tensor.shape[0]
387 * numeric_util.round_up(weight_tensor.shape[-2], block_config[2])
388 * numeric_util.round_up(weight_tensor.shape[-1], block_config[3])
389 )
390 num_mac_ops = nn_macs
391
392 cycles[PassCycles.Dpu] = num_mac_ops / arch.num_macs_per_cycle
393 macs[MacCount.NeuralNetworkMacs] += nn_macs
394 macs[MacCount.HardwareMacs] += num_mac_ops
395
396 blocks = 1 * numeric_util.round_up_divide(weight_tensor.shape[-1], block_config[3])
397
398 non_zero_fraction = 1.0
399 if ifm_tensor.values is not None:
400 nz_vector = np.amax(ifm_tensor.values != 0, axis=0) # max across batch axis
401 non_zero_fraction = np.average(nz_vector)
402
403 replacement_read_bws[ifm_tensor] = ifm_tensor.bandwidth()
404 replacement_read_bws[weight_tensor] = weight_tensor.bandwidth() * non_zero_fraction
405 ifm_read_multiple = 1
406 weight_read_multiple = non_zero_fraction
407 else:
408 if ps.placement == PassPlacement.Npu and len(ps.outputs):
409 # Assume element-wise operation going through the element pipelines.
410 # Work out how many elements we have and calculate performance.
411 out = ps.outputs[0]
412 elms = out.elements()
413
414 cycles[PassCycles.ElementWise] = numeric_util.round_up_divide(elms, arch.num_elem_wise_units)
415
416 if ps.placement == PassPlacement.Cpu:
417 cycles[PassCycles.Cpu] = arch.cpu_cycle_estimate(ps.ops[0])
418
419 # apply the desired rewrites
420 for rewrite_op, tens, _, _, _, ps_to_rewrite in rewrite_list:
421 if ps != ps_to_rewrite:
422 continue
423 if rewrite_op == SchedulerRewrite.Nop:
424 pass # these are fine, no bandwidth changes
425 elif rewrite_op in (SchedulerRewrite.ChangeTensorSubPurpose,):
426 bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Read] += replacement_read_bws[tens]
427 replacement_read_bws[tens] = 0
428
429 for tens in ps.outputs:
430 if force_outputs_to_fast_storage:
431 bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
432 else:
433 bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
434
435 for tens in ps.intermediates:
436 bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
437
438 if tens in replacement_read_bws:
439 bw = replacement_read_bws[tens]
440 else:
441 bw = tens.bandwidth()
442
443 bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
444
445 for tens in ps.inputs:
446 if tens in replacement_read_bws:
447 bw = replacement_read_bws[tens]
448 else:
449 bw = tens.bandwidth()
450
451 bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
452
453 cycles[PassCycles.SramAccess] = np.sum(bws[MemArea.Sram]) / arch.memory_bandwidths_per_cycle[MemArea.Sram]
454 cycles[PassCycles.TotalPerPass] = np.max(cycles[: PassCycles.TotalPerPass])
455
456 # quick build access counts for only current pass, even though these aren't the final numbers
457 update_summary_cycles(arch, bws, macs, cycles)
458
459 return bws, macs, cycles, blocks, ifm_read_multiple, weight_read_multiple
460
461
462def update_summary_cycles(arch, bws, macs, cycles):
463 cycles[PassCycles.DramAccess] = np.sum(bws[MemArea.Dram]) / arch.memory_bandwidths_per_cycle[MemArea.Dram]
464 cycles[PassCycles.OnChipFlashAccess] = (
465 np.sum(bws[MemArea.OnChipFlash]) / arch.memory_bandwidths_per_cycle[MemArea.OnChipFlash]
466 )
467 cycles[PassCycles.OffChipFlashAccess] = (
468 np.sum(bws[MemArea.OffChipFlash]) / arch.memory_bandwidths_per_cycle[MemArea.OffChipFlash]
469 )
470
471 cycles[PassCycles.Total] = np.max(cycles[: PassCycles.Total])
472 return cycles
473
474
475def collate_stats_for_cascaded_pass(arch, bws, macs, cycles):
476 return bws, macs, cycles
477
478
479def performance_for_cascaded_pass(arch, cps):
480 total_bws = make_bandwidth_array()
481 total_macs = make_macs_array()
482 total_cycles = make_cycles_array()
483
484 for ps in cps.passes:
485 bws, macs, cycles, blocks, _, _ = performance_metrics_for_pass(arch, ps)
486 ps.bandwidths = bws
487 ps.macs = macs
488 ps.cycles = cycles
489 ps.n_blocks = blocks
490 total_bws += bws
491 total_macs += macs
492 total_cycles += cycles
493
494 bws, macs, cycles = collate_stats_for_cascaded_pass(arch, total_bws, total_macs, total_cycles)
495 cps.bandwidths = bws
496 cps.macs = macs
497 cps.cycles = cycles
498 return bws, macs, cycles
499
500
501def calc_performance_for_network(nng, arch):
502 total_bws = make_bandwidth_array()
503 total_macs = np.zeros(MacCount.Size)
504 total_cycles = np.zeros(PassCycles.Size)
505
506 for sg in nng.subgraphs:
507 for cps in sg.cascaded_passes:
508 bws, macs, cycles = performance_for_cascaded_pass(arch, cps)
509 total_bws += bws
510 total_macs += macs
511 total_cycles += cycles
512 total_cycles += arch.inter_pass_cycle_delay
513
514 nng.bandwidths = total_bws
515 nng.macs = total_macs
516 nng.cycles = total_cycles