blob: eda8e42bd02d68be7ab8e927472ca809a1d885a6 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# NPU performance estimation functions to estimate performance of a Pass and CascadedPass. Uses a model that takes the
18# maximum of the 'cycles required for bandwidth' and 'cycles required for computing'.
19#
20# Called during scheduling to evaluate different proposals, as well as post-scheduling to provide a final performance
21# estimate.
Tim Hall79d07d22020-04-27 18:20:16 +010022import enum
Diego Russoea6111a2020-04-14 18:41:58 +010023
Tim Hall79d07d22020-04-27 18:20:16 +010024import numpy as np
Diego Russoea6111a2020-04-14 18:41:58 +010025
26from . import numeric_util
Diego Russoe8a10452020-04-21 17:39:10 +010027from .architecture_features import Block
28from .architecture_features import Kernel
29from .nn_graph import PassPlacement
30from .nn_graph import SchedulerRewrite
Diego Russoea6111a2020-04-14 18:41:58 +010031from .operation import NpuBlockType
Diego Russoe8a10452020-04-21 17:39:10 +010032from .tensor import MemArea
33from .tensor import shape_num_elements
34from .tensor import TensorBlockTraversal
35from .tensor import TensorPurpose
Tim Hall79d07d22020-04-27 18:20:16 +010036
37
38def rolling_buffer_dims_from_passes(arch, ps1, block_config_ps1, ps2, block_config_ps2):
39 ps2_strides = (1, 1, 1, 1)
40 ps2_dilation = (1, 1, 1, 1)
41 for op in ps2.ops:
42 if "strides" in op.attrs:
43 ps2_strides = op.attrs["strides"]
44 if "dilation" in op.attrs:
45 ps2_dilation = op.attrs["dilation"]
46
47 ifm_idx, _, weight_idx, _, _ = op.get_ifm_ifm2_weight_bias_ofm_indices()
48
49 rolling_buffer_sizes = []
50
51 weight_tensor = op.inputs[weight_idx]
52
53 ofm_block = Block(block_config_ps2[-3], block_config_ps2[-4], block_config_ps2[-1])
54 kernel = Kernel(
55 weight_tensor.shape[1], weight_tensor.shape[0], ps2_strides[2], ps2_strides[1], ps2_dilation[2], ps2_dilation[1]
56 )
57 kernel_block = Block(weight_tensor.shape[1], weight_tensor.shape[0], 65536)
58
59 if ps2.npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct)):
60 ifm_block_depth = arch.calc_ifm_block_depth(
61 op.inputs[ifm_idx].shape[-1], op.inputs[ifm_idx].dtype.size_in_bits()
62 )
63 else:
64 ifm_block_depth = block_config_ps2[-1]
65
66 ifm_block = arch.get_ifm_block_size(ifm_block_depth, ofm_block, kernel, kernel_block)
67
68 # The performed height calculation is for worst case
69 height = numeric_util.round_up(ifm_block.height + block_config_ps1[0], block_config_ps1[0])
70 width = ifm_block.width
71
72 rolling_buffer_sizes.append(height)
73 rolling_buffer_sizes.append(width)
74
75 return rolling_buffer_sizes
76
77
78class PassCycles(enum.IntEnum):
79 Dpu = 0
80 ElementWise = 1
81 Cpu = 2
82 SramAccess = 3
83 TotalPerPass = 4
84 DramAccess = 5
85 OnChipFlashAccess = 6
86 OffChipFlashAccess = 7
87 Total = 8
88 Size = 9
89
90 def display_name(self):
91 return (
92 "DPU",
93 "Element wise",
94 "CPU",
95 "SRAM Access",
96 "Total per Pass",
97 "DRAM Access",
98 "On-chip Flash Access",
99 "Off-chip Flash Access",
100 "Total",
101 "Size",
102 )[self.value]
103
104 def identifier_name(self):
105 return (
106 "dpu",
107 "element_wise",
108 "cpu",
109 "sram_access",
110 "total_per_pass",
111 "dram_access",
112 "on_chip_flash_access",
113 "off_chip_flash_access",
114 "total",
115 "size",
116 )[self.value]
117
118 @staticmethod
119 def all():
120 return (
121 PassCycles.Dpu,
122 PassCycles.ElementWise,
123 PassCycles.Cpu,
124 PassCycles.SramAccess,
125 PassCycles.DramAccess,
126 PassCycles.OnChipFlashAccess,
127 PassCycles.OffChipFlashAccess,
128 PassCycles.Total,
129 )
130
131
132class MacCount(enum.IntEnum):
133 NeuralNetworkMacs = 0
134 HardwareMacs = 1
135 Size = 2
136
137 def display_name(self):
138 return ("Neural Network Macs", "Hardware Macs", "Size")[self.value]
139
140 def identifier_name(self):
141 return ("nn_macs", "hardware_macs", "size")[self.value]
142
143 @staticmethod
144 def all():
145 return (MacCount.NeuralNetworkMacs, MacCount.HardwareMacs)
146
147
148class BandwidthDirection(enum.IntEnum):
149 Read = 0
150 Write = 1
151 Size = 2
152
153 def display_name(self):
154 return self.name
155
156 def identifier_name(self):
157 return self.name.lower()
158
159 @staticmethod
160 def all():
161 return (BandwidthDirection.Read, BandwidthDirection.Write)
162
163
164def make_bandwidth_array():
165 return np.zeros((MemArea.Size, TensorPurpose.Size, BandwidthDirection.Size))
166
167
168def make_macs_array():
169 return np.zeros(MacCount.Size, np.int)
170
171
172def make_cycles_array():
173 return np.zeros(PassCycles.Size)
174
175
176def make_metrics_arrays():
177 return (make_bandwidth_array(), make_macs_array(), make_cycles_array())
178
179
180def get_n_blocks_and_area(
181 ifm_brick_size, ifm_height_width, orig_skirt, clamped_skirt, block_config, min_block_size, strides
182):
183
184 ifm_block_config = (block_config[0] * strides[1], block_config[1] * strides[2])
185
186 n_normal_blocks = []
187 remainder_size = []
188 for i in range(2):
189 non_skirt_dim = ifm_height_width[i] - orig_skirt[i] - orig_skirt[2 + i]
190 n_blocks = non_skirt_dim // ifm_block_config[i]
191 n_normal_blocks.append(n_blocks)
192 remainder_dim = numeric_util.round_up(
193 ((non_skirt_dim - n_blocks * ifm_block_config[i] - 1) // strides[i + 1]) + 1, min_block_size[i]
194 )
195 remainder_size.append(remainder_dim)
196
197 # this will actually calculate reads into the edge padding.
198
199 # there are four cases in total, handling the edges that will not fill a complete block.
200
201 # 0000000001
202 # 0000000001
203 # 0000000001
204 # 0000000001
205 # 0000000001
206 # 0000000001
207 # 2222222223
208 total_blocks = 0
209 total_area = 0
210
211 block_setup = (
212 (n_normal_blocks[0] * n_normal_blocks[1], block_config),
213 (1 * n_normal_blocks[1], (remainder_size[0], block_config[1])),
214 (n_normal_blocks[0] * 1, (block_config[0], remainder_size[1])),
215 (1 * 1, remainder_size),
216 )
217
218 for n_blocks, block_size in block_setup:
219 if block_size[0] == 0 or block_size[1] == 0:
220 continue
221 read_dims = [0, 0]
222 for i in range(2):
223 read_dims[i] = (
224 numeric_util.round_up(clamped_skirt[i], ifm_brick_size[i + 1])
225 + block_size[i] * strides[i + 1]
226 + numeric_util.round_up(clamped_skirt[2 + i], ifm_brick_size[i + 1])
227 )
228 assert n_blocks >= 0
229 total_blocks += n_blocks
230 total_area += n_blocks * read_dims[0] * read_dims[1]
231 assert total_blocks >= 1
232 return total_blocks, total_area, block_setup
233
234
235def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], force_outputs_to_fast_storage=False):
236 if block_config is None:
237 block_config = ps.block_config
238 bws = make_bandwidth_array()
239 macs = make_macs_array()
240 cycles = make_cycles_array()
241 blocks = 0
242 ifm_read_multiple = 1
243 weight_read_multiple = 0
244
245 if ps.placement in set((PassPlacement.MemoryOnly, PassPlacement.StartupInit)):
246 return bws, macs, cycles, blocks, ifm_read_multiple, weight_read_multiple # nothing real happening in this pass
247
248 min_block_size = arch.min_block_sizes[ps.npu_block_type]
249
250 skirt = (0, 0, 0, 0)
251 explicit_padding = (0, 0, 0, 0)
252 primary_op = ps.primary_op
253 replacement_read_bws = {}
254 if primary_op:
255 skirt = primary_op.attrs.get("skirt", skirt)
256 explicit_padding = primary_op.attrs.get("explicit_padding", explicit_padding)
257 assert primary_op.attrs["npu_block_type"] == ps.npu_block_type
258 npu_block_type = primary_op.attrs["npu_block_type"]
259
260 ifm_tensor, _, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm()
261
Charles Xu3e9c4342020-04-22 08:31:43 +0200262 if npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling)):
263 # extent the ifm to full dimension
264 ifm_tensor_brick_size = tuple(numeric_util.full_shape(4, list(ifm_tensor.brick_size), 1))
265 ifm_tensor_shape = numeric_util.full_shape(4, ifm_tensor.shape, 1)
266 ifm_tensor_bandwidth_shape = numeric_util.full_shape(4, ifm_tensor.bandwidth_shape, 1)
Tim Hall79d07d22020-04-27 18:20:16 +0100267
268 batch_size = ifm_tensor.shape[0]
Charles Xu3e9c4342020-04-22 08:31:43 +0200269 ifm_depth = ifm_tensor_bandwidth_shape[3]
Tim Hall79d07d22020-04-27 18:20:16 +0100270
271 # add in padding
272 ifm_tensor_shape[1] += explicit_padding[0] + explicit_padding[2] # height += top and bottom
273 ifm_tensor_shape[2] += explicit_padding[1] + explicit_padding[3] # width += left and right
274
275 strides = primary_op.attrs["strides"]
276 if npu_block_type != NpuBlockType.Pooling:
277 weight_tensor_shape = weight_tensor.shape
278 weight_tensor_bandwidth_shape = weight_tensor.bandwidth_shape
279 weight_tensor_element_size = weight_tensor.element_size()
280 weight_tensor_bandwidth_compression_scale = weight_tensor.bandwidth_compression_scale
281 nn_ops = (
282 int(ofm_tensor.shape[0])
283 * int(ofm_tensor.shape[1])
284 * int(ofm_tensor.shape[2])
285 * int(weight_tensor_shape[0])
286 * int(weight_tensor_shape[1])
287 * int(weight_tensor_shape[2])
288 * int(weight_tensor_shape[3])
289 / int(strides[1])
290 / int(strides[2])
291 )
292 else:
293 weight_tensor_shape = [
294 primary_op.attrs["ksize"][1],
295 primary_op.attrs["ksize"][2],
296 1,
297 ifm_tensor_shape[3],
298 ]
299 weight_tensor_bandwidth_shape = weight_tensor_shape
300 weight_tensor_element_size = 0
301 weight_tensor_bandwidth_compression_scale = 0.0
302 nn_ops = 0 # pooling doesn't count as NN ops
303
304 kernel_dims = weight_tensor_shape[:2]
305
306 sub_kernel_limits = arch.sub_kernel_limits[npu_block_type]
307 # count the sub kernels; the IFM block needs to be refetched for each of them
308 n_sub_kernels_y = numeric_util.round_up_divide(kernel_dims[0], sub_kernel_limits[0])
309 n_sub_kernels_x = numeric_util.round_up_divide(kernel_dims[1], sub_kernel_limits[1])
310 n_sub_kernels = n_sub_kernels_y * n_sub_kernels_x
311
312 clamped_skirt = list(skirt)
313 clamped_skirt[2] = min(clamped_skirt[2], sub_kernel_limits[0] - 1 - clamped_skirt[0])
314 clamped_skirt[3] = min(clamped_skirt[3], sub_kernel_limits[1] - 1 - clamped_skirt[1])
315 n_blocks, area, block_setup = get_n_blocks_and_area(
Charles Xu3e9c4342020-04-22 08:31:43 +0200316 ifm_tensor_brick_size,
Tim Hall79d07d22020-04-27 18:20:16 +0100317 ifm_tensor_shape[1:3],
318 skirt,
319 clamped_skirt,
320 block_config,
321 min_block_size,
322 strides,
323 )
324
325 blocks = n_blocks * numeric_util.round_up_divide(weight_tensor_shape[3], block_config[3])
326
327 n_weight_stages = numeric_util.round_up_divide(weight_tensor_bandwidth_shape[3], block_config[3])
328 if npu_block_type == NpuBlockType.ConvolutionDepthWise or npu_block_type == NpuBlockType.Pooling:
329 n_weight_stages = 1 # force to no reread
330
331 ifm_tensor_bw = (
332 n_sub_kernels
333 * batch_size
334 * area
335 * ifm_depth
336 * n_weight_stages
337 * ifm_tensor.element_size()
338 * ifm_tensor.bandwidth_compression_scale
339 )
340 replacement_read_bws[ifm_tensor] = ifm_tensor_bw
341 ifm_read_multiple = n_weight_stages
342
343 replacement_read_bws[weight_tensor] = (
344 batch_size
345 * shape_num_elements(weight_tensor_bandwidth_shape)
346 * weight_tensor_element_size
347 * weight_tensor_bandwidth_compression_scale
348 * n_blocks
349 ) # read once per block and batch
350 weight_read_multiple = n_blocks
351
352 n_kernel_xy = kernel_dims[0] * kernel_dims[1]
353 n_input_channels_at_a_time = block_config[2]
354
355 if npu_block_type == NpuBlockType.Pooling or weight_tensor.block_traversal in set(
356 (TensorBlockTraversal.PartKernelFirst, TensorBlockTraversal.DepthWise)
357 ):
358 n_input_channels_at_a_time = numeric_util.round_up_divide(n_input_channels_at_a_time, 4)
359 n_kernel_xy = max(
360 n_kernel_xy, 4
361 ) # need at least 4, as this is the minimum duty cycle for secondary accumulator writes
362 if weight_tensor is not None:
Diego Russoea6111a2020-04-14 18:41:58 +0100363 n_kernel_xy = numeric_util.round_up(n_kernel_xy, 4) # weights need to be read in blocks of 4
Tim Hall79d07d22020-04-27 18:20:16 +0100364
365 num_mac_ops = 0
366 for n_blocks_for_size, block_size in block_setup:
367 num_mac_ops += (
368 batch_size
369 * n_blocks_for_size
370 * block_size[0]
371 * block_size[1]
372 * numeric_util.round_up(weight_tensor_shape[2], n_input_channels_at_a_time)
373 * numeric_util.round_up(weight_tensor_shape[3], block_config[3])
374 * n_kernel_xy
375 )
376
377 if npu_block_type == NpuBlockType.Pooling:
378 # TODO: improve pooling estimation
379 cycles[PassCycles.Dpu] = num_mac_ops / arch.num_macs_per_cycle / 2
380 else:
381 cycles[PassCycles.Dpu] = num_mac_ops / arch.num_macs_per_cycle
382 macs[MacCount.NeuralNetworkMacs] += nn_ops
383 macs[MacCount.HardwareMacs] += num_mac_ops
384
385 elif npu_block_type == NpuBlockType.VectorProduct:
386 nn_macs = (
387 ifm_tensor.shape[0]
388 * numeric_util.round_up(weight_tensor.shape[-2], block_config[2])
389 * numeric_util.round_up(weight_tensor.shape[-1], block_config[3])
390 )
391 num_mac_ops = nn_macs
392
393 cycles[PassCycles.Dpu] = num_mac_ops / arch.num_macs_per_cycle
394 macs[MacCount.NeuralNetworkMacs] += nn_macs
395 macs[MacCount.HardwareMacs] += num_mac_ops
396
397 blocks = 1 * numeric_util.round_up_divide(weight_tensor.shape[-1], block_config[3])
398
399 non_zero_fraction = 1.0
400 if ifm_tensor.values is not None:
401 nz_vector = np.amax(ifm_tensor.values != 0, axis=0) # max across batch axis
402 non_zero_fraction = np.average(nz_vector)
403
404 replacement_read_bws[ifm_tensor] = ifm_tensor.bandwidth()
405 replacement_read_bws[weight_tensor] = weight_tensor.bandwidth() * non_zero_fraction
406 ifm_read_multiple = 1
407 weight_read_multiple = non_zero_fraction
408 else:
409 if ps.placement == PassPlacement.Npu and len(ps.outputs):
410 # Assume element-wise operation going through the element pipelines.
411 # Work out how many elements we have and calculate performance.
412 out = ps.outputs[0]
413 elms = out.elements()
414
415 cycles[PassCycles.ElementWise] = numeric_util.round_up_divide(elms, arch.num_elem_wise_units)
416
417 if ps.placement == PassPlacement.Cpu:
418 cycles[PassCycles.Cpu] = arch.cpu_cycle_estimate(ps.ops[0])
419
420 # apply the desired rewrites
421 for rewrite_op, tens, _, _, _, ps_to_rewrite in rewrite_list:
422 if ps != ps_to_rewrite:
423 continue
424 if rewrite_op == SchedulerRewrite.Nop:
425 pass # these are fine, no bandwidth changes
426 elif rewrite_op in (SchedulerRewrite.ChangeTensorSubPurpose,):
427 bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Read] += replacement_read_bws[tens]
428 replacement_read_bws[tens] = 0
429
430 for tens in ps.outputs:
431 if force_outputs_to_fast_storage:
432 bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
433 else:
434 bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
435
436 for tens in ps.intermediates:
437 bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
438
439 if tens in replacement_read_bws:
440 bw = replacement_read_bws[tens]
441 else:
442 bw = tens.bandwidth()
443
444 bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
445
446 for tens in ps.inputs:
447 if tens in replacement_read_bws:
448 bw = replacement_read_bws[tens]
449 else:
450 bw = tens.bandwidth()
451
452 bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
453
454 cycles[PassCycles.SramAccess] = np.sum(bws[MemArea.Sram]) / arch.memory_bandwidths_per_cycle[MemArea.Sram]
455 cycles[PassCycles.TotalPerPass] = np.max(cycles[: PassCycles.TotalPerPass])
456
457 # quick build access counts for only current pass, even though these aren't the final numbers
458 update_summary_cycles(arch, bws, macs, cycles)
459
460 return bws, macs, cycles, blocks, ifm_read_multiple, weight_read_multiple
461
462
463def update_summary_cycles(arch, bws, macs, cycles):
464 cycles[PassCycles.DramAccess] = np.sum(bws[MemArea.Dram]) / arch.memory_bandwidths_per_cycle[MemArea.Dram]
465 cycles[PassCycles.OnChipFlashAccess] = (
466 np.sum(bws[MemArea.OnChipFlash]) / arch.memory_bandwidths_per_cycle[MemArea.OnChipFlash]
467 )
468 cycles[PassCycles.OffChipFlashAccess] = (
469 np.sum(bws[MemArea.OffChipFlash]) / arch.memory_bandwidths_per_cycle[MemArea.OffChipFlash]
470 )
471
472 cycles[PassCycles.Total] = np.max(cycles[: PassCycles.Total])
473 return cycles
474
475
476def collate_stats_for_cascaded_pass(arch, bws, macs, cycles):
477 return bws, macs, cycles
478
479
480def performance_for_cascaded_pass(arch, cps):
481 total_bws = make_bandwidth_array()
482 total_macs = make_macs_array()
483 total_cycles = make_cycles_array()
484
485 for ps in cps.passes:
486 bws, macs, cycles, blocks, _, _ = performance_metrics_for_pass(arch, ps)
487 ps.bandwidths = bws
488 ps.macs = macs
489 ps.cycles = cycles
490 ps.n_blocks = blocks
491 total_bws += bws
492 total_macs += macs
493 total_cycles += cycles
494
495 bws, macs, cycles = collate_stats_for_cascaded_pass(arch, total_bws, total_macs, total_cycles)
496 cps.bandwidths = bws
497 cps.macs = macs
498 cps.cycles = cycles
499 return bws, macs, cycles
500
501
502def calc_performance_for_network(nng, arch):
503 total_bws = make_bandwidth_array()
504 total_macs = np.zeros(MacCount.Size)
505 total_cycles = np.zeros(PassCycles.Size)
506
507 for sg in nng.subgraphs:
508 for cps in sg.cascaded_passes:
509 bws, macs, cycles = performance_for_cascaded_pass(arch, cps)
510 total_bws += bws
511 total_macs += macs
512 total_cycles += cycles
513 total_cycles += arch.inter_pass_cycle_delay
514
515 nng.bandwidths = total_bws
516 nng.macs = total_macs
517 nng.cycles = total_cycles