blob: 11f1e92b763a8eea23d995df17f3ba55820b678a [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17
18# Description:
19# NPU performance estimation functions to estimate performance of a Pass and CascadedPass. Uses a model that takes the
20# maximum of the 'cycles required for bandwidth' and 'cycles required for computing'.
21#
22# Called during scheduling to evaluate different proposals, as well as post-scheduling to provide a final performance
23# estimate.
24
25import enum
Diego Russoea6111a2020-04-14 18:41:58 +010026
Tim Hall79d07d22020-04-27 18:20:16 +010027import numpy as np
Diego Russoea6111a2020-04-14 18:41:58 +010028
29from . import numeric_util
30from .tensor import TensorPurpose, MemArea, shape_num_elements, TensorBlockTraversal
31from .nn_graph import PassPlacement, SchedulerRewrite
32from .operation import NpuBlockType
Tim Hall79d07d22020-04-27 18:20:16 +010033from .architecture_features import Block, Kernel
34
35
36def rolling_buffer_dims_from_passes(arch, ps1, block_config_ps1, ps2, block_config_ps2):
37 ps2_strides = (1, 1, 1, 1)
38 ps2_dilation = (1, 1, 1, 1)
39 for op in ps2.ops:
40 if "strides" in op.attrs:
41 ps2_strides = op.attrs["strides"]
42 if "dilation" in op.attrs:
43 ps2_dilation = op.attrs["dilation"]
44
45 ifm_idx, _, weight_idx, _, _ = op.get_ifm_ifm2_weight_bias_ofm_indices()
46
47 rolling_buffer_sizes = []
48
49 weight_tensor = op.inputs[weight_idx]
50
51 ofm_block = Block(block_config_ps2[-3], block_config_ps2[-4], block_config_ps2[-1])
52 kernel = Kernel(
53 weight_tensor.shape[1], weight_tensor.shape[0], ps2_strides[2], ps2_strides[1], ps2_dilation[2], ps2_dilation[1]
54 )
55 kernel_block = Block(weight_tensor.shape[1], weight_tensor.shape[0], 65536)
56
57 if ps2.npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct)):
58 ifm_block_depth = arch.calc_ifm_block_depth(
59 op.inputs[ifm_idx].shape[-1], op.inputs[ifm_idx].dtype.size_in_bits()
60 )
61 else:
62 ifm_block_depth = block_config_ps2[-1]
63
64 ifm_block = arch.get_ifm_block_size(ifm_block_depth, ofm_block, kernel, kernel_block)
65
66 # The performed height calculation is for worst case
67 height = numeric_util.round_up(ifm_block.height + block_config_ps1[0], block_config_ps1[0])
68 width = ifm_block.width
69
70 rolling_buffer_sizes.append(height)
71 rolling_buffer_sizes.append(width)
72
73 return rolling_buffer_sizes
74
75
76class PassCycles(enum.IntEnum):
77 Dpu = 0
78 ElementWise = 1
79 Cpu = 2
80 SramAccess = 3
81 TotalPerPass = 4
82 DramAccess = 5
83 OnChipFlashAccess = 6
84 OffChipFlashAccess = 7
85 Total = 8
86 Size = 9
87
88 def display_name(self):
89 return (
90 "DPU",
91 "Element wise",
92 "CPU",
93 "SRAM Access",
94 "Total per Pass",
95 "DRAM Access",
96 "On-chip Flash Access",
97 "Off-chip Flash Access",
98 "Total",
99 "Size",
100 )[self.value]
101
102 def identifier_name(self):
103 return (
104 "dpu",
105 "element_wise",
106 "cpu",
107 "sram_access",
108 "total_per_pass",
109 "dram_access",
110 "on_chip_flash_access",
111 "off_chip_flash_access",
112 "total",
113 "size",
114 )[self.value]
115
116 @staticmethod
117 def all():
118 return (
119 PassCycles.Dpu,
120 PassCycles.ElementWise,
121 PassCycles.Cpu,
122 PassCycles.SramAccess,
123 PassCycles.DramAccess,
124 PassCycles.OnChipFlashAccess,
125 PassCycles.OffChipFlashAccess,
126 PassCycles.Total,
127 )
128
129
130class MacCount(enum.IntEnum):
131 NeuralNetworkMacs = 0
132 HardwareMacs = 1
133 Size = 2
134
135 def display_name(self):
136 return ("Neural Network Macs", "Hardware Macs", "Size")[self.value]
137
138 def identifier_name(self):
139 return ("nn_macs", "hardware_macs", "size")[self.value]
140
141 @staticmethod
142 def all():
143 return (MacCount.NeuralNetworkMacs, MacCount.HardwareMacs)
144
145
146class BandwidthDirection(enum.IntEnum):
147 Read = 0
148 Write = 1
149 Size = 2
150
151 def display_name(self):
152 return self.name
153
154 def identifier_name(self):
155 return self.name.lower()
156
157 @staticmethod
158 def all():
159 return (BandwidthDirection.Read, BandwidthDirection.Write)
160
161
162def make_bandwidth_array():
163 return np.zeros((MemArea.Size, TensorPurpose.Size, BandwidthDirection.Size))
164
165
166def make_macs_array():
167 return np.zeros(MacCount.Size, np.int)
168
169
170def make_cycles_array():
171 return np.zeros(PassCycles.Size)
172
173
174def make_metrics_arrays():
175 return (make_bandwidth_array(), make_macs_array(), make_cycles_array())
176
177
178def get_n_blocks_and_area(
179 ifm_brick_size, ifm_height_width, orig_skirt, clamped_skirt, block_config, min_block_size, strides
180):
181
182 ifm_block_config = (block_config[0] * strides[1], block_config[1] * strides[2])
183
184 n_normal_blocks = []
185 remainder_size = []
186 for i in range(2):
187 non_skirt_dim = ifm_height_width[i] - orig_skirt[i] - orig_skirt[2 + i]
188 n_blocks = non_skirt_dim // ifm_block_config[i]
189 n_normal_blocks.append(n_blocks)
190 remainder_dim = numeric_util.round_up(
191 ((non_skirt_dim - n_blocks * ifm_block_config[i] - 1) // strides[i + 1]) + 1, min_block_size[i]
192 )
193 remainder_size.append(remainder_dim)
194
195 # this will actually calculate reads into the edge padding.
196
197 # there are four cases in total, handling the edges that will not fill a complete block.
198
199 # 0000000001
200 # 0000000001
201 # 0000000001
202 # 0000000001
203 # 0000000001
204 # 0000000001
205 # 2222222223
206 total_blocks = 0
207 total_area = 0
208
209 block_setup = (
210 (n_normal_blocks[0] * n_normal_blocks[1], block_config),
211 (1 * n_normal_blocks[1], (remainder_size[0], block_config[1])),
212 (n_normal_blocks[0] * 1, (block_config[0], remainder_size[1])),
213 (1 * 1, remainder_size),
214 )
215
216 for n_blocks, block_size in block_setup:
217 if block_size[0] == 0 or block_size[1] == 0:
218 continue
219 read_dims = [0, 0]
220 for i in range(2):
221 read_dims[i] = (
222 numeric_util.round_up(clamped_skirt[i], ifm_brick_size[i + 1])
223 + block_size[i] * strides[i + 1]
224 + numeric_util.round_up(clamped_skirt[2 + i], ifm_brick_size[i + 1])
225 )
226 assert n_blocks >= 0
227 total_blocks += n_blocks
228 total_area += n_blocks * read_dims[0] * read_dims[1]
229 assert total_blocks >= 1
230 return total_blocks, total_area, block_setup
231
232
233def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], force_outputs_to_fast_storage=False):
234 if block_config is None:
235 block_config = ps.block_config
236 bws = make_bandwidth_array()
237 macs = make_macs_array()
238 cycles = make_cycles_array()
239 blocks = 0
240 ifm_read_multiple = 1
241 weight_read_multiple = 0
242
243 if ps.placement in set((PassPlacement.MemoryOnly, PassPlacement.StartupInit)):
244 return bws, macs, cycles, blocks, ifm_read_multiple, weight_read_multiple # nothing real happening in this pass
245
246 min_block_size = arch.min_block_sizes[ps.npu_block_type]
247
248 skirt = (0, 0, 0, 0)
249 explicit_padding = (0, 0, 0, 0)
250 primary_op = ps.primary_op
251 replacement_read_bws = {}
252 if primary_op:
253 skirt = primary_op.attrs.get("skirt", skirt)
254 explicit_padding = primary_op.attrs.get("explicit_padding", explicit_padding)
255 assert primary_op.attrs["npu_block_type"] == ps.npu_block_type
256 npu_block_type = primary_op.attrs["npu_block_type"]
257
258 ifm_tensor, _, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm()
259
260 npu_convolution_ops = set((NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise))
261 if (npu_block_type == NpuBlockType.Pooling and len(ifm_tensor.shape) == 4) or (
262 npu_block_type in npu_convolution_ops
263 ):
264
265 batch_size = ifm_tensor.shape[0]
266 ifm_tensor_shape = list(ifm_tensor.shape)
267 ifm_depth = ifm_tensor.bandwidth_shape[3]
268
269 # add in padding
270 ifm_tensor_shape[1] += explicit_padding[0] + explicit_padding[2] # height += top and bottom
271 ifm_tensor_shape[2] += explicit_padding[1] + explicit_padding[3] # width += left and right
272
273 strides = primary_op.attrs["strides"]
274 if npu_block_type != NpuBlockType.Pooling:
275 weight_tensor_shape = weight_tensor.shape
276 weight_tensor_bandwidth_shape = weight_tensor.bandwidth_shape
277 weight_tensor_element_size = weight_tensor.element_size()
278 weight_tensor_bandwidth_compression_scale = weight_tensor.bandwidth_compression_scale
279 nn_ops = (
280 int(ofm_tensor.shape[0])
281 * int(ofm_tensor.shape[1])
282 * int(ofm_tensor.shape[2])
283 * int(weight_tensor_shape[0])
284 * int(weight_tensor_shape[1])
285 * int(weight_tensor_shape[2])
286 * int(weight_tensor_shape[3])
287 / int(strides[1])
288 / int(strides[2])
289 )
290 else:
291 weight_tensor_shape = [
292 primary_op.attrs["ksize"][1],
293 primary_op.attrs["ksize"][2],
294 1,
295 ifm_tensor_shape[3],
296 ]
297 weight_tensor_bandwidth_shape = weight_tensor_shape
298 weight_tensor_element_size = 0
299 weight_tensor_bandwidth_compression_scale = 0.0
300 nn_ops = 0 # pooling doesn't count as NN ops
301
302 kernel_dims = weight_tensor_shape[:2]
303
304 sub_kernel_limits = arch.sub_kernel_limits[npu_block_type]
305 # count the sub kernels; the IFM block needs to be refetched for each of them
306 n_sub_kernels_y = numeric_util.round_up_divide(kernel_dims[0], sub_kernel_limits[0])
307 n_sub_kernels_x = numeric_util.round_up_divide(kernel_dims[1], sub_kernel_limits[1])
308 n_sub_kernels = n_sub_kernels_y * n_sub_kernels_x
309
310 clamped_skirt = list(skirt)
311 clamped_skirt[2] = min(clamped_skirt[2], sub_kernel_limits[0] - 1 - clamped_skirt[0])
312 clamped_skirt[3] = min(clamped_skirt[3], sub_kernel_limits[1] - 1 - clamped_skirt[1])
313 n_blocks, area, block_setup = get_n_blocks_and_area(
314 ifm_tensor.brick_size,
315 ifm_tensor_shape[1:3],
316 skirt,
317 clamped_skirt,
318 block_config,
319 min_block_size,
320 strides,
321 )
322
323 blocks = n_blocks * numeric_util.round_up_divide(weight_tensor_shape[3], block_config[3])
324
325 n_weight_stages = numeric_util.round_up_divide(weight_tensor_bandwidth_shape[3], block_config[3])
326 if npu_block_type == NpuBlockType.ConvolutionDepthWise or npu_block_type == NpuBlockType.Pooling:
327 n_weight_stages = 1 # force to no reread
328
329 ifm_tensor_bw = (
330 n_sub_kernels
331 * batch_size
332 * area
333 * ifm_depth
334 * n_weight_stages
335 * ifm_tensor.element_size()
336 * ifm_tensor.bandwidth_compression_scale
337 )
338 replacement_read_bws[ifm_tensor] = ifm_tensor_bw
339 ifm_read_multiple = n_weight_stages
340
341 replacement_read_bws[weight_tensor] = (
342 batch_size
343 * shape_num_elements(weight_tensor_bandwidth_shape)
344 * weight_tensor_element_size
345 * weight_tensor_bandwidth_compression_scale
346 * n_blocks
347 ) # read once per block and batch
348 weight_read_multiple = n_blocks
349
350 n_kernel_xy = kernel_dims[0] * kernel_dims[1]
351 n_input_channels_at_a_time = block_config[2]
352
353 if npu_block_type == NpuBlockType.Pooling or weight_tensor.block_traversal in set(
354 (TensorBlockTraversal.PartKernelFirst, TensorBlockTraversal.DepthWise)
355 ):
356 n_input_channels_at_a_time = numeric_util.round_up_divide(n_input_channels_at_a_time, 4)
357 n_kernel_xy = max(
358 n_kernel_xy, 4
359 ) # need at least 4, as this is the minimum duty cycle for secondary accumulator writes
360 if weight_tensor is not None:
Diego Russoea6111a2020-04-14 18:41:58 +0100361 n_kernel_xy = numeric_util.round_up(n_kernel_xy, 4) # weights need to be read in blocks of 4
Tim Hall79d07d22020-04-27 18:20:16 +0100362
363 num_mac_ops = 0
364 for n_blocks_for_size, block_size in block_setup:
365 num_mac_ops += (
366 batch_size
367 * n_blocks_for_size
368 * block_size[0]
369 * block_size[1]
370 * numeric_util.round_up(weight_tensor_shape[2], n_input_channels_at_a_time)
371 * numeric_util.round_up(weight_tensor_shape[3], block_config[3])
372 * n_kernel_xy
373 )
374
375 if npu_block_type == NpuBlockType.Pooling:
376 # TODO: improve pooling estimation
377 cycles[PassCycles.Dpu] = num_mac_ops / arch.num_macs_per_cycle / 2
378 else:
379 cycles[PassCycles.Dpu] = num_mac_ops / arch.num_macs_per_cycle
380 macs[MacCount.NeuralNetworkMacs] += nn_ops
381 macs[MacCount.HardwareMacs] += num_mac_ops
382
383 elif npu_block_type == NpuBlockType.VectorProduct:
384 nn_macs = (
385 ifm_tensor.shape[0]
386 * numeric_util.round_up(weight_tensor.shape[-2], block_config[2])
387 * numeric_util.round_up(weight_tensor.shape[-1], block_config[3])
388 )
389 num_mac_ops = nn_macs
390
391 cycles[PassCycles.Dpu] = num_mac_ops / arch.num_macs_per_cycle
392 macs[MacCount.NeuralNetworkMacs] += nn_macs
393 macs[MacCount.HardwareMacs] += num_mac_ops
394
395 blocks = 1 * numeric_util.round_up_divide(weight_tensor.shape[-1], block_config[3])
396
397 non_zero_fraction = 1.0
398 if ifm_tensor.values is not None:
399 nz_vector = np.amax(ifm_tensor.values != 0, axis=0) # max across batch axis
400 non_zero_fraction = np.average(nz_vector)
401
402 replacement_read_bws[ifm_tensor] = ifm_tensor.bandwidth()
403 replacement_read_bws[weight_tensor] = weight_tensor.bandwidth() * non_zero_fraction
404 ifm_read_multiple = 1
405 weight_read_multiple = non_zero_fraction
406 else:
407 if ps.placement == PassPlacement.Npu and len(ps.outputs):
408 # Assume element-wise operation going through the element pipelines.
409 # Work out how many elements we have and calculate performance.
410 out = ps.outputs[0]
411 elms = out.elements()
412
413 cycles[PassCycles.ElementWise] = numeric_util.round_up_divide(elms, arch.num_elem_wise_units)
414
415 if ps.placement == PassPlacement.Cpu:
416 cycles[PassCycles.Cpu] = arch.cpu_cycle_estimate(ps.ops[0])
417
418 # apply the desired rewrites
419 for rewrite_op, tens, _, _, _, ps_to_rewrite in rewrite_list:
420 if ps != ps_to_rewrite:
421 continue
422 if rewrite_op == SchedulerRewrite.Nop:
423 pass # these are fine, no bandwidth changes
424 elif rewrite_op in (SchedulerRewrite.ChangeTensorSubPurpose,):
425 bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Read] += replacement_read_bws[tens]
426 replacement_read_bws[tens] = 0
427
428 for tens in ps.outputs:
429 if force_outputs_to_fast_storage:
430 bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
431 else:
432 bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
433
434 for tens in ps.intermediates:
435 bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
436
437 if tens in replacement_read_bws:
438 bw = replacement_read_bws[tens]
439 else:
440 bw = tens.bandwidth()
441
442 bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
443
444 for tens in ps.inputs:
445 if tens in replacement_read_bws:
446 bw = replacement_read_bws[tens]
447 else:
448 bw = tens.bandwidth()
449
450 bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
451
452 cycles[PassCycles.SramAccess] = np.sum(bws[MemArea.Sram]) / arch.memory_bandwidths_per_cycle[MemArea.Sram]
453 cycles[PassCycles.TotalPerPass] = np.max(cycles[: PassCycles.TotalPerPass])
454
455 # quick build access counts for only current pass, even though these aren't the final numbers
456 update_summary_cycles(arch, bws, macs, cycles)
457
458 return bws, macs, cycles, blocks, ifm_read_multiple, weight_read_multiple
459
460
461def update_summary_cycles(arch, bws, macs, cycles):
462 cycles[PassCycles.DramAccess] = np.sum(bws[MemArea.Dram]) / arch.memory_bandwidths_per_cycle[MemArea.Dram]
463 cycles[PassCycles.OnChipFlashAccess] = (
464 np.sum(bws[MemArea.OnChipFlash]) / arch.memory_bandwidths_per_cycle[MemArea.OnChipFlash]
465 )
466 cycles[PassCycles.OffChipFlashAccess] = (
467 np.sum(bws[MemArea.OffChipFlash]) / arch.memory_bandwidths_per_cycle[MemArea.OffChipFlash]
468 )
469
470 cycles[PassCycles.Total] = np.max(cycles[: PassCycles.Total])
471 return cycles
472
473
474def collate_stats_for_cascaded_pass(arch, bws, macs, cycles):
475 return bws, macs, cycles
476
477
478def performance_for_cascaded_pass(arch, cps):
479 total_bws = make_bandwidth_array()
480 total_macs = make_macs_array()
481 total_cycles = make_cycles_array()
482
483 for ps in cps.passes:
484 bws, macs, cycles, blocks, _, _ = performance_metrics_for_pass(arch, ps)
485 ps.bandwidths = bws
486 ps.macs = macs
487 ps.cycles = cycles
488 ps.n_blocks = blocks
489 total_bws += bws
490 total_macs += macs
491 total_cycles += cycles
492
493 bws, macs, cycles = collate_stats_for_cascaded_pass(arch, total_bws, total_macs, total_cycles)
494 cps.bandwidths = bws
495 cps.macs = macs
496 cps.cycles = cycles
497 return bws, macs, cycles
498
499
500def calc_performance_for_network(nng, arch):
501 total_bws = make_bandwidth_array()
502 total_macs = np.zeros(MacCount.Size)
503 total_cycles = np.zeros(PassCycles.Size)
504
505 for sg in nng.subgraphs:
506 for cps in sg.cascaded_passes:
507 bws, macs, cycles = performance_for_cascaded_pass(arch, cps)
508 total_bws += bws
509 total_macs += macs
510 total_cycles += cycles
511 total_cycles += arch.inter_pass_cycle_delay
512
513 nng.bandwidths = total_bws
514 nng.macs = total_macs
515 nng.cycles = total_cycles