blob: b37bac80fd6a70149ec3e73bce3964edde5833d3 [file] [log] [blame]
Louis Verhaard1a92f782021-02-09 16:08:26 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Diqing Zhong94457b12020-12-09 15:22:40 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020018# Unit tests for tflite_graph_optimiser
Diqing Zhong94457b12020-12-09 15:22:40 +010019import numpy as np
Louis Verhaardebf4af62021-01-27 15:57:57 +010020import pytest
Diqing Zhong94457b12020-12-09 15:22:40 +010021
Louis Verhaardae2d5532020-12-11 17:19:54 +010022from ethosu.vela.data_type import DataType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020023from ethosu.vela.graph_optimiser import optimise_graph
Louis Verhaardae2d5532020-12-11 17:19:54 +010024from ethosu.vela.nn_graph import Graph
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from ethosu.vela.nn_graph import NetworkType
Diqing Zhong94457b12020-12-09 15:22:40 +010026from ethosu.vela.operation import Op
Louis Verhaardae2d5532020-12-11 17:19:54 +010027from ethosu.vela.operation import Padding
Patrik Gustavsson3a269202021-01-21 08:28:55 +010028from ethosu.vela.rewrite_graph import verify_graph_health
Diqing Zhong94457b12020-12-09 15:22:40 +010029from ethosu.vela.tensor import create_const_tensor
patrik.gustavssoneeb85152020-12-21 17:10:40 +000030from ethosu.vela.tensor import Shape4D
Diqing Zhong94457b12020-12-09 15:22:40 +010031from ethosu.vela.tensor import Tensor
32from ethosu.vela.test import testutil
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020033from ethosu.vela.tflite_graph_optimiser import calc_explicit_padding
34from ethosu.vela.tflite_graph_optimiser import convert_batched_fc_shape
35from ethosu.vela.tflite_graph_optimiser import replace_pad_by_hw_pad
36from ethosu.vela.tflite_graph_optimiser import rewrite_fully_connected_input
Diqing Zhong94457b12020-12-09 15:22:40 +010037
38
39def test_convert_batched_fc():
40 """Tests shape conversion of batched fully connected"""
Patrik Gustavsson3a269202021-01-21 08:28:55 +010041 ifm_shape = [4, 8]
42 ifm = create_const_tensor("test_in", ifm_shape, np.uint8, np.zeros(ifm_shape))
43 w_shape = [8, 4]
44 weights = create_const_tensor("weight_in", w_shape, np.uint8, np.zeros(w_shape))
Diqing Zhong94457b12020-12-09 15:22:40 +010045 ofm = Tensor(ifm.shape, np.uint8, "test_out")
46 op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm)
Patrik Gustavsson2349d422020-12-01 16:02:29 +010047
Diqing Zhong94457b12020-12-09 15:22:40 +010048 ifm.consumer_list.append(op)
49
50 prev_op = op.clone()
Patrik Gustavsson3a269202021-01-21 08:28:55 +010051 prev_op.ifm_shapes = op.ifm_shapes.copy()
52 prev_op.ofm_shapes = op.ofm_shapes.copy()
Patrik Gustavsson2349d422020-12-01 16:02:29 +010053
Patrik Gustavsson2c2522d2021-01-29 11:51:31 +010054 rewrite_fully_connected_input(op, None, None)
Diqing Zhong94457b12020-12-09 15:22:40 +010055 conv_op = convert_batched_fc_shape(op, None, None)
Diqing Zhong94457b12020-12-09 15:22:40 +010056 assert conv_op.ifm == prev_op.ifm
57 assert conv_op.ofm == prev_op.ofm
Patrik Gustavsson3a269202021-01-21 08:28:55 +010058 assert op.ifm_shapes[0] == Shape4D([1, 2, 2, 8])
59 assert op.ofm_shapes[0] == Shape4D([1, 2, 2, 8])
Diqing Zhong94457b12020-12-09 15:22:40 +010060 assert conv_op.type == Op.FullyConnected
61 assert len(conv_op.ifm.shape) == 2
Patrik Gustavsson3a269202021-01-21 08:28:55 +010062 assert len(conv_op.ofm.shape) == 2
63 assert conv_op.ifm.shape == conv_op.ofm.shape
64
65 ifm.shape = [1, 8]
66 weights.shape = [8, 1]
67 ofm.shape = [1, 8]
68 op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm)
69 ifm.consumer_list.append(op)
70
71 prev_op = op.clone()
72 prev_op.ifm_shapes = op.ifm_shapes.copy()
73 prev_op.ofm_shapes = op.ofm_shapes.copy()
74
Patrik Gustavsson2c2522d2021-01-29 11:51:31 +010075 rewrite_fully_connected_input(op, None, None)
Patrik Gustavsson3a269202021-01-21 08:28:55 +010076 conv_op = convert_batched_fc_shape(op, None, None)
77
78 assert conv_op.ifm == prev_op.ifm
79 assert conv_op.ofm == prev_op.ofm
80 assert op.ifm_shapes[0] == prev_op.ifm_shapes[0]
81 assert op.ofm_shapes[0] == prev_op.ofm_shapes[0]
82 assert conv_op.type == Op.FullyConnected
83 assert len(conv_op.ifm.shape) == 2
84 assert len(conv_op.ofm.shape) == 2
Diqing Zhong94457b12020-12-09 15:22:40 +010085 assert conv_op.ifm.shape == conv_op.ofm.shape
Louis Verhaardae2d5532020-12-11 17:19:54 +010086
87
Louis Verhaardebf4af62021-01-27 15:57:57 +010088explicit_padding_test_data = [
89 # Kernel size 2
90 [(17, 1, 2, 1, 1), (1, 1)],
91 [(18, 1, 2, 0, 1), (0, 1)],
92 [(18, 1, 2, 1, 0), (1, 0)],
93 # Kernel size 3
94 [(18, 2, 3, 1, 1), (1, 0)],
95 [(25, 2, 3, 1, 1), (1, 1)],
96 # Kernel size 4
97 [(18, 1, 4, 1, 2), (1, 2)],
98 [(18, 1, 4, 2, 1), (2, 1)],
99 [(19, 1, 4, 2, 2), (2, 2)],
100 # Kernel size 5
101 [(19, 1, 5, 1, 2), (1, 2)],
102 [(19, 1, 5, 0, 2), (0, 2)],
103 [(19, 1, 5, 1, 0), (1, 0)],
104 # Kernel size 21
105 [(41, 2, 21, 8, 10), (8, 10)],
106 [(41, 3, 21, 10, 10), (10, 9)],
107 [(42, 3, 21, 10, 10), (10, 8)],
108 [(42, 3, 21, 9, 10), (9, 9)],
109 [(41, 3, 21, 10, 6), (10, 6)],
110]
111
112
113@pytest.mark.parametrize("test_input, expected_result", explicit_padding_test_data)
114def test_calc_explicit_padding(test_input, expected_result):
115 input_size, stride, filter_size, explicit_pad_before, explicit_pad_after = test_input
116 before, after = calc_explicit_padding(input_size, stride, filter_size, explicit_pad_before, explicit_pad_after)
117 assert (before, after) == expected_result
118
119
Louis Verhaardc822d622021-03-11 14:59:06 +0100120def create_pad_and_conv2d(
121 in_shape,
122 out_shape,
123 padding,
124 in_dtype=DataType.int8,
125 out_dtype=DataType.int8,
126 pad_dtype=DataType.int32,
127 pad_setting=Padding.VALID,
128 kernel_size=3,
129):
130 """Creates Pad operator followed by a conv2d operator"""
131 qp = testutil.default_quant_params()
132 in0 = Tensor(in_shape, in_dtype, "in")
133 in0.quantization = qp
134 pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
135 out = Tensor(out_shape, out_dtype, "out")
136 out.quantization = qp.clone()
137 op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
138 op.run_on_npu = True
139 conv_out_tens = Tensor(in_shape, in_dtype, "output")
140 conv_out_tens.quantization = qp.clone()
141 weight_tens = Tensor([kernel_size, kernel_size, in_shape[-1], out_shape[-1]], in_dtype, "weights")
142 weight_tens.values = np.zeros(weight_tens.shape)
143 weight_tens.quant_values = np.zeros(weight_tens.shape, np.int8)
144 weight_tens.quantization = qp.clone()
145 bias_tens = Tensor(out_shape, pad_dtype, "biases")
146 attrs = {"padding": pad_setting, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1}
147 attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
148 conv2d_op = testutil.create_op(Op.Conv2DBias, [out, weight_tens, bias_tens], conv_out_tens, attrs)
149 conv2d_op.add_input_tensor(out)
150 conv2d_op.run_on_npu = True
151 return op, conv2d_op
152
153
154def test_pad_followed_by_conv_is_removed():
Louis Verhaardae2d5532020-12-11 17:19:54 +0100155 """
156 Tests that the PAD operator is bypassed when followed by a convolution operator,
157 and that the padding of the convolution operation is correctly updated
158 """
Louis Verhaardc822d622021-03-11 14:59:06 +0100159 pad_op, conv2d_op = create_pad_and_conv2d(
160 in_shape=[1, 76, 75, 64], out_shape=[1, 76, 75, 64], padding=[[0, 0], [2, 1], [1, 1], [0, 0]], kernel_size=4
161 )
162 nng = testutil.create_graph([pad_op, conv2d_op])
Louis Verhaardae2d5532020-12-11 17:19:54 +0100163 arch = testutil.create_arch()
164
Louis Verhaardc822d622021-03-11 14:59:06 +0100165 replace_pad_by_hw_pad(conv2d_op, nng, arch)
Louis Verhaardae2d5532020-12-11 17:19:54 +0100166
Louis Verhaardc822d622021-03-11 14:59:06 +0100167 op = nng.subgraphs[0].output_tensors[0].ops[0]
168 assert op.type == Op.Conv2DBias
Louis Verhaardae2d5532020-12-11 17:19:54 +0100169 assert op.attrs["padding"] == Padding.EXPLICIT
170 assert op.attrs["explicit_padding"] == (2, 1, 1, 1)
171 assert op.ifm.shape == [1, 76, 75, 64]
172 assert pad_op not in op.ifm.ops
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100173
174
Louis Verhaardc822d622021-03-11 14:59:06 +0100175leading_pad_test_data = [
176 (2, 2, 11, True),
177 (1, 2, 11, False),
178 (2, 1, 11, False),
179 (5, 2, 11, True),
180]
181
182
183@pytest.mark.parametrize("top, left, kernel_size, expect_pad_removed", leading_pad_test_data)
184def test_leading_pad_size(top, left, kernel_size, expect_pad_removed):
185 # Tests PAD operator with big kernel size; top and left pad must be multiple of stride
186 out_shape = [1, 11 + left, 11 + top, 1]
187 padding = [[0, 0], [top, 0], [left, 0], [0, 0]]
188 pad_op, conv2d_op = create_pad_and_conv2d(
189 in_shape=[1, 11, 11, 1], out_shape=out_shape, padding=padding, kernel_size=kernel_size
190 )
191 nng = testutil.create_graph([pad_op, conv2d_op])
192 arch = testutil.create_arch()
193 replace_pad_by_hw_pad(conv2d_op, nng, arch)
194 op = nng.subgraphs[0].output_tensors[0].ops[0]
195 if expect_pad_removed:
196 assert op.attrs["padding"] == Padding.EXPLICIT
197 assert "explicit_padding" in op.attrs
198 assert op.ifm.shape == op.ofm.shape
199 assert pad_op not in op.ifm.ops
200 else:
201 assert pad_op in op.ifm.ops
202 assert op.attrs["padding"] == Padding.VALID
203 assert "explicit_padding" not in op.attrs
204
205
Louis Verhaard1a92f782021-02-09 16:08:26 +0100206def test_optimise_pad_followed_by_avg_pool():
207 """
208 Tests that the PAD operator is bypassed when followed by a average pool operator,
209 and that the average pool is converted to a depthwise
210 """
211 # Create Pad operation followed by AvgPool
212 quant = testutil.default_quant_params()
213 in_tens = Tensor([1, 76, 75, 64], DataType.uint8, "input")
214 in_tens.quantization = quant
Louis Verhaardc822d622021-03-11 14:59:06 +0100215 # Test with 3x2 input tensor
216 pad_input = create_const_tensor("pad_input", [3, 2], DataType.int32, [[2, 2], [1, 1], [0, 0]])
Louis Verhaard1a92f782021-02-09 16:08:26 +0100217 temp_tens = Tensor([1, 79, 77, 64], DataType.uint8, "pad_out")
218 temp_tens.quantization = quant.clone()
219 out_tens = Tensor([1, 76, 75, 64], DataType.uint8, "output")
220 out_tens.quantization = quant.clone()
221
222 pad_op = testutil.create_op(Op.Pad, [in_tens, pad_input], temp_tens)
223 attrs = {
224 "padding": Padding.VALID,
225 "ksize": [1, 5, 3, 1],
226 "stride_w": 2,
227 "stride_h": 2,
228 "dilation_w_factor": 1,
229 "dilation_h_factor": 1,
230 }
231 attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
232 pad_op.run_on_npu = True
233 conv2d_op = testutil.create_op(Op.AvgPool, [temp_tens], out_tens, attrs)
234 conv2d_op.run_on_npu = True
Louis Verhaardc822d622021-03-11 14:59:06 +0100235 nng = testutil.create_graph([pad_op, conv2d_op])
Louis Verhaard1a92f782021-02-09 16:08:26 +0100236 arch = testutil.create_arch()
237
Louis Verhaardc822d622021-03-11 14:59:06 +0100238 replace_pad_by_hw_pad(conv2d_op, nng, arch)
Louis Verhaard1a92f782021-02-09 16:08:26 +0100239
Louis Verhaardc822d622021-03-11 14:59:06 +0100240 op = nng.subgraphs[0].output_tensors[0].ops[0]
Louis Verhaard1a92f782021-02-09 16:08:26 +0100241 assert op.type == Op.DepthwiseConv2DBias
242 assert op.attrs["padding"] == Padding.EXPLICIT
Louis Verhaardc822d622021-03-11 14:59:06 +0100243 assert op.attrs["explicit_padding"] == (2, 1, 2, 1)
Louis Verhaard1a92f782021-02-09 16:08:26 +0100244 assert op.ifm.shape == [1, 76, 75, 64]
245 assert pad_op not in op.ifm.ops
246 # Check that bias and weight tensors have been added
247 assert op.bias.shape == [64]
Louis Verhaard1a92f782021-02-09 16:08:26 +0100248 assert op.weights.shape == [5, 3, 1, 64]
249
250
Louis Verhaardc822d622021-03-11 14:59:06 +0100251pad_avg_pool_test_data = [
252 ((3, 3), (1, 1, 1, 1), True),
253 ((3, 3), (2, 1, 1, 1), False),
254 ((3, 3), (1, 2, 1, 1), False),
255 ((3, 3), (1, 1, 2, 1), False),
256 ((3, 3), (1, 1, 1, 2), False),
257 ((2, 4), (1, 2, 1, 2), True),
258 ((5, 3), (2, 1, 2, 1), True),
259 ((5, 3), (0, 1, 2, 1), True),
260 ((5, 3), (2, 0, 2, 1), True),
261 ((5, 3), (2, 1, 0, 1), True),
262 ((5, 3), (2, 1, 0, 1), True),
263 ((4, 4), (2, 2, 2, 2), True),
264 ((4, 4), (1, 2, 2, 2), False),
265 ((4, 4), (2, 1, 2, 2), False),
266 ((4, 4), (2, 2, 1, 2), False),
267 ((4, 4), (2, 2, 2, 1), False),
268]
269
270
271@pytest.mark.parametrize("k_size, padding, expect_pad_removed", pad_avg_pool_test_data)
272def test_pad_followed_by_avg_pool(k_size, padding, expect_pad_removed):
273 # Tests PAD followed by AvgPool
274 k_w, k_h = k_size
275 top, left, bottom, right = padding
276 pad_values = [[0, 0], [top, bottom], [left, right], [0, 0]]
277 dtype = DataType.int8
278 qp = testutil.default_quant_params()
279 in_shape = [1, 15, 17, 8]
280 out_shape = [1, in_shape[1] + top + bottom, in_shape[2] + left + right, in_shape[3]]
281 in0 = Tensor(in_shape, dtype, "in")
282 in0.quantization = qp
283 pad_tensor = create_const_tensor(
284 name="pad", shape=list(np.shape(pad_values)), values=pad_values, dtype=DataType.int32
285 )
286 out = Tensor(out_shape, dtype, "out")
287 out.quantization = qp.clone()
288 pad_op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
289 pool_out_tens = Tensor(in_shape, dtype, "output")
290 pool_out_tens.quantization = qp.clone()
291 attrs = {
292 "padding": Padding.VALID,
293 "ksize": [1, k_w, k_h, 1],
294 "stride_w": 1,
295 "stride_h": 1,
296 "dilation_w_factor": 1,
297 "dilation_h_factor": 1,
298 }
299 pool_op = testutil.create_op(Op.AvgPool, [out], pool_out_tens, attrs)
Louis Verhaardc822d622021-03-11 14:59:06 +0100300 pad_op.run_on_npu = True
301 pool_op.run_on_npu = True
302 nng = testutil.create_graph([pad_op, pool_op])
303 arch = testutil.create_arch()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200304 nng = optimise_graph(nng, arch, NetworkType.TFLite)
Louis Verhaardc822d622021-03-11 14:59:06 +0100305 sg = nng.subgraphs[0]
306 all_ops = sg.get_all_ops()
307 print("all_ops: ", all_ops)
308 # Pad should not be in the graph anymore, it should either have been removed or rewritten
309 assert not any(op.type == Op.Pad for op in all_ops)
310 op = nng.subgraphs[0].output_tensors[0].ops[0]
311 if expect_pad_removed:
312 # Expect rewrite to depthwise, PAD is removed
313 assert op.type == Op.DepthwiseConv2DBias
314 assert op.attrs["padding"] == Padding.EXPLICIT
315 assert any(pad > 0 for pad in op.attrs["explicit_padding"])
316 assert op.ifm.shape == op.ofm.shape
317 # Check that bias and weight tensors have been added
318 assert len(op.bias.shape) > 0
319 assert op.weights.shape is not None
320 else:
321 # Pad should have been rewritten to a number of average pool operations
322 assert all(op.type in (Op.AvgPool, Op.Const) for op in all_ops)
323 assert pool_op.type == Op.AvgPool
324 assert pool_op.attrs["padding"] == Padding.VALID
325
326
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100327def test_remove_reshape():
328 """
329 Tests that the expected reshape are removed in graph_optimisation
330 """
331
332 def setup_network():
333 quant = testutil.default_quant_params()
334 # create reshape1 op
335 ifm_shape = [64, 16]
336 reshape1_ofm_shape = [1, 4, 16, 16]
337 reshape1_ifm = create_const_tensor("reshape1_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
338 reshape1_ifm.quantization = quant
339 reshape1_ofm = create_const_tensor(
340 "reshape1_out", reshape1_ofm_shape, DataType.uint8, np.zeros(reshape1_ofm_shape)
341 )
342 reshape1_ofm.quantization = quant
343 shape_tens = create_const_tensor("reshape1_shape", [1], DataType.int32, reshape1_ofm_shape)
344 reshape1_op = testutil.create_op(Op.Reshape, [reshape1_ifm, shape_tens], reshape1_ofm, set_ifm_ofm_shapes=False)
345 reshape1_op.attrs["new_shape"] = reshape1_ofm_shape
346 reshape1_op.run_on_npu = True
347
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100348 # create conv op
349 conv_ofm = Tensor([1, 8, 8, 16], DataType.uint8, "output")
350 conv_ofm.quantization = quant.clone()
351 weight_tens = Tensor([1, 1, 16, 16], DataType.uint8, "weights")
352 weight_tens.values = np.zeros(weight_tens.shape)
353 weight_tens.quant_values = np.zeros(weight_tens.shape, np.uint8)
354 weight_tens.quantization = quant.clone()
355 bias_tens = Tensor([16], DataType.int32, "biases")
356
357 attrs = {"padding": Padding.SAME, "stride_w": 1, "stride_h": 1, "dilation_w_factor": 1, "dilation_h_factor": 1}
358 attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
359
360 conv2d_op = testutil.create_op(
361 Op.Conv2D, [reshape1_ofm, weight_tens, bias_tens], conv_ofm, attrs=attrs, set_ifm_ofm_shapes=False
362 )
363 conv2d_op.run_on_npu = True
364
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100365 # create reshape2 op
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100366 ofm_shape = [8, 8, 16]
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100367 reshape2_ofm = create_const_tensor("reshape2_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
368 reshape2_ofm.quantization = quant
369 shape_tens = create_const_tensor("reshape2_shape", [1], DataType.int32, ofm_shape)
370 reshape2_op = testutil.create_op(Op.Reshape, [conv_ofm, shape_tens], reshape2_ofm, set_ifm_ofm_shapes=False)
371 reshape2_op.attrs["new_shape"] = ofm_shape
372 reshape2_op.run_on_npu = True
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100373 nng = Graph()
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100374 sg = testutil.create_subgraph([reshape1_op, conv2d_op, reshape2_op])
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100375 nng.subgraphs.append(sg)
376
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100377 return nng, reshape1_op, conv2d_op, reshape2_op
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100378
379 # Test1 no Reshape op is expected to remain in the NPU subgrapgh
380 # but first one will be put on CPU
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100381 # Network is Reshape-Conv-Reshape
382 # Result is Conv
383 nng, reshape1_op, conv2d_op, reshape2_op = setup_network()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100384 arch = testutil.create_arch()
385 assert verify_graph_health(nng)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200386 nng = optimise_graph(nng, arch, NetworkType.TFLite)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100387 assert verify_graph_health(nng)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100388
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100389 # Test2 reshape1 with different quantisation, this Reshape op is expected to remain
390 # Network is Reshape-Conv-Reshape
391 # expected is Reshape-Conv
392 nng, reshape1_op, conv2d_op, reshape2_op = setup_network()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100393 quant_zp32 = testutil.default_quant_params()
394 quant_zp32.zero_point = 32
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100395 reshape1_op.ofm.quantization = quant_zp32
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100396 assert verify_graph_health(nng)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200397 nng = optimise_graph(nng, arch, NetworkType.TFLite)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100398 assert verify_graph_health(nng)