blob: 24b4c68a9f7a392ccbd03238f3b8dec6f87a69d7 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# NPU performance estimation functions to estimate performance of a Pass and CascadedPass. Uses a model that takes the
18# maximum of the 'cycles required for bandwidth' and 'cycles required for computing'.
19#
20# Called during scheduling to evaluate different proposals, as well as post-scheduling to provide a final performance
21# estimate.
Tim Hall79d07d22020-04-27 18:20:16 +010022import enum
Diego Russoea6111a2020-04-14 18:41:58 +010023
Tim Hall79d07d22020-04-27 18:20:16 +010024import numpy as np
Diego Russoea6111a2020-04-14 18:41:58 +010025
26from . import numeric_util
Diego Russoe8a10452020-04-21 17:39:10 +010027from .architecture_features import Block
Diqing Zhonge8887a32020-09-24 09:53:48 +020028from .architecture_features import SHRAMElements
29from .data_type import DataType
Diego Russoe8a10452020-04-21 17:39:10 +010030from .nn_graph import PassPlacement
31from .nn_graph import SchedulerRewrite
Diego Russoea6111a2020-04-14 18:41:58 +010032from .operation import NpuBlockType
Diqing Zhonge8887a32020-09-24 09:53:48 +020033from .operation import Op
Diego Russoe8a10452020-04-21 17:39:10 +010034from .tensor import MemArea
35from .tensor import shape_num_elements
36from .tensor import TensorBlockTraversal
37from .tensor import TensorPurpose
Tim Hall79d07d22020-04-27 18:20:16 +010038
39
40def rolling_buffer_dims_from_passes(arch, ps1, block_config_ps1, ps2, block_config_ps2):
Tim Hall79d07d22020-04-27 18:20:16 +010041 ofm_block = Block(block_config_ps2[-3], block_config_ps2[-4], block_config_ps2[-1])
Tim Hall4ed38bc2020-10-20 18:54:20 +010042 kernel = ps2.primary_op.kernel
Tim Hall79d07d22020-04-27 18:20:16 +010043
44 if ps2.npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct)):
Louis Verhaard93dc5532020-06-07 12:40:18 +020045 op = ps2.primary_op
Louis Verhaardaee5d752020-09-30 09:01:52 +020046 ifm_block_depth = arch.calc_ifm_block_depth(op.ifm.shape[-1], op.ifm.dtype.size_in_bits())
Tim Hall79d07d22020-04-27 18:20:16 +010047 else:
48 ifm_block_depth = block_config_ps2[-1]
49
Louis Verhaard93dc5532020-06-07 12:40:18 +020050 ifm_block = arch.get_ifm_block_size(ifm_block_depth, ofm_block, kernel, arch.ofm_block_max)
Tim Hall79d07d22020-04-27 18:20:16 +010051
52 # The performed height calculation is for worst case
53 height = numeric_util.round_up(ifm_block.height + block_config_ps1[0], block_config_ps1[0])
54 width = ifm_block.width
Louis Verhaard93dc5532020-06-07 12:40:18 +020055 return [height, width]
Tim Hall79d07d22020-04-27 18:20:16 +010056
57
58class PassCycles(enum.IntEnum):
59 Dpu = 0
60 ElementWise = 1
61 Cpu = 2
62 SramAccess = 3
63 TotalPerPass = 4
64 DramAccess = 5
65 OnChipFlashAccess = 6
66 OffChipFlashAccess = 7
67 Total = 8
68 Size = 9
69
70 def display_name(self):
71 return (
72 "DPU",
73 "Element wise",
74 "CPU",
75 "SRAM Access",
76 "Total per Pass",
77 "DRAM Access",
78 "On-chip Flash Access",
79 "Off-chip Flash Access",
80 "Total",
81 "Size",
82 )[self.value]
83
84 def identifier_name(self):
85 return (
86 "dpu",
87 "element_wise",
88 "cpu",
89 "sram_access",
90 "total_per_pass",
91 "dram_access",
92 "on_chip_flash_access",
93 "off_chip_flash_access",
94 "total",
95 "size",
96 )[self.value]
97
98 @staticmethod
99 def all():
100 return (
101 PassCycles.Dpu,
102 PassCycles.ElementWise,
103 PassCycles.Cpu,
104 PassCycles.SramAccess,
105 PassCycles.DramAccess,
106 PassCycles.OnChipFlashAccess,
107 PassCycles.OffChipFlashAccess,
108 PassCycles.Total,
109 )
110
111
112class MacCount(enum.IntEnum):
113 NeuralNetworkMacs = 0
114 HardwareMacs = 1
115 Size = 2
116
117 def display_name(self):
118 return ("Neural Network Macs", "Hardware Macs", "Size")[self.value]
119
120 def identifier_name(self):
121 return ("nn_macs", "hardware_macs", "size")[self.value]
122
123 @staticmethod
124 def all():
125 return (MacCount.NeuralNetworkMacs, MacCount.HardwareMacs)
126
127
128class BandwidthDirection(enum.IntEnum):
129 Read = 0
130 Write = 1
131 Size = 2
132
133 def display_name(self):
134 return self.name
135
136 def identifier_name(self):
137 return self.name.lower()
138
139 @staticmethod
140 def all():
141 return (BandwidthDirection.Read, BandwidthDirection.Write)
142
143
144def make_bandwidth_array():
145 return np.zeros((MemArea.Size, TensorPurpose.Size, BandwidthDirection.Size))
146
147
148def make_macs_array():
149 return np.zeros(MacCount.Size, np.int)
150
151
152def make_cycles_array():
153 return np.zeros(PassCycles.Size)
154
155
156def make_metrics_arrays():
157 return (make_bandwidth_array(), make_macs_array(), make_cycles_array())
158
159
160def get_n_blocks_and_area(
161 ifm_brick_size, ifm_height_width, orig_skirt, clamped_skirt, block_config, min_block_size, strides
162):
163
164 ifm_block_config = (block_config[0] * strides[1], block_config[1] * strides[2])
165
166 n_normal_blocks = []
167 remainder_size = []
168 for i in range(2):
169 non_skirt_dim = ifm_height_width[i] - orig_skirt[i] - orig_skirt[2 + i]
170 n_blocks = non_skirt_dim // ifm_block_config[i]
171 n_normal_blocks.append(n_blocks)
172 remainder_dim = numeric_util.round_up(
173 ((non_skirt_dim - n_blocks * ifm_block_config[i] - 1) // strides[i + 1]) + 1, min_block_size[i]
174 )
175 remainder_size.append(remainder_dim)
176
177 # this will actually calculate reads into the edge padding.
178
179 # there are four cases in total, handling the edges that will not fill a complete block.
180
181 # 0000000001
182 # 0000000001
183 # 0000000001
184 # 0000000001
185 # 0000000001
186 # 0000000001
187 # 2222222223
188 total_blocks = 0
189 total_area = 0
190
191 block_setup = (
192 (n_normal_blocks[0] * n_normal_blocks[1], block_config),
193 (1 * n_normal_blocks[1], (remainder_size[0], block_config[1])),
194 (n_normal_blocks[0] * 1, (block_config[0], remainder_size[1])),
195 (1 * 1, remainder_size),
196 )
197
198 for n_blocks, block_size in block_setup:
199 if block_size[0] == 0 or block_size[1] == 0:
200 continue
201 read_dims = [0, 0]
202 for i in range(2):
203 read_dims[i] = (
204 numeric_util.round_up(clamped_skirt[i], ifm_brick_size[i + 1])
205 + block_size[i] * strides[i + 1]
206 + numeric_util.round_up(clamped_skirt[2 + i], ifm_brick_size[i + 1])
207 )
208 assert n_blocks >= 0
209 total_blocks += n_blocks
210 total_area += n_blocks * read_dims[0] * read_dims[1]
211 assert total_blocks >= 1
212 return total_blocks, total_area, block_setup
213
214
Diqing Zhonge8887a32020-09-24 09:53:48 +0200215def get_output_cycle_estimate(arch, ps):
216 primary_op = ps.primary_op
217 assert primary_op
218 npu_block_type = primary_op.type.npu_block_type
219 faf = primary_op.activation
220
221 if npu_block_type == NpuBlockType.ElementWise and ps.ifm_tensor.dtype == DataType.int32:
222 if ps.ifm2_tensor is None:
223 # Unary op
224 output_perf_index = 0
225 else:
226 # Binary op
227 output_perf_index = 1
228 elif ps.primary_op.type == Op.Mul and ps.ofm_tensor.dtype == DataType.int32:
229 output_perf_index = 2
230 elif ps.primary_op.type == Op.Mul or (
231 npu_block_type
232 in (
233 NpuBlockType.ConvolutionMxN,
234 NpuBlockType.ConvolutionDepthWise,
235 NpuBlockType.Pooling,
236 NpuBlockType.ReduceSum,
237 NpuBlockType.VectorProduct,
238 )
239 and ps.shared_buffer.use_accumulator_element == SHRAMElements.Acc40
240 ):
241 output_perf_index = 3
242 elif ps.primary_op.type in (Op.Add, Op.Sub):
243 input_scale = ps.ifm_tensor.quantization.scale_f32
244 input2_scale = ps.ifm2_tensor.quantization.scale_f32
245 output_scale = ps.ofm_tensor.quantization.scale_f32
246
247 if "resizebilinear" in primary_op.attrs:
248 output_scale = input2_scale
249
250 if None in (input_scale, input2_scale, output_scale) or input_scale == input2_scale:
251 # Simple Add/Sub
252 output_perf_index = 4
253 else:
254 # Advanced Add/Sub
255 output_perf_index = 5
256 elif ps.primary_op.type.is_maxpool_op():
257 output_perf_index = 6
258 else:
259 output_perf_index = 7
260
261 if faf in (Op.Sigmoid, Op.Tanh, Op.LUT):
262 activation_perf_index = 0
263 elif faf in (Op.Relu, Op.Relu6, Op.ReluN1To1):
264 activation_perf_index = 1
265 else:
266 activation_perf_index = 2
267
268 num_elems = ps.outputs[0].elements()
269 cycle_per_elem = max(
270 arch.output_cycles_per_elem[output_perf_index], arch.activation_cycles_per_elem[activation_perf_index]
271 )
272 return num_elems * cycle_per_elem
273
274
Tim Hall79d07d22020-04-27 18:20:16 +0100275def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], force_outputs_to_fast_storage=False):
276 if block_config is None:
277 block_config = ps.block_config
278 bws = make_bandwidth_array()
279 macs = make_macs_array()
280 cycles = make_cycles_array()
281 blocks = 0
282 ifm_read_multiple = 1
283 weight_read_multiple = 0
284
285 if ps.placement in set((PassPlacement.MemoryOnly, PassPlacement.StartupInit)):
286 return bws, macs, cycles, blocks, ifm_read_multiple, weight_read_multiple # nothing real happening in this pass
287
288 min_block_size = arch.min_block_sizes[ps.npu_block_type]
289
290 skirt = (0, 0, 0, 0)
291 explicit_padding = (0, 0, 0, 0)
292 primary_op = ps.primary_op
293 replacement_read_bws = {}
Charles Xub02c8d92020-06-25 16:05:25 +0200294 if ps.placement == PassPlacement.Cpu:
295 cycles[PassCycles.Cpu] = arch.cpu_cycle_estimate(ps.ops[0])
296 elif primary_op:
Tim Hall79d07d22020-04-27 18:20:16 +0100297 skirt = primary_op.attrs.get("skirt", skirt)
298 explicit_padding = primary_op.attrs.get("explicit_padding", explicit_padding)
Louis Verhaardaee5d752020-09-30 09:01:52 +0200299 assert primary_op.type.npu_block_type == ps.npu_block_type
300 npu_block_type = primary_op.type.npu_block_type
Tim Hall79d07d22020-04-27 18:20:16 +0100301
302 ifm_tensor, _, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm()
303
Tim Hallc30f4952020-06-15 20:47:35 +0100304 if npu_block_type in set(
305 (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling)
306 ):
Charles Xu3e9c4342020-04-22 08:31:43 +0200307 # extent the ifm to full dimension
308 ifm_tensor_brick_size = tuple(numeric_util.full_shape(4, list(ifm_tensor.brick_size), 1))
309 ifm_tensor_shape = numeric_util.full_shape(4, ifm_tensor.shape, 1)
310 ifm_tensor_bandwidth_shape = numeric_util.full_shape(4, ifm_tensor.bandwidth_shape, 1)
Tim Hall79d07d22020-04-27 18:20:16 +0100311
312 batch_size = ifm_tensor.shape[0]
Charles Xu3e9c4342020-04-22 08:31:43 +0200313 ifm_depth = ifm_tensor_bandwidth_shape[3]
Tim Hall79d07d22020-04-27 18:20:16 +0100314
315 # add in padding
316 ifm_tensor_shape[1] += explicit_padding[0] + explicit_padding[2] # height += top and bottom
317 ifm_tensor_shape[2] += explicit_padding[1] + explicit_padding[3] # width += left and right
318
319 strides = primary_op.attrs["strides"]
320 if npu_block_type != NpuBlockType.Pooling:
321 weight_tensor_shape = weight_tensor.shape
322 weight_tensor_bandwidth_shape = weight_tensor.bandwidth_shape
323 weight_tensor_element_size = weight_tensor.element_size()
324 weight_tensor_bandwidth_compression_scale = weight_tensor.bandwidth_compression_scale
325 nn_ops = (
326 int(ofm_tensor.shape[0])
327 * int(ofm_tensor.shape[1])
328 * int(ofm_tensor.shape[2])
329 * int(weight_tensor_shape[0])
330 * int(weight_tensor_shape[1])
331 * int(weight_tensor_shape[2])
332 * int(weight_tensor_shape[3])
Tim Hall79d07d22020-04-27 18:20:16 +0100333 )
334 else:
335 weight_tensor_shape = [
336 primary_op.attrs["ksize"][1],
337 primary_op.attrs["ksize"][2],
338 1,
339 ifm_tensor_shape[3],
340 ]
341 weight_tensor_bandwidth_shape = weight_tensor_shape
342 weight_tensor_element_size = 0
343 weight_tensor_bandwidth_compression_scale = 0.0
344 nn_ops = 0 # pooling doesn't count as NN ops
345
346 kernel_dims = weight_tensor_shape[:2]
347
348 sub_kernel_limits = arch.sub_kernel_limits[npu_block_type]
349 # count the sub kernels; the IFM block needs to be refetched for each of them
350 n_sub_kernels_y = numeric_util.round_up_divide(kernel_dims[0], sub_kernel_limits[0])
351 n_sub_kernels_x = numeric_util.round_up_divide(kernel_dims[1], sub_kernel_limits[1])
352 n_sub_kernels = n_sub_kernels_y * n_sub_kernels_x
353
354 clamped_skirt = list(skirt)
355 clamped_skirt[2] = min(clamped_skirt[2], sub_kernel_limits[0] - 1 - clamped_skirt[0])
356 clamped_skirt[3] = min(clamped_skirt[3], sub_kernel_limits[1] - 1 - clamped_skirt[1])
357 n_blocks, area, block_setup = get_n_blocks_and_area(
Charles Xu3e9c4342020-04-22 08:31:43 +0200358 ifm_tensor_brick_size,
Tim Hall79d07d22020-04-27 18:20:16 +0100359 ifm_tensor_shape[1:3],
360 skirt,
361 clamped_skirt,
362 block_config,
363 min_block_size,
364 strides,
365 )
366
367 blocks = n_blocks * numeric_util.round_up_divide(weight_tensor_shape[3], block_config[3])
368
369 n_weight_stages = numeric_util.round_up_divide(weight_tensor_bandwidth_shape[3], block_config[3])
370 if npu_block_type == NpuBlockType.ConvolutionDepthWise or npu_block_type == NpuBlockType.Pooling:
371 n_weight_stages = 1 # force to no reread
372
373 ifm_tensor_bw = (
374 n_sub_kernels
375 * batch_size
376 * area
377 * ifm_depth
378 * n_weight_stages
379 * ifm_tensor.element_size()
380 * ifm_tensor.bandwidth_compression_scale
381 )
382 replacement_read_bws[ifm_tensor] = ifm_tensor_bw
383 ifm_read_multiple = n_weight_stages
384
385 replacement_read_bws[weight_tensor] = (
386 batch_size
387 * shape_num_elements(weight_tensor_bandwidth_shape)
388 * weight_tensor_element_size
389 * weight_tensor_bandwidth_compression_scale
390 * n_blocks
391 ) # read once per block and batch
392 weight_read_multiple = n_blocks
393
394 n_kernel_xy = kernel_dims[0] * kernel_dims[1]
395 n_input_channels_at_a_time = block_config[2]
396
397 if npu_block_type == NpuBlockType.Pooling or weight_tensor.block_traversal in set(
398 (TensorBlockTraversal.PartKernelFirst, TensorBlockTraversal.DepthWise)
399 ):
400 n_input_channels_at_a_time = numeric_util.round_up_divide(n_input_channels_at_a_time, 4)
401 n_kernel_xy = max(
402 n_kernel_xy, 4
403 ) # need at least 4, as this is the minimum duty cycle for secondary accumulator writes
404 if weight_tensor is not None:
Diego Russoea6111a2020-04-14 18:41:58 +0100405 n_kernel_xy = numeric_util.round_up(n_kernel_xy, 4) # weights need to be read in blocks of 4
Tim Hall79d07d22020-04-27 18:20:16 +0100406
407 num_mac_ops = 0
408 for n_blocks_for_size, block_size in block_setup:
409 num_mac_ops += (
410 batch_size
411 * n_blocks_for_size
412 * block_size[0]
413 * block_size[1]
414 * numeric_util.round_up(weight_tensor_shape[2], n_input_channels_at_a_time)
415 * numeric_util.round_up(weight_tensor_shape[3], block_config[3])
416 * n_kernel_xy
417 )
418
419 if npu_block_type == NpuBlockType.Pooling:
420 # TODO: improve pooling estimation
421 cycles[PassCycles.Dpu] = num_mac_ops / arch.num_macs_per_cycle / 2
422 else:
423 cycles[PassCycles.Dpu] = num_mac_ops / arch.num_macs_per_cycle
424 macs[MacCount.NeuralNetworkMacs] += nn_ops
425 macs[MacCount.HardwareMacs] += num_mac_ops
426
427 elif npu_block_type == NpuBlockType.VectorProduct:
428 nn_macs = (
429 ifm_tensor.shape[0]
430 * numeric_util.round_up(weight_tensor.shape[-2], block_config[2])
431 * numeric_util.round_up(weight_tensor.shape[-1], block_config[3])
432 )
433 num_mac_ops = nn_macs
434
435 cycles[PassCycles.Dpu] = num_mac_ops / arch.num_macs_per_cycle
436 macs[MacCount.NeuralNetworkMacs] += nn_macs
437 macs[MacCount.HardwareMacs] += num_mac_ops
438
439 blocks = 1 * numeric_util.round_up_divide(weight_tensor.shape[-1], block_config[3])
440
441 non_zero_fraction = 1.0
442 if ifm_tensor.values is not None:
443 nz_vector = np.amax(ifm_tensor.values != 0, axis=0) # max across batch axis
444 non_zero_fraction = np.average(nz_vector)
445
446 replacement_read_bws[ifm_tensor] = ifm_tensor.bandwidth()
447 replacement_read_bws[weight_tensor] = weight_tensor.bandwidth() * non_zero_fraction
448 ifm_read_multiple = 1
449 weight_read_multiple = non_zero_fraction
Diqing Zhonge8887a32020-09-24 09:53:48 +0200450 elif npu_block_type == NpuBlockType.ElementWise:
Tim Hall79d07d22020-04-27 18:20:16 +0100451 # Work out how many elements we have and calculate performance.
Diqing Zhonge8887a32020-09-24 09:53:48 +0200452 cycles[PassCycles.ElementWise] = get_output_cycle_estimate(arch, ps)
Tim Hall79d07d22020-04-27 18:20:16 +0100453
Tim Hall79d07d22020-04-27 18:20:16 +0100454 # apply the desired rewrites
455 for rewrite_op, tens, _, _, _, ps_to_rewrite in rewrite_list:
456 if ps != ps_to_rewrite:
457 continue
458 if rewrite_op == SchedulerRewrite.Nop:
459 pass # these are fine, no bandwidth changes
460 elif rewrite_op in (SchedulerRewrite.ChangeTensorSubPurpose,):
461 bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Read] += replacement_read_bws[tens]
462 replacement_read_bws[tens] = 0
463
464 for tens in ps.outputs:
465 if force_outputs_to_fast_storage:
466 bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
467 else:
468 bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
469
470 for tens in ps.intermediates:
471 bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
472
473 if tens in replacement_read_bws:
474 bw = replacement_read_bws[tens]
475 else:
476 bw = tens.bandwidth()
477
478 bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
479
480 for tens in ps.inputs:
481 if tens in replacement_read_bws:
482 bw = replacement_read_bws[tens]
483 else:
484 bw = tens.bandwidth()
485
486 bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
487
488 cycles[PassCycles.SramAccess] = np.sum(bws[MemArea.Sram]) / arch.memory_bandwidths_per_cycle[MemArea.Sram]
489 cycles[PassCycles.TotalPerPass] = np.max(cycles[: PassCycles.TotalPerPass])
490
491 # quick build access counts for only current pass, even though these aren't the final numbers
492 update_summary_cycles(arch, bws, macs, cycles)
493
494 return bws, macs, cycles, blocks, ifm_read_multiple, weight_read_multiple
495
496
497def update_summary_cycles(arch, bws, macs, cycles):
498 cycles[PassCycles.DramAccess] = np.sum(bws[MemArea.Dram]) / arch.memory_bandwidths_per_cycle[MemArea.Dram]
499 cycles[PassCycles.OnChipFlashAccess] = (
500 np.sum(bws[MemArea.OnChipFlash]) / arch.memory_bandwidths_per_cycle[MemArea.OnChipFlash]
501 )
502 cycles[PassCycles.OffChipFlashAccess] = (
503 np.sum(bws[MemArea.OffChipFlash]) / arch.memory_bandwidths_per_cycle[MemArea.OffChipFlash]
504 )
505
506 cycles[PassCycles.Total] = np.max(cycles[: PassCycles.Total])
507 return cycles
508
509
510def collate_stats_for_cascaded_pass(arch, bws, macs, cycles):
511 return bws, macs, cycles
512
513
514def performance_for_cascaded_pass(arch, cps):
515 total_bws = make_bandwidth_array()
516 total_macs = make_macs_array()
517 total_cycles = make_cycles_array()
518
519 for ps in cps.passes:
520 bws, macs, cycles, blocks, _, _ = performance_metrics_for_pass(arch, ps)
521 ps.bandwidths = bws
522 ps.macs = macs
523 ps.cycles = cycles
524 ps.n_blocks = blocks
525 total_bws += bws
526 total_macs += macs
527 total_cycles += cycles
528
529 bws, macs, cycles = collate_stats_for_cascaded_pass(arch, total_bws, total_macs, total_cycles)
530 cps.bandwidths = bws
531 cps.macs = macs
532 cps.cycles = cycles
533 return bws, macs, cycles
534
535
536def calc_performance_for_network(nng, arch):
537 total_bws = make_bandwidth_array()
538 total_macs = np.zeros(MacCount.Size)
539 total_cycles = np.zeros(PassCycles.Size)
540
541 for sg in nng.subgraphs:
542 for cps in sg.cascaded_passes:
543 bws, macs, cycles = performance_for_cascaded_pass(arch, cps)
544 total_bws += bws
545 total_macs += macs
546 total_cycles += cycles
Tim Hall79d07d22020-04-27 18:20:16 +0100547
548 nng.bandwidths = total_bws
549 nng.macs = total_macs
550 nng.cycles = total_cycles