blob: 83a3dda631e1c67403e783cde313b57d0baaaad8 [file] [log] [blame]
Louis Verhaard1a92f782021-02-09 16:08:26 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Diqing Zhong94457b12020-12-09 15:22:40 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Unit tests for graph_optimiser
19import numpy as np
Louis Verhaardebf4af62021-01-27 15:57:57 +010020import pytest
Diqing Zhong94457b12020-12-09 15:22:40 +010021
Louis Verhaardae2d5532020-12-11 17:19:54 +010022from ethosu.vela.data_type import DataType
Louis Verhaardebf4af62021-01-27 15:57:57 +010023from ethosu.vela.graph_optimiser import calc_explicit_padding
Diqing Zhong94457b12020-12-09 15:22:40 +010024from ethosu.vela.graph_optimiser import convert_batched_fc_shape
Patrik Gustavsson3a269202021-01-21 08:28:55 +010025from ethosu.vela.graph_optimiser import optimise_graph_a
Louis Verhaardc822d622021-03-11 14:59:06 +010026from ethosu.vela.graph_optimiser import replace_pad_by_hw_pad
Patrik Gustavsson2c2522d2021-01-29 11:51:31 +010027from ethosu.vela.graph_optimiser import rewrite_fully_connected_input
Louis Verhaardae2d5532020-12-11 17:19:54 +010028from ethosu.vela.nn_graph import Graph
Diqing Zhong94457b12020-12-09 15:22:40 +010029from ethosu.vela.operation import Op
Louis Verhaardae2d5532020-12-11 17:19:54 +010030from ethosu.vela.operation import Padding
Patrik Gustavsson3a269202021-01-21 08:28:55 +010031from ethosu.vela.rewrite_graph import verify_graph_health
Diqing Zhong94457b12020-12-09 15:22:40 +010032from ethosu.vela.tensor import create_const_tensor
patrik.gustavssoneeb85152020-12-21 17:10:40 +000033from ethosu.vela.tensor import Shape4D
Diqing Zhong94457b12020-12-09 15:22:40 +010034from ethosu.vela.tensor import Tensor
35from ethosu.vela.test import testutil
36
37
38def test_convert_batched_fc():
39 """Tests shape conversion of batched fully connected"""
Patrik Gustavsson3a269202021-01-21 08:28:55 +010040 ifm_shape = [4, 8]
41 ifm = create_const_tensor("test_in", ifm_shape, np.uint8, np.zeros(ifm_shape))
42 w_shape = [8, 4]
43 weights = create_const_tensor("weight_in", w_shape, np.uint8, np.zeros(w_shape))
Diqing Zhong94457b12020-12-09 15:22:40 +010044 ofm = Tensor(ifm.shape, np.uint8, "test_out")
45 op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm)
Patrik Gustavsson2349d422020-12-01 16:02:29 +010046
Diqing Zhong94457b12020-12-09 15:22:40 +010047 ifm.consumer_list.append(op)
48
49 prev_op = op.clone()
Patrik Gustavsson3a269202021-01-21 08:28:55 +010050 prev_op.ifm_shapes = op.ifm_shapes.copy()
51 prev_op.ofm_shapes = op.ofm_shapes.copy()
Patrik Gustavsson2349d422020-12-01 16:02:29 +010052
Patrik Gustavsson2c2522d2021-01-29 11:51:31 +010053 rewrite_fully_connected_input(op, None, None)
Diqing Zhong94457b12020-12-09 15:22:40 +010054 conv_op = convert_batched_fc_shape(op, None, None)
Diqing Zhong94457b12020-12-09 15:22:40 +010055 assert conv_op.ifm == prev_op.ifm
56 assert conv_op.ofm == prev_op.ofm
Patrik Gustavsson3a269202021-01-21 08:28:55 +010057 assert op.ifm_shapes[0] == Shape4D([1, 2, 2, 8])
58 assert op.ofm_shapes[0] == Shape4D([1, 2, 2, 8])
Diqing Zhong94457b12020-12-09 15:22:40 +010059 assert conv_op.type == Op.FullyConnected
60 assert len(conv_op.ifm.shape) == 2
Patrik Gustavsson3a269202021-01-21 08:28:55 +010061 assert len(conv_op.ofm.shape) == 2
62 assert conv_op.ifm.shape == conv_op.ofm.shape
63
64 ifm.shape = [1, 8]
65 weights.shape = [8, 1]
66 ofm.shape = [1, 8]
67 op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm)
68 ifm.consumer_list.append(op)
69
70 prev_op = op.clone()
71 prev_op.ifm_shapes = op.ifm_shapes.copy()
72 prev_op.ofm_shapes = op.ofm_shapes.copy()
73
Patrik Gustavsson2c2522d2021-01-29 11:51:31 +010074 rewrite_fully_connected_input(op, None, None)
Patrik Gustavsson3a269202021-01-21 08:28:55 +010075 conv_op = convert_batched_fc_shape(op, None, None)
76
77 assert conv_op.ifm == prev_op.ifm
78 assert conv_op.ofm == prev_op.ofm
79 assert op.ifm_shapes[0] == prev_op.ifm_shapes[0]
80 assert op.ofm_shapes[0] == prev_op.ofm_shapes[0]
81 assert conv_op.type == Op.FullyConnected
82 assert len(conv_op.ifm.shape) == 2
83 assert len(conv_op.ofm.shape) == 2
Diqing Zhong94457b12020-12-09 15:22:40 +010084 assert conv_op.ifm.shape == conv_op.ofm.shape
Louis Verhaardae2d5532020-12-11 17:19:54 +010085
86
Louis Verhaardebf4af62021-01-27 15:57:57 +010087explicit_padding_test_data = [
88 # Kernel size 2
89 [(17, 1, 2, 1, 1), (1, 1)],
90 [(18, 1, 2, 0, 1), (0, 1)],
91 [(18, 1, 2, 1, 0), (1, 0)],
92 # Kernel size 3
93 [(18, 2, 3, 1, 1), (1, 0)],
94 [(25, 2, 3, 1, 1), (1, 1)],
95 # Kernel size 4
96 [(18, 1, 4, 1, 2), (1, 2)],
97 [(18, 1, 4, 2, 1), (2, 1)],
98 [(19, 1, 4, 2, 2), (2, 2)],
99 # Kernel size 5
100 [(19, 1, 5, 1, 2), (1, 2)],
101 [(19, 1, 5, 0, 2), (0, 2)],
102 [(19, 1, 5, 1, 0), (1, 0)],
103 # Kernel size 21
104 [(41, 2, 21, 8, 10), (8, 10)],
105 [(41, 3, 21, 10, 10), (10, 9)],
106 [(42, 3, 21, 10, 10), (10, 8)],
107 [(42, 3, 21, 9, 10), (9, 9)],
108 [(41, 3, 21, 10, 6), (10, 6)],
109]
110
111
112@pytest.mark.parametrize("test_input, expected_result", explicit_padding_test_data)
113def test_calc_explicit_padding(test_input, expected_result):
114 input_size, stride, filter_size, explicit_pad_before, explicit_pad_after = test_input
115 before, after = calc_explicit_padding(input_size, stride, filter_size, explicit_pad_before, explicit_pad_after)
116 assert (before, after) == expected_result
117
118
Louis Verhaardc822d622021-03-11 14:59:06 +0100119def create_pad_and_conv2d(
120 in_shape,
121 out_shape,
122 padding,
123 in_dtype=DataType.int8,
124 out_dtype=DataType.int8,
125 pad_dtype=DataType.int32,
126 pad_setting=Padding.VALID,
127 kernel_size=3,
128):
129 """Creates Pad operator followed by a conv2d operator"""
130 qp = testutil.default_quant_params()
131 in0 = Tensor(in_shape, in_dtype, "in")
132 in0.quantization = qp
133 pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
134 out = Tensor(out_shape, out_dtype, "out")
135 out.quantization = qp.clone()
136 op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
137 op.run_on_npu = True
138 conv_out_tens = Tensor(in_shape, in_dtype, "output")
139 conv_out_tens.quantization = qp.clone()
140 weight_tens = Tensor([kernel_size, kernel_size, in_shape[-1], out_shape[-1]], in_dtype, "weights")
141 weight_tens.values = np.zeros(weight_tens.shape)
142 weight_tens.quant_values = np.zeros(weight_tens.shape, np.int8)
143 weight_tens.quantization = qp.clone()
144 bias_tens = Tensor(out_shape, pad_dtype, "biases")
145 attrs = {"padding": pad_setting, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1}
146 attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
147 conv2d_op = testutil.create_op(Op.Conv2DBias, [out, weight_tens, bias_tens], conv_out_tens, attrs)
148 conv2d_op.add_input_tensor(out)
149 conv2d_op.run_on_npu = True
150 return op, conv2d_op
151
152
153def test_pad_followed_by_conv_is_removed():
Louis Verhaardae2d5532020-12-11 17:19:54 +0100154 """
155 Tests that the PAD operator is bypassed when followed by a convolution operator,
156 and that the padding of the convolution operation is correctly updated
157 """
Louis Verhaardc822d622021-03-11 14:59:06 +0100158 pad_op, conv2d_op = create_pad_and_conv2d(
159 in_shape=[1, 76, 75, 64], out_shape=[1, 76, 75, 64], padding=[[0, 0], [2, 1], [1, 1], [0, 0]], kernel_size=4
160 )
161 nng = testutil.create_graph([pad_op, conv2d_op])
Louis Verhaardae2d5532020-12-11 17:19:54 +0100162 arch = testutil.create_arch()
163
Louis Verhaardc822d622021-03-11 14:59:06 +0100164 replace_pad_by_hw_pad(conv2d_op, nng, arch)
Louis Verhaardae2d5532020-12-11 17:19:54 +0100165
Louis Verhaardc822d622021-03-11 14:59:06 +0100166 op = nng.subgraphs[0].output_tensors[0].ops[0]
167 assert op.type == Op.Conv2DBias
Louis Verhaardae2d5532020-12-11 17:19:54 +0100168 assert op.attrs["padding"] == Padding.EXPLICIT
169 assert op.attrs["explicit_padding"] == (2, 1, 1, 1)
170 assert op.ifm.shape == [1, 76, 75, 64]
171 assert pad_op not in op.ifm.ops
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100172
173
Louis Verhaardc822d622021-03-11 14:59:06 +0100174leading_pad_test_data = [
175 (2, 2, 11, True),
176 (1, 2, 11, False),
177 (2, 1, 11, False),
178 (5, 2, 11, True),
179]
180
181
182@pytest.mark.parametrize("top, left, kernel_size, expect_pad_removed", leading_pad_test_data)
183def test_leading_pad_size(top, left, kernel_size, expect_pad_removed):
184 # Tests PAD operator with big kernel size; top and left pad must be multiple of stride
185 out_shape = [1, 11 + left, 11 + top, 1]
186 padding = [[0, 0], [top, 0], [left, 0], [0, 0]]
187 pad_op, conv2d_op = create_pad_and_conv2d(
188 in_shape=[1, 11, 11, 1], out_shape=out_shape, padding=padding, kernel_size=kernel_size
189 )
190 nng = testutil.create_graph([pad_op, conv2d_op])
191 arch = testutil.create_arch()
192 replace_pad_by_hw_pad(conv2d_op, nng, arch)
193 op = nng.subgraphs[0].output_tensors[0].ops[0]
194 if expect_pad_removed:
195 assert op.attrs["padding"] == Padding.EXPLICIT
196 assert "explicit_padding" in op.attrs
197 assert op.ifm.shape == op.ofm.shape
198 assert pad_op not in op.ifm.ops
199 else:
200 assert pad_op in op.ifm.ops
201 assert op.attrs["padding"] == Padding.VALID
202 assert "explicit_padding" not in op.attrs
203
204
Louis Verhaard1a92f782021-02-09 16:08:26 +0100205def test_optimise_pad_followed_by_avg_pool():
206 """
207 Tests that the PAD operator is bypassed when followed by a average pool operator,
208 and that the average pool is converted to a depthwise
209 """
210 # Create Pad operation followed by AvgPool
211 quant = testutil.default_quant_params()
212 in_tens = Tensor([1, 76, 75, 64], DataType.uint8, "input")
213 in_tens.quantization = quant
Louis Verhaardc822d622021-03-11 14:59:06 +0100214 # Test with 3x2 input tensor
215 pad_input = create_const_tensor("pad_input", [3, 2], DataType.int32, [[2, 2], [1, 1], [0, 0]])
Louis Verhaard1a92f782021-02-09 16:08:26 +0100216 temp_tens = Tensor([1, 79, 77, 64], DataType.uint8, "pad_out")
217 temp_tens.quantization = quant.clone()
218 out_tens = Tensor([1, 76, 75, 64], DataType.uint8, "output")
219 out_tens.quantization = quant.clone()
220
221 pad_op = testutil.create_op(Op.Pad, [in_tens, pad_input], temp_tens)
222 attrs = {
223 "padding": Padding.VALID,
224 "ksize": [1, 5, 3, 1],
225 "stride_w": 2,
226 "stride_h": 2,
227 "dilation_w_factor": 1,
228 "dilation_h_factor": 1,
229 }
230 attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
231 pad_op.run_on_npu = True
232 conv2d_op = testutil.create_op(Op.AvgPool, [temp_tens], out_tens, attrs)
233 conv2d_op.run_on_npu = True
Louis Verhaardc822d622021-03-11 14:59:06 +0100234 nng = testutil.create_graph([pad_op, conv2d_op])
Louis Verhaard1a92f782021-02-09 16:08:26 +0100235 arch = testutil.create_arch()
236
Louis Verhaardc822d622021-03-11 14:59:06 +0100237 replace_pad_by_hw_pad(conv2d_op, nng, arch)
Louis Verhaard1a92f782021-02-09 16:08:26 +0100238
Louis Verhaardc822d622021-03-11 14:59:06 +0100239 op = nng.subgraphs[0].output_tensors[0].ops[0]
Louis Verhaard1a92f782021-02-09 16:08:26 +0100240 assert op.type == Op.DepthwiseConv2DBias
241 assert op.attrs["padding"] == Padding.EXPLICIT
Louis Verhaardc822d622021-03-11 14:59:06 +0100242 assert op.attrs["explicit_padding"] == (2, 1, 2, 1)
Louis Verhaard1a92f782021-02-09 16:08:26 +0100243 assert op.ifm.shape == [1, 76, 75, 64]
244 assert pad_op not in op.ifm.ops
245 # Check that bias and weight tensors have been added
246 assert op.bias.shape == [64]
Louis Verhaard1a92f782021-02-09 16:08:26 +0100247 assert op.weights.shape == [5, 3, 1, 64]
248
249
Louis Verhaardc822d622021-03-11 14:59:06 +0100250pad_avg_pool_test_data = [
251 ((3, 3), (1, 1, 1, 1), True),
252 ((3, 3), (2, 1, 1, 1), False),
253 ((3, 3), (1, 2, 1, 1), False),
254 ((3, 3), (1, 1, 2, 1), False),
255 ((3, 3), (1, 1, 1, 2), False),
256 ((2, 4), (1, 2, 1, 2), True),
257 ((5, 3), (2, 1, 2, 1), True),
258 ((5, 3), (0, 1, 2, 1), True),
259 ((5, 3), (2, 0, 2, 1), True),
260 ((5, 3), (2, 1, 0, 1), True),
261 ((5, 3), (2, 1, 0, 1), True),
262 ((4, 4), (2, 2, 2, 2), True),
263 ((4, 4), (1, 2, 2, 2), False),
264 ((4, 4), (2, 1, 2, 2), False),
265 ((4, 4), (2, 2, 1, 2), False),
266 ((4, 4), (2, 2, 2, 1), False),
267]
268
269
270@pytest.mark.parametrize("k_size, padding, expect_pad_removed", pad_avg_pool_test_data)
271def test_pad_followed_by_avg_pool(k_size, padding, expect_pad_removed):
272 # Tests PAD followed by AvgPool
273 k_w, k_h = k_size
274 top, left, bottom, right = padding
275 pad_values = [[0, 0], [top, bottom], [left, right], [0, 0]]
276 dtype = DataType.int8
277 qp = testutil.default_quant_params()
278 in_shape = [1, 15, 17, 8]
279 out_shape = [1, in_shape[1] + top + bottom, in_shape[2] + left + right, in_shape[3]]
280 in0 = Tensor(in_shape, dtype, "in")
281 in0.quantization = qp
282 pad_tensor = create_const_tensor(
283 name="pad", shape=list(np.shape(pad_values)), values=pad_values, dtype=DataType.int32
284 )
285 out = Tensor(out_shape, dtype, "out")
286 out.quantization = qp.clone()
287 pad_op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
288 pool_out_tens = Tensor(in_shape, dtype, "output")
289 pool_out_tens.quantization = qp.clone()
290 attrs = {
291 "padding": Padding.VALID,
292 "ksize": [1, k_w, k_h, 1],
293 "stride_w": 1,
294 "stride_h": 1,
295 "dilation_w_factor": 1,
296 "dilation_h_factor": 1,
297 }
298 pool_op = testutil.create_op(Op.AvgPool, [out], pool_out_tens, attrs)
Louis Verhaardc822d622021-03-11 14:59:06 +0100299 pad_op.run_on_npu = True
300 pool_op.run_on_npu = True
301 nng = testutil.create_graph([pad_op, pool_op])
302 arch = testutil.create_arch()
303 nng = optimise_graph_a(nng, arch)
304 sg = nng.subgraphs[0]
305 all_ops = sg.get_all_ops()
306 print("all_ops: ", all_ops)
307 # Pad should not be in the graph anymore, it should either have been removed or rewritten
308 assert not any(op.type == Op.Pad for op in all_ops)
309 op = nng.subgraphs[0].output_tensors[0].ops[0]
310 if expect_pad_removed:
311 # Expect rewrite to depthwise, PAD is removed
312 assert op.type == Op.DepthwiseConv2DBias
313 assert op.attrs["padding"] == Padding.EXPLICIT
314 assert any(pad > 0 for pad in op.attrs["explicit_padding"])
315 assert op.ifm.shape == op.ofm.shape
316 # Check that bias and weight tensors have been added
317 assert len(op.bias.shape) > 0
318 assert op.weights.shape is not None
319 else:
320 # Pad should have been rewritten to a number of average pool operations
321 assert all(op.type in (Op.AvgPool, Op.Const) for op in all_ops)
322 assert pool_op.type == Op.AvgPool
323 assert pool_op.attrs["padding"] == Padding.VALID
324
325
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100326def test_remove_reshape():
327 """
328 Tests that the expected reshape are removed in graph_optimisation
329 """
330
331 def setup_network():
332 quant = testutil.default_quant_params()
333 # create reshape1 op
334 ifm_shape = [64, 16]
335 reshape1_ofm_shape = [1, 4, 16, 16]
336 reshape1_ifm = create_const_tensor("reshape1_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
337 reshape1_ifm.quantization = quant
338 reshape1_ofm = create_const_tensor(
339 "reshape1_out", reshape1_ofm_shape, DataType.uint8, np.zeros(reshape1_ofm_shape)
340 )
341 reshape1_ofm.quantization = quant
342 shape_tens = create_const_tensor("reshape1_shape", [1], DataType.int32, reshape1_ofm_shape)
343 reshape1_op = testutil.create_op(Op.Reshape, [reshape1_ifm, shape_tens], reshape1_ofm, set_ifm_ofm_shapes=False)
344 reshape1_op.attrs["new_shape"] = reshape1_ofm_shape
345 reshape1_op.run_on_npu = True
346
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100347 # create conv op
348 conv_ofm = Tensor([1, 8, 8, 16], DataType.uint8, "output")
349 conv_ofm.quantization = quant.clone()
350 weight_tens = Tensor([1, 1, 16, 16], DataType.uint8, "weights")
351 weight_tens.values = np.zeros(weight_tens.shape)
352 weight_tens.quant_values = np.zeros(weight_tens.shape, np.uint8)
353 weight_tens.quantization = quant.clone()
354 bias_tens = Tensor([16], DataType.int32, "biases")
355
356 attrs = {"padding": Padding.SAME, "stride_w": 1, "stride_h": 1, "dilation_w_factor": 1, "dilation_h_factor": 1}
357 attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
358
359 conv2d_op = testutil.create_op(
360 Op.Conv2D, [reshape1_ofm, weight_tens, bias_tens], conv_ofm, attrs=attrs, set_ifm_ofm_shapes=False
361 )
362 conv2d_op.run_on_npu = True
363
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100364 # create reshape2 op
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100365 ofm_shape = [8, 8, 16]
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100366 reshape2_ofm = create_const_tensor("reshape2_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
367 reshape2_ofm.quantization = quant
368 shape_tens = create_const_tensor("reshape2_shape", [1], DataType.int32, ofm_shape)
369 reshape2_op = testutil.create_op(Op.Reshape, [conv_ofm, shape_tens], reshape2_ofm, set_ifm_ofm_shapes=False)
370 reshape2_op.attrs["new_shape"] = ofm_shape
371 reshape2_op.run_on_npu = True
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100372 nng = Graph()
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100373 sg = testutil.create_subgraph([reshape1_op, conv2d_op, reshape2_op])
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100374 nng.subgraphs.append(sg)
375
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100376 return nng, reshape1_op, conv2d_op, reshape2_op
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100377
378 # Test1 no Reshape op is expected to remain in the NPU subgrapgh
379 # but first one will be put on CPU
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100380 # Network is Reshape-Conv-Reshape
381 # Result is Conv
382 nng, reshape1_op, conv2d_op, reshape2_op = setup_network()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100383 arch = testutil.create_arch()
384 assert verify_graph_health(nng)
385 nng = optimise_graph_a(nng, arch)
386 assert verify_graph_health(nng)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100387
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100388 # Test2 reshape1 with different quantisation, this Reshape op is expected to remain
389 # Network is Reshape-Conv-Reshape
390 # expected is Reshape-Conv
391 nng, reshape1_op, conv2d_op, reshape2_op = setup_network()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100392 quant_zp32 = testutil.default_quant_params()
393 quant_zp32.zero_point = 32
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100394 reshape1_op.ofm.quantization = quant_zp32
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100395 assert verify_graph_health(nng)
396 nng = optimise_graph_a(nng, arch)
397 assert verify_graph_health(nng)