blob: e0eedd66d9723e9b37518b7c8d1979571492ebb1 [file] [log] [blame]
Louis Verhaard1a92f782021-02-09 16:08:26 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Diqing Zhong94457b12020-12-09 15:22:40 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020018# Unit tests for tflite_graph_optimiser
Diqing Zhong94457b12020-12-09 15:22:40 +010019import numpy as np
Louis Verhaardebf4af62021-01-27 15:57:57 +010020import pytest
Diqing Zhong94457b12020-12-09 15:22:40 +010021
Louis Verhaardae2d5532020-12-11 17:19:54 +010022from ethosu.vela.data_type import DataType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020023from ethosu.vela.graph_optimiser import optimise_graph
Louis Verhaardae2d5532020-12-11 17:19:54 +010024from ethosu.vela.nn_graph import Graph
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from ethosu.vela.nn_graph import NetworkType
Diqing Zhong94457b12020-12-09 15:22:40 +010026from ethosu.vela.operation import Op
Louis Verhaardae2d5532020-12-11 17:19:54 +010027from ethosu.vela.operation import Padding
Patrik Gustavsson3a269202021-01-21 08:28:55 +010028from ethosu.vela.rewrite_graph import verify_graph_health
Diqing Zhong94457b12020-12-09 15:22:40 +010029from ethosu.vela.tensor import create_const_tensor
patrik.gustavssoneeb85152020-12-21 17:10:40 +000030from ethosu.vela.tensor import Shape4D
Diqing Zhong94457b12020-12-09 15:22:40 +010031from ethosu.vela.tensor import Tensor
32from ethosu.vela.test import testutil
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020033from ethosu.vela.tflite_graph_optimiser import calc_explicit_padding
34from ethosu.vela.tflite_graph_optimiser import convert_batched_fc_shape
35from ethosu.vela.tflite_graph_optimiser import replace_pad_by_hw_pad
36from ethosu.vela.tflite_graph_optimiser import rewrite_fully_connected_input
Diqing Zhong94457b12020-12-09 15:22:40 +010037
38
39def test_convert_batched_fc():
40 """Tests shape conversion of batched fully connected"""
Patrik Gustavsson3a269202021-01-21 08:28:55 +010041 ifm_shape = [4, 8]
42 ifm = create_const_tensor("test_in", ifm_shape, np.uint8, np.zeros(ifm_shape))
43 w_shape = [8, 4]
44 weights = create_const_tensor("weight_in", w_shape, np.uint8, np.zeros(w_shape))
Diqing Zhong94457b12020-12-09 15:22:40 +010045 ofm = Tensor(ifm.shape, np.uint8, "test_out")
46 op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm)
Patrik Gustavsson2349d422020-12-01 16:02:29 +010047
Diqing Zhong94457b12020-12-09 15:22:40 +010048 ifm.consumer_list.append(op)
49
50 prev_op = op.clone()
Patrik Gustavsson3a269202021-01-21 08:28:55 +010051 prev_op.ifm_shapes = op.ifm_shapes.copy()
52 prev_op.ofm_shapes = op.ofm_shapes.copy()
Patrik Gustavsson2349d422020-12-01 16:02:29 +010053
Patrik Gustavsson2c2522d2021-01-29 11:51:31 +010054 rewrite_fully_connected_input(op, None, None)
Diqing Zhong94457b12020-12-09 15:22:40 +010055 conv_op = convert_batched_fc_shape(op, None, None)
Diqing Zhong94457b12020-12-09 15:22:40 +010056 assert conv_op.ifm == prev_op.ifm
57 assert conv_op.ofm == prev_op.ofm
Patrik Gustavsson3a269202021-01-21 08:28:55 +010058 assert op.ifm_shapes[0] == Shape4D([1, 2, 2, 8])
59 assert op.ofm_shapes[0] == Shape4D([1, 2, 2, 8])
Diqing Zhong94457b12020-12-09 15:22:40 +010060 assert conv_op.type == Op.FullyConnected
61 assert len(conv_op.ifm.shape) == 2
Patrik Gustavsson3a269202021-01-21 08:28:55 +010062 assert len(conv_op.ofm.shape) == 2
63 assert conv_op.ifm.shape == conv_op.ofm.shape
64
65 ifm.shape = [1, 8]
66 weights.shape = [8, 1]
67 ofm.shape = [1, 8]
68 op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm)
69 ifm.consumer_list.append(op)
70
71 prev_op = op.clone()
72 prev_op.ifm_shapes = op.ifm_shapes.copy()
73 prev_op.ofm_shapes = op.ofm_shapes.copy()
74
Patrik Gustavsson2c2522d2021-01-29 11:51:31 +010075 rewrite_fully_connected_input(op, None, None)
Patrik Gustavsson3a269202021-01-21 08:28:55 +010076 conv_op = convert_batched_fc_shape(op, None, None)
77
78 assert conv_op.ifm == prev_op.ifm
79 assert conv_op.ofm == prev_op.ofm
80 assert op.ifm_shapes[0] == prev_op.ifm_shapes[0]
81 assert op.ofm_shapes[0] == prev_op.ofm_shapes[0]
82 assert conv_op.type == Op.FullyConnected
83 assert len(conv_op.ifm.shape) == 2
84 assert len(conv_op.ofm.shape) == 2
Diqing Zhong94457b12020-12-09 15:22:40 +010085 assert conv_op.ifm.shape == conv_op.ofm.shape
Louis Verhaardae2d5532020-12-11 17:19:54 +010086
87
Louis Verhaardebf4af62021-01-27 15:57:57 +010088explicit_padding_test_data = [
89 # Kernel size 2
90 [(17, 1, 2, 1, 1), (1, 1)],
91 [(18, 1, 2, 0, 1), (0, 1)],
92 [(18, 1, 2, 1, 0), (1, 0)],
93 # Kernel size 3
94 [(18, 2, 3, 1, 1), (1, 0)],
95 [(25, 2, 3, 1, 1), (1, 1)],
96 # Kernel size 4
97 [(18, 1, 4, 1, 2), (1, 2)],
98 [(18, 1, 4, 2, 1), (2, 1)],
99 [(19, 1, 4, 2, 2), (2, 2)],
100 # Kernel size 5
101 [(19, 1, 5, 1, 2), (1, 2)],
102 [(19, 1, 5, 0, 2), (0, 2)],
103 [(19, 1, 5, 1, 0), (1, 0)],
104 # Kernel size 21
105 [(41, 2, 21, 8, 10), (8, 10)],
106 [(41, 3, 21, 10, 10), (10, 9)],
107 [(42, 3, 21, 10, 10), (10, 8)],
108 [(42, 3, 21, 9, 10), (9, 9)],
109 [(41, 3, 21, 10, 6), (10, 6)],
110]
111
112
113@pytest.mark.parametrize("test_input, expected_result", explicit_padding_test_data)
114def test_calc_explicit_padding(test_input, expected_result):
115 input_size, stride, filter_size, explicit_pad_before, explicit_pad_after = test_input
116 before, after = calc_explicit_padding(input_size, stride, filter_size, explicit_pad_before, explicit_pad_after)
117 assert (before, after) == expected_result
118
119
Louis Verhaardc822d622021-03-11 14:59:06 +0100120def create_pad_and_conv2d(
121 in_shape,
122 out_shape,
123 padding,
124 in_dtype=DataType.int8,
125 out_dtype=DataType.int8,
126 pad_dtype=DataType.int32,
127 pad_setting=Padding.VALID,
128 kernel_size=3,
129):
130 """Creates Pad operator followed by a conv2d operator"""
131 qp = testutil.default_quant_params()
132 in0 = Tensor(in_shape, in_dtype, "in")
133 in0.quantization = qp
134 pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
135 out = Tensor(out_shape, out_dtype, "out")
136 out.quantization = qp.clone()
137 op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
138 op.run_on_npu = True
139 conv_out_tens = Tensor(in_shape, in_dtype, "output")
140 conv_out_tens.quantization = qp.clone()
141 weight_tens = Tensor([kernel_size, kernel_size, in_shape[-1], out_shape[-1]], in_dtype, "weights")
James Peet7519d502021-07-19 16:47:58 +0100142 weight_tens.values = np.zeros(weight_tens.shape, in_dtype.as_numpy_type())
Louis Verhaardc822d622021-03-11 14:59:06 +0100143 weight_tens.quantization = qp.clone()
144 bias_tens = Tensor(out_shape, pad_dtype, "biases")
145 attrs = {"padding": pad_setting, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1}
146 attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
147 conv2d_op = testutil.create_op(Op.Conv2DBias, [out, weight_tens, bias_tens], conv_out_tens, attrs)
148 conv2d_op.add_input_tensor(out)
149 conv2d_op.run_on_npu = True
150 return op, conv2d_op
151
152
153def test_pad_followed_by_conv_is_removed():
Louis Verhaardae2d5532020-12-11 17:19:54 +0100154 """
155 Tests that the PAD operator is bypassed when followed by a convolution operator,
156 and that the padding of the convolution operation is correctly updated
157 """
Louis Verhaardc822d622021-03-11 14:59:06 +0100158 pad_op, conv2d_op = create_pad_and_conv2d(
159 in_shape=[1, 76, 75, 64], out_shape=[1, 76, 75, 64], padding=[[0, 0], [2, 1], [1, 1], [0, 0]], kernel_size=4
160 )
161 nng = testutil.create_graph([pad_op, conv2d_op])
Louis Verhaardae2d5532020-12-11 17:19:54 +0100162 arch = testutil.create_arch()
163
Louis Verhaardc822d622021-03-11 14:59:06 +0100164 replace_pad_by_hw_pad(conv2d_op, nng, arch)
Louis Verhaardae2d5532020-12-11 17:19:54 +0100165
Louis Verhaardc822d622021-03-11 14:59:06 +0100166 op = nng.subgraphs[0].output_tensors[0].ops[0]
167 assert op.type == Op.Conv2DBias
Louis Verhaardae2d5532020-12-11 17:19:54 +0100168 assert op.attrs["padding"] == Padding.EXPLICIT
169 assert op.attrs["explicit_padding"] == (2, 1, 1, 1)
170 assert op.ifm.shape == [1, 76, 75, 64]
171 assert pad_op not in op.ifm.ops
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100172
173
Louis Verhaardc822d622021-03-11 14:59:06 +0100174leading_pad_test_data = [
175 (2, 2, 11, True),
176 (1, 2, 11, False),
177 (2, 1, 11, False),
178 (5, 2, 11, True),
179]
180
181
182@pytest.mark.parametrize("top, left, kernel_size, expect_pad_removed", leading_pad_test_data)
183def test_leading_pad_size(top, left, kernel_size, expect_pad_removed):
184 # Tests PAD operator with big kernel size; top and left pad must be multiple of stride
185 out_shape = [1, 11 + left, 11 + top, 1]
186 padding = [[0, 0], [top, 0], [left, 0], [0, 0]]
187 pad_op, conv2d_op = create_pad_and_conv2d(
188 in_shape=[1, 11, 11, 1], out_shape=out_shape, padding=padding, kernel_size=kernel_size
189 )
190 nng = testutil.create_graph([pad_op, conv2d_op])
191 arch = testutil.create_arch()
192 replace_pad_by_hw_pad(conv2d_op, nng, arch)
193 op = nng.subgraphs[0].output_tensors[0].ops[0]
194 if expect_pad_removed:
195 assert op.attrs["padding"] == Padding.EXPLICIT
196 assert "explicit_padding" in op.attrs
197 assert op.ifm.shape == op.ofm.shape
198 assert pad_op not in op.ifm.ops
199 else:
200 assert pad_op in op.ifm.ops
201 assert op.attrs["padding"] == Padding.VALID
202 assert "explicit_padding" not in op.attrs
203
204
Louis Verhaard1a92f782021-02-09 16:08:26 +0100205def test_optimise_pad_followed_by_avg_pool():
206 """
207 Tests that the PAD operator is bypassed when followed by a average pool operator,
208 and that the average pool is converted to a depthwise
209 """
210 # Create Pad operation followed by AvgPool
211 quant = testutil.default_quant_params()
212 in_tens = Tensor([1, 76, 75, 64], DataType.uint8, "input")
213 in_tens.quantization = quant
Louis Verhaardc822d622021-03-11 14:59:06 +0100214 # Test with 3x2 input tensor
215 pad_input = create_const_tensor("pad_input", [3, 2], DataType.int32, [[2, 2], [1, 1], [0, 0]])
Louis Verhaard1a92f782021-02-09 16:08:26 +0100216 temp_tens = Tensor([1, 79, 77, 64], DataType.uint8, "pad_out")
217 temp_tens.quantization = quant.clone()
218 out_tens = Tensor([1, 76, 75, 64], DataType.uint8, "output")
219 out_tens.quantization = quant.clone()
220
221 pad_op = testutil.create_op(Op.Pad, [in_tens, pad_input], temp_tens)
222 attrs = {
223 "padding": Padding.VALID,
224 "ksize": [1, 5, 3, 1],
225 "stride_w": 2,
226 "stride_h": 2,
227 "dilation_w_factor": 1,
228 "dilation_h_factor": 1,
229 }
230 attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
231 pad_op.run_on_npu = True
232 conv2d_op = testutil.create_op(Op.AvgPool, [temp_tens], out_tens, attrs)
233 conv2d_op.run_on_npu = True
Louis Verhaardc822d622021-03-11 14:59:06 +0100234 nng = testutil.create_graph([pad_op, conv2d_op])
Louis Verhaard1a92f782021-02-09 16:08:26 +0100235 arch = testutil.create_arch()
236
Louis Verhaardc822d622021-03-11 14:59:06 +0100237 replace_pad_by_hw_pad(conv2d_op, nng, arch)
Louis Verhaard1a92f782021-02-09 16:08:26 +0100238
Louis Verhaardc822d622021-03-11 14:59:06 +0100239 op = nng.subgraphs[0].output_tensors[0].ops[0]
Louis Verhaard1a92f782021-02-09 16:08:26 +0100240 assert op.type == Op.DepthwiseConv2DBias
241 assert op.attrs["padding"] == Padding.EXPLICIT
Louis Verhaardc822d622021-03-11 14:59:06 +0100242 assert op.attrs["explicit_padding"] == (2, 1, 2, 1)
Louis Verhaard1a92f782021-02-09 16:08:26 +0100243 assert op.ifm.shape == [1, 76, 75, 64]
244 assert pad_op not in op.ifm.ops
245 # Check that bias and weight tensors have been added
246 assert op.bias.shape == [64]
Louis Verhaard1a92f782021-02-09 16:08:26 +0100247 assert op.weights.shape == [5, 3, 1, 64]
248
249
Louis Verhaardc822d622021-03-11 14:59:06 +0100250pad_avg_pool_test_data = [
251 ((3, 3), (1, 1, 1, 1), True),
252 ((3, 3), (2, 1, 1, 1), False),
253 ((3, 3), (1, 2, 1, 1), False),
254 ((3, 3), (1, 1, 2, 1), False),
255 ((3, 3), (1, 1, 1, 2), False),
256 ((2, 4), (1, 2, 1, 2), True),
257 ((5, 3), (2, 1, 2, 1), True),
258 ((5, 3), (0, 1, 2, 1), True),
259 ((5, 3), (2, 0, 2, 1), True),
260 ((5, 3), (2, 1, 0, 1), True),
261 ((5, 3), (2, 1, 0, 1), True),
262 ((4, 4), (2, 2, 2, 2), True),
263 ((4, 4), (1, 2, 2, 2), False),
264 ((4, 4), (2, 1, 2, 2), False),
265 ((4, 4), (2, 2, 1, 2), False),
266 ((4, 4), (2, 2, 2, 1), False),
267]
268
269
270@pytest.mark.parametrize("k_size, padding, expect_pad_removed", pad_avg_pool_test_data)
271def test_pad_followed_by_avg_pool(k_size, padding, expect_pad_removed):
272 # Tests PAD followed by AvgPool
273 k_w, k_h = k_size
274 top, left, bottom, right = padding
275 pad_values = [[0, 0], [top, bottom], [left, right], [0, 0]]
276 dtype = DataType.int8
277 qp = testutil.default_quant_params()
278 in_shape = [1, 15, 17, 8]
279 out_shape = [1, in_shape[1] + top + bottom, in_shape[2] + left + right, in_shape[3]]
280 in0 = Tensor(in_shape, dtype, "in")
281 in0.quantization = qp
282 pad_tensor = create_const_tensor(
283 name="pad", shape=list(np.shape(pad_values)), values=pad_values, dtype=DataType.int32
284 )
285 out = Tensor(out_shape, dtype, "out")
286 out.quantization = qp.clone()
287 pad_op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
288 pool_out_tens = Tensor(in_shape, dtype, "output")
289 pool_out_tens.quantization = qp.clone()
290 attrs = {
291 "padding": Padding.VALID,
292 "ksize": [1, k_w, k_h, 1],
293 "stride_w": 1,
294 "stride_h": 1,
295 "dilation_w_factor": 1,
296 "dilation_h_factor": 1,
297 }
298 pool_op = testutil.create_op(Op.AvgPool, [out], pool_out_tens, attrs)
Louis Verhaardc822d622021-03-11 14:59:06 +0100299 pad_op.run_on_npu = True
300 pool_op.run_on_npu = True
301 nng = testutil.create_graph([pad_op, pool_op])
302 arch = testutil.create_arch()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200303 nng = optimise_graph(nng, arch, NetworkType.TFLite)
Louis Verhaardc822d622021-03-11 14:59:06 +0100304 sg = nng.subgraphs[0]
305 all_ops = sg.get_all_ops()
306 print("all_ops: ", all_ops)
307 # Pad should not be in the graph anymore, it should either have been removed or rewritten
308 assert not any(op.type == Op.Pad for op in all_ops)
309 op = nng.subgraphs[0].output_tensors[0].ops[0]
310 if expect_pad_removed:
311 # Expect rewrite to depthwise, PAD is removed
312 assert op.type == Op.DepthwiseConv2DBias
313 assert op.attrs["padding"] == Padding.EXPLICIT
314 assert any(pad > 0 for pad in op.attrs["explicit_padding"])
315 assert op.ifm.shape == op.ofm.shape
316 # Check that bias and weight tensors have been added
317 assert len(op.bias.shape) > 0
318 assert op.weights.shape is not None
319 else:
320 # Pad should have been rewritten to a number of average pool operations
321 assert all(op.type in (Op.AvgPool, Op.Const) for op in all_ops)
322 assert pool_op.type == Op.AvgPool
323 assert pool_op.attrs["padding"] == Padding.VALID
324
325
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100326def test_remove_reshape():
327 """
328 Tests that the expected reshape are removed in graph_optimisation
329 """
330
331 def setup_network():
332 quant = testutil.default_quant_params()
333 # create reshape1 op
334 ifm_shape = [64, 16]
335 reshape1_ofm_shape = [1, 4, 16, 16]
336 reshape1_ifm = create_const_tensor("reshape1_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
337 reshape1_ifm.quantization = quant
338 reshape1_ofm = create_const_tensor(
339 "reshape1_out", reshape1_ofm_shape, DataType.uint8, np.zeros(reshape1_ofm_shape)
340 )
341 reshape1_ofm.quantization = quant
342 shape_tens = create_const_tensor("reshape1_shape", [1], DataType.int32, reshape1_ofm_shape)
343 reshape1_op = testutil.create_op(Op.Reshape, [reshape1_ifm, shape_tens], reshape1_ofm, set_ifm_ofm_shapes=False)
344 reshape1_op.attrs["new_shape"] = reshape1_ofm_shape
345 reshape1_op.run_on_npu = True
346
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100347 # create conv op
348 conv_ofm = Tensor([1, 8, 8, 16], DataType.uint8, "output")
349 conv_ofm.quantization = quant.clone()
350 weight_tens = Tensor([1, 1, 16, 16], DataType.uint8, "weights")
James Peet7519d502021-07-19 16:47:58 +0100351 weight_tens.values = np.zeros(weight_tens.shape, np.uint8)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100352 weight_tens.quantization = quant.clone()
353 bias_tens = Tensor([16], DataType.int32, "biases")
354
355 attrs = {"padding": Padding.SAME, "stride_w": 1, "stride_h": 1, "dilation_w_factor": 1, "dilation_h_factor": 1}
356 attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
357
358 conv2d_op = testutil.create_op(
359 Op.Conv2D, [reshape1_ofm, weight_tens, bias_tens], conv_ofm, attrs=attrs, set_ifm_ofm_shapes=False
360 )
361 conv2d_op.run_on_npu = True
362
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100363 # create reshape2 op
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100364 ofm_shape = [8, 8, 16]
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100365 reshape2_ofm = create_const_tensor("reshape2_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
366 reshape2_ofm.quantization = quant
367 shape_tens = create_const_tensor("reshape2_shape", [1], DataType.int32, ofm_shape)
368 reshape2_op = testutil.create_op(Op.Reshape, [conv_ofm, shape_tens], reshape2_ofm, set_ifm_ofm_shapes=False)
369 reshape2_op.attrs["new_shape"] = ofm_shape
370 reshape2_op.run_on_npu = True
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100371 nng = Graph()
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100372 sg = testutil.create_subgraph([reshape1_op, conv2d_op, reshape2_op])
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100373 nng.subgraphs.append(sg)
374
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100375 return nng, reshape1_op, conv2d_op, reshape2_op
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100376
377 # Test1 no Reshape op is expected to remain in the NPU subgrapgh
378 # but first one will be put on CPU
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100379 # Network is Reshape-Conv-Reshape
380 # Result is Conv
381 nng, reshape1_op, conv2d_op, reshape2_op = setup_network()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100382 arch = testutil.create_arch()
383 assert verify_graph_health(nng)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200384 nng = optimise_graph(nng, arch, NetworkType.TFLite)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100385 assert verify_graph_health(nng)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100386
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100387 # Test2 reshape1 with different quantisation, this Reshape op is expected to remain
388 # Network is Reshape-Conv-Reshape
389 # expected is Reshape-Conv
390 nng, reshape1_op, conv2d_op, reshape2_op = setup_network()
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100391 quant_zp32 = testutil.default_quant_params()
392 quant_zp32.zero_point = 32
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100393 reshape1_op.ofm.quantization = quant_zp32
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100394 assert verify_graph_health(nng)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200395 nng = optimise_graph(nng, arch, NetworkType.TFLite)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100396 assert verify_graph_health(nng)