blob: 2f96572451dca82039830de44011c0e7a523d370 [file] [log] [blame]
Louis Verhaard1a92f782021-02-09 16:08:26 +01001# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
Diqing Zhong94457b12020-12-09 15:22:40 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020018# Unit tests for tflite_graph_optimiser
Diqing Zhong94457b12020-12-09 15:22:40 +010019import numpy as np
Louis Verhaardebf4af62021-01-27 15:57:57 +010020import pytest
Diqing Zhong94457b12020-12-09 15:22:40 +010021
Louis Verhaardae2d5532020-12-11 17:19:54 +010022from ethosu.vela.data_type import DataType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020023from ethosu.vela.graph_optimiser import optimise_graph
Louis Verhaardae2d5532020-12-11 17:19:54 +010024from ethosu.vela.nn_graph import Graph
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020025from ethosu.vela.nn_graph import NetworkType
Diqing Zhong94457b12020-12-09 15:22:40 +010026from ethosu.vela.operation import Op
Louis Verhaardae2d5532020-12-11 17:19:54 +010027from ethosu.vela.operation import Padding
Patrik Gustavsson3a269202021-01-21 08:28:55 +010028from ethosu.vela.rewrite_graph import verify_graph_health
Diqing Zhong94457b12020-12-09 15:22:40 +010029from ethosu.vela.tensor import create_const_tensor
patrik.gustavssoneeb85152020-12-21 17:10:40 +000030from ethosu.vela.tensor import Shape4D
Diqing Zhong94457b12020-12-09 15:22:40 +010031from ethosu.vela.tensor import Tensor
32from ethosu.vela.test import testutil
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020033from ethosu.vela.tflite_graph_optimiser import calc_explicit_padding
34from ethosu.vela.tflite_graph_optimiser import convert_batched_fc_shape
35from ethosu.vela.tflite_graph_optimiser import replace_pad_by_hw_pad
36from ethosu.vela.tflite_graph_optimiser import rewrite_fully_connected_input
Diqing Zhong94457b12020-12-09 15:22:40 +010037
38
39def test_convert_batched_fc():
40 """Tests shape conversion of batched fully connected"""
Patrik Gustavsson3a269202021-01-21 08:28:55 +010041 ifm_shape = [4, 8]
42 ifm = create_const_tensor("test_in", ifm_shape, np.uint8, np.zeros(ifm_shape))
43 w_shape = [8, 4]
44 weights = create_const_tensor("weight_in", w_shape, np.uint8, np.zeros(w_shape))
Diqing Zhong94457b12020-12-09 15:22:40 +010045 ofm = Tensor(ifm.shape, np.uint8, "test_out")
46 op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm)
Patrik Gustavsson2349d422020-12-01 16:02:29 +010047
Diqing Zhong94457b12020-12-09 15:22:40 +010048 ifm.consumer_list.append(op)
49
50 prev_op = op.clone()
Patrik Gustavsson3a269202021-01-21 08:28:55 +010051 prev_op.ifm_shapes = op.ifm_shapes.copy()
52 prev_op.ofm_shapes = op.ofm_shapes.copy()
Patrik Gustavsson2349d422020-12-01 16:02:29 +010053
Patrik Gustavsson2c2522d2021-01-29 11:51:31 +010054 rewrite_fully_connected_input(op, None, None)
Diqing Zhong94457b12020-12-09 15:22:40 +010055 conv_op = convert_batched_fc_shape(op, None, None)
Diqing Zhong94457b12020-12-09 15:22:40 +010056 assert conv_op.ifm == prev_op.ifm
57 assert conv_op.ofm == prev_op.ofm
Patrik Gustavsson3a269202021-01-21 08:28:55 +010058 assert op.ifm_shapes[0] == Shape4D([1, 2, 2, 8])
59 assert op.ofm_shapes[0] == Shape4D([1, 2, 2, 8])
Diqing Zhong94457b12020-12-09 15:22:40 +010060 assert conv_op.type == Op.FullyConnected
61 assert len(conv_op.ifm.shape) == 2
Patrik Gustavsson3a269202021-01-21 08:28:55 +010062 assert len(conv_op.ofm.shape) == 2
63 assert conv_op.ifm.shape == conv_op.ofm.shape
64
65 ifm.shape = [1, 8]
66 weights.shape = [8, 1]
67 ofm.shape = [1, 8]
68 op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm)
69 ifm.consumer_list.append(op)
70
71 prev_op = op.clone()
72 prev_op.ifm_shapes = op.ifm_shapes.copy()
73 prev_op.ofm_shapes = op.ofm_shapes.copy()
74
Patrik Gustavsson2c2522d2021-01-29 11:51:31 +010075 rewrite_fully_connected_input(op, None, None)
Patrik Gustavsson3a269202021-01-21 08:28:55 +010076 conv_op = convert_batched_fc_shape(op, None, None)
77
78 assert conv_op.ifm == prev_op.ifm
79 assert conv_op.ofm == prev_op.ofm
80 assert op.ifm_shapes[0] == prev_op.ifm_shapes[0]
81 assert op.ofm_shapes[0] == prev_op.ofm_shapes[0]
82 assert conv_op.type == Op.FullyConnected
83 assert len(conv_op.ifm.shape) == 2
84 assert len(conv_op.ofm.shape) == 2
Diqing Zhong94457b12020-12-09 15:22:40 +010085 assert conv_op.ifm.shape == conv_op.ofm.shape
Louis Verhaardae2d5532020-12-11 17:19:54 +010086
87
Louis Verhaardebf4af62021-01-27 15:57:57 +010088explicit_padding_test_data = [
89 # Kernel size 2
90 [(17, 1, 2, 1, 1), (1, 1)],
91 [(18, 1, 2, 0, 1), (0, 1)],
92 [(18, 1, 2, 1, 0), (1, 0)],
93 # Kernel size 3
94 [(18, 2, 3, 1, 1), (1, 0)],
95 [(25, 2, 3, 1, 1), (1, 1)],
96 # Kernel size 4
97 [(18, 1, 4, 1, 2), (1, 2)],
98 [(18, 1, 4, 2, 1), (2, 1)],
99 [(19, 1, 4, 2, 2), (2, 2)],
100 # Kernel size 5
101 [(19, 1, 5, 1, 2), (1, 2)],
102 [(19, 1, 5, 0, 2), (0, 2)],
103 [(19, 1, 5, 1, 0), (1, 0)],
104 # Kernel size 21
105 [(41, 2, 21, 8, 10), (8, 10)],
106 [(41, 3, 21, 10, 10), (10, 9)],
107 [(42, 3, 21, 10, 10), (10, 8)],
108 [(42, 3, 21, 9, 10), (9, 9)],
109 [(41, 3, 21, 10, 6), (10, 6)],
110]
111
112
113@pytest.mark.parametrize("test_input, expected_result", explicit_padding_test_data)
114def test_calc_explicit_padding(test_input, expected_result):
115 input_size, stride, filter_size, explicit_pad_before, explicit_pad_after = test_input
116 before, after = calc_explicit_padding(input_size, stride, filter_size, explicit_pad_before, explicit_pad_after)
117 assert (before, after) == expected_result
118
119
Louis Verhaardc822d622021-03-11 14:59:06 +0100120def create_pad_and_conv2d(
121 in_shape,
122 out_shape,
123 padding,
124 in_dtype=DataType.int8,
125 out_dtype=DataType.int8,
126 pad_dtype=DataType.int32,
127 pad_setting=Padding.VALID,
128 kernel_size=3,
129):
130 """Creates Pad operator followed by a conv2d operator"""
131 qp = testutil.default_quant_params()
132 in0 = Tensor(in_shape, in_dtype, "in")
133 in0.quantization = qp
134 pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
135 out = Tensor(out_shape, out_dtype, "out")
136 out.quantization = qp.clone()
137 op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
138 op.run_on_npu = True
139 conv_out_tens = Tensor(in_shape, in_dtype, "output")
140 conv_out_tens.quantization = qp.clone()
141 weight_tens = Tensor([kernel_size, kernel_size, in_shape[-1], out_shape[-1]], in_dtype, "weights")
James Peet7519d502021-07-19 16:47:58 +0100142 weight_tens.values = np.zeros(weight_tens.shape, in_dtype.as_numpy_type())
Louis Verhaardc822d622021-03-11 14:59:06 +0100143 weight_tens.quantization = qp.clone()
144 bias_tens = Tensor(out_shape, pad_dtype, "biases")
145 attrs = {"padding": pad_setting, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1}
146 attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
147 conv2d_op = testutil.create_op(Op.Conv2DBias, [out, weight_tens, bias_tens], conv_out_tens, attrs)
148 conv2d_op.add_input_tensor(out)
149 conv2d_op.run_on_npu = True
150 return op, conv2d_op
151
152
153def test_pad_followed_by_conv_is_removed():
Louis Verhaardae2d5532020-12-11 17:19:54 +0100154 """
155 Tests that the PAD operator is bypassed when followed by a convolution operator,
156 and that the padding of the convolution operation is correctly updated
157 """
Louis Verhaardc822d622021-03-11 14:59:06 +0100158 pad_op, conv2d_op = create_pad_and_conv2d(
159 in_shape=[1, 76, 75, 64], out_shape=[1, 76, 75, 64], padding=[[0, 0], [2, 1], [1, 1], [0, 0]], kernel_size=4
160 )
161 nng = testutil.create_graph([pad_op, conv2d_op])
Louis Verhaardae2d5532020-12-11 17:19:54 +0100162 arch = testutil.create_arch()
163
Louis Verhaardc822d622021-03-11 14:59:06 +0100164 replace_pad_by_hw_pad(conv2d_op, nng, arch)
Louis Verhaardae2d5532020-12-11 17:19:54 +0100165
Louis Verhaardc822d622021-03-11 14:59:06 +0100166 op = nng.subgraphs[0].output_tensors[0].ops[0]
167 assert op.type == Op.Conv2DBias
Louis Verhaardae2d5532020-12-11 17:19:54 +0100168 assert op.attrs["padding"] == Padding.EXPLICIT
169 assert op.attrs["explicit_padding"] == (2, 1, 1, 1)
170 assert op.ifm.shape == [1, 76, 75, 64]
171 assert pad_op not in op.ifm.ops
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100172
173
Louis Verhaardc822d622021-03-11 14:59:06 +0100174leading_pad_test_data = [
175 (2, 2, 11, True),
176 (1, 2, 11, False),
177 (2, 1, 11, False),
178 (5, 2, 11, True),
179]
180
181
182@pytest.mark.parametrize("top, left, kernel_size, expect_pad_removed", leading_pad_test_data)
183def test_leading_pad_size(top, left, kernel_size, expect_pad_removed):
184 # Tests PAD operator with big kernel size; top and left pad must be multiple of stride
185 out_shape = [1, 11 + left, 11 + top, 1]
186 padding = [[0, 0], [top, 0], [left, 0], [0, 0]]
187 pad_op, conv2d_op = create_pad_and_conv2d(
188 in_shape=[1, 11, 11, 1], out_shape=out_shape, padding=padding, kernel_size=kernel_size
189 )
190 nng = testutil.create_graph([pad_op, conv2d_op])
191 arch = testutil.create_arch()
192 replace_pad_by_hw_pad(conv2d_op, nng, arch)
193 op = nng.subgraphs[0].output_tensors[0].ops[0]
194 if expect_pad_removed:
195 assert op.attrs["padding"] == Padding.EXPLICIT
196 assert "explicit_padding" in op.attrs
197 assert op.ifm.shape == op.ofm.shape
198 assert pad_op not in op.ifm.ops
199 else:
200 assert pad_op in op.ifm.ops
201 assert op.attrs["padding"] == Padding.VALID
202 assert "explicit_padding" not in op.attrs
203
204
Louis Verhaard1a92f782021-02-09 16:08:26 +0100205def test_optimise_pad_followed_by_avg_pool():
206 """
207 Tests that the PAD operator is bypassed when followed by a average pool operator,
208 and that the average pool is converted to a depthwise
209 """
210 # Create Pad operation followed by AvgPool
211 quant = testutil.default_quant_params()
212 in_tens = Tensor([1, 76, 75, 64], DataType.uint8, "input")
213 in_tens.quantization = quant
Louis Verhaardc822d622021-03-11 14:59:06 +0100214 # Test with 3x2 input tensor
215 pad_input = create_const_tensor("pad_input", [3, 2], DataType.int32, [[2, 2], [1, 1], [0, 0]])
Louis Verhaard1a92f782021-02-09 16:08:26 +0100216 temp_tens = Tensor([1, 79, 77, 64], DataType.uint8, "pad_out")
217 temp_tens.quantization = quant.clone()
218 out_tens = Tensor([1, 76, 75, 64], DataType.uint8, "output")
219 out_tens.quantization = quant.clone()
220
221 pad_op = testutil.create_op(Op.Pad, [in_tens, pad_input], temp_tens)
222 attrs = {
223 "padding": Padding.VALID,
224 "ksize": [1, 5, 3, 1],
225 "stride_w": 2,
226 "stride_h": 2,
227 "dilation_w_factor": 1,
228 "dilation_h_factor": 1,
229 }
230 attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
231 pad_op.run_on_npu = True
232 conv2d_op = testutil.create_op(Op.AvgPool, [temp_tens], out_tens, attrs)
233 conv2d_op.run_on_npu = True
Louis Verhaardc822d622021-03-11 14:59:06 +0100234 nng = testutil.create_graph([pad_op, conv2d_op])
Louis Verhaard1a92f782021-02-09 16:08:26 +0100235 arch = testutil.create_arch()
236
Louis Verhaardc822d622021-03-11 14:59:06 +0100237 replace_pad_by_hw_pad(conv2d_op, nng, arch)
Louis Verhaard1a92f782021-02-09 16:08:26 +0100238
Louis Verhaardc822d622021-03-11 14:59:06 +0100239 op = nng.subgraphs[0].output_tensors[0].ops[0]
Louis Verhaard1a92f782021-02-09 16:08:26 +0100240 assert op.type == Op.DepthwiseConv2DBias
241 assert op.attrs["padding"] == Padding.EXPLICIT
Louis Verhaardc822d622021-03-11 14:59:06 +0100242 assert op.attrs["explicit_padding"] == (2, 1, 2, 1)
Louis Verhaard1a92f782021-02-09 16:08:26 +0100243 assert op.ifm.shape == [1, 76, 75, 64]
244 assert pad_op not in op.ifm.ops
245 # Check that bias and weight tensors have been added
246 assert op.bias.shape == [64]
Louis Verhaard1a92f782021-02-09 16:08:26 +0100247 assert op.weights.shape == [5, 3, 1, 64]
248
249
Louis Verhaardc822d622021-03-11 14:59:06 +0100250pad_avg_pool_test_data = [
251 ((3, 3), (1, 1, 1, 1), True),
252 ((3, 3), (2, 1, 1, 1), False),
253 ((3, 3), (1, 2, 1, 1), False),
254 ((3, 3), (1, 1, 2, 1), False),
255 ((3, 3), (1, 1, 1, 2), False),
256 ((2, 4), (1, 2, 1, 2), True),
257 ((5, 3), (2, 1, 2, 1), True),
258 ((5, 3), (0, 1, 2, 1), True),
259 ((5, 3), (2, 0, 2, 1), True),
260 ((5, 3), (2, 1, 0, 1), True),
261 ((5, 3), (2, 1, 0, 1), True),
262 ((4, 4), (2, 2, 2, 2), True),
263 ((4, 4), (1, 2, 2, 2), False),
264 ((4, 4), (2, 1, 2, 2), False),
265 ((4, 4), (2, 2, 1, 2), False),
266 ((4, 4), (2, 2, 2, 1), False),
267]
268
269
270@pytest.mark.parametrize("k_size, padding, expect_pad_removed", pad_avg_pool_test_data)
271def test_pad_followed_by_avg_pool(k_size, padding, expect_pad_removed):
272 # Tests PAD followed by AvgPool
273 k_w, k_h = k_size
274 top, left, bottom, right = padding
275 pad_values = [[0, 0], [top, bottom], [left, right], [0, 0]]
276 dtype = DataType.int8
277 qp = testutil.default_quant_params()
278 in_shape = [1, 15, 17, 8]
279 out_shape = [1, in_shape[1] + top + bottom, in_shape[2] + left + right, in_shape[3]]
280 in0 = Tensor(in_shape, dtype, "in")
281 in0.quantization = qp
282 pad_tensor = create_const_tensor(
283 name="pad", shape=list(np.shape(pad_values)), values=pad_values, dtype=DataType.int32
284 )
285 out = Tensor(out_shape, dtype, "out")
286 out.quantization = qp.clone()
287 pad_op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
288 pool_out_tens = Tensor(in_shape, dtype, "output")
289 pool_out_tens.quantization = qp.clone()
290 attrs = {
291 "padding": Padding.VALID,
292 "ksize": [1, k_w, k_h, 1],
293 "stride_w": 1,
294 "stride_h": 1,
295 "dilation_w_factor": 1,
296 "dilation_h_factor": 1,
297 }
298 pool_op = testutil.create_op(Op.AvgPool, [out], pool_out_tens, attrs)
Louis Verhaardc822d622021-03-11 14:59:06 +0100299 pad_op.run_on_npu = True
300 pool_op.run_on_npu = True
301 nng = testutil.create_graph([pad_op, pool_op])
302 arch = testutil.create_arch()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200303 nng = optimise_graph(nng, arch, NetworkType.TFLite)
Louis Verhaardc822d622021-03-11 14:59:06 +0100304 sg = nng.subgraphs[0]
305 all_ops = sg.get_all_ops()
306 print("all_ops: ", all_ops)
307 # Pad should not be in the graph anymore, it should either have been removed or rewritten
308 assert not any(op.type == Op.Pad for op in all_ops)
309 op = nng.subgraphs[0].output_tensors[0].ops[0]
310 if expect_pad_removed:
311 # Expect rewrite to depthwise, PAD is removed
312 assert op.type == Op.DepthwiseConv2DBias
313 assert op.attrs["padding"] == Padding.EXPLICIT
314 assert any(pad > 0 for pad in op.attrs["explicit_padding"])
315 assert op.ifm.shape == op.ofm.shape
316 # Check that bias and weight tensors have been added
317 assert len(op.bias.shape) > 0
318 assert op.weights.shape is not None
319 else:
320 # Pad should have been rewritten to a number of average pool operations
321 assert all(op.type in (Op.AvgPool, Op.Const) for op in all_ops)
322 assert pool_op.type == Op.AvgPool
323 assert pool_op.attrs["padding"] == Padding.VALID
324
325
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200326# Setup network to test removal of op with op_type Op.Reshape or Op.Squeeze
327# op_type should be Op.Reshape or Op.Squeeze
328def setup_network(op_type):
329 assert op_type == Op.Reshape or op_type == Op.Squeeze
330 if op_type == Op.Reshape:
331 op_str = "reshape"
332 elif op_type == Op.Squeeze:
333 op_str = "squeeze"
334
335 quant = testutil.default_quant_params()
336 # create reshape1 op
337 ifm_shape = [64, 16]
338 reshape1_ofm_shape = [1, 4, 16, 16]
339 reshape1_ifm = create_const_tensor(f"{op_str}1_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
340 reshape1_ifm.quantization = quant
341 reshape1_ofm = create_const_tensor(
342 f"{op_str}1_out", reshape1_ofm_shape, DataType.uint8, np.zeros(reshape1_ofm_shape)
343 )
344 reshape1_ofm.quantization = quant
345 shape_tens = create_const_tensor(f"{op_str}1_shape", [1], DataType.int32, reshape1_ofm_shape)
346 reshape1_op = testutil.create_op(op_type, [reshape1_ifm, shape_tens], reshape1_ofm, set_ifm_ofm_shapes=False)
347 reshape1_op.attrs["new_shape"] = reshape1_ofm_shape
348 reshape1_op.run_on_npu = True
349
350 # create conv op
351 conv_ofm = Tensor([1, 8, 8, 16], DataType.uint8, "output")
352 conv_ofm.quantization = quant.clone()
353 weight_tens = Tensor([1, 1, 16, 16], DataType.uint8, "weights")
354 weight_tens.values = np.zeros(weight_tens.shape, np.uint8)
355 weight_tens.quantization = quant.clone()
356 bias_tens = Tensor([16], DataType.int32, "biases")
357
358 attrs = {"padding": Padding.SAME, "stride_w": 1, "stride_h": 1, "dilation_w_factor": 1, "dilation_h_factor": 1}
359 attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
360
361 conv2d_op = testutil.create_op(
362 Op.Conv2D, [reshape1_ofm, weight_tens, bias_tens], conv_ofm, attrs=attrs, set_ifm_ofm_shapes=False
363 )
364 conv2d_op.run_on_npu = True
365
366 # create reshape2 op
367 ofm_shape = [8, 8, 16]
368 reshape2_ofm = create_const_tensor(f"{op_str}2_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
369 reshape2_ofm.quantization = quant
370 shape_tens = create_const_tensor(f"{op_str}2_shape", [1], DataType.int32, ofm_shape)
371 reshape2_op = testutil.create_op(op_type, [conv_ofm, shape_tens], reshape2_ofm, set_ifm_ofm_shapes=False)
372 reshape2_op.attrs["new_shape"] = ofm_shape
373 reshape2_op.run_on_npu = True
374 nng = Graph()
375 sg = testutil.create_subgraph([reshape1_op, conv2d_op, reshape2_op])
376 nng.subgraphs.append(sg)
377
378 return nng, reshape1_op, conv2d_op, reshape2_op
379
380
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100381def test_remove_reshape():
382 """
383 Tests that the expected reshape are removed in graph_optimisation
384 """
385
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100386 # Test1 no Reshape op is expected to remain in the NPU subgrapgh
387 # but first one will be put on CPU
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100388 # Network is Reshape-Conv-Reshape
389 # Result is Conv
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200390 nng, reshape1_op, conv2d_op, reshape2_op = setup_network(Op.Reshape)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100391 arch = testutil.create_arch()
392 assert verify_graph_health(nng)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200393 nng = optimise_graph(nng, arch, NetworkType.TFLite)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100394 assert verify_graph_health(nng)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100395
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100396 # Test2 reshape1 with different quantisation, this Reshape op is expected to remain
397 # Network is Reshape-Conv-Reshape
398 # expected is Reshape-Conv
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200399 nng, reshape1_op, conv2d_op, reshape2_op = setup_network(Op.Reshape)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100400 quant_zp32 = testutil.default_quant_params()
401 quant_zp32.zero_point = 32
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100402 reshape1_op.ofm.quantization = quant_zp32
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100403 assert verify_graph_health(nng)
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200404 nng = optimise_graph(nng, arch, NetworkType.TFLite)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100405 assert verify_graph_health(nng)
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200406
407
408def test_remove_squeeze():
409 """
410 Tests that the expected squeeze are removed in graph_optimisation
411 """
412
413 # Test1 no Squeeze op is expected to remain in the NPU subgrapgh
414 # but first one will be put on CPU
415 # Network is Squeeze-Conv-Squeeze
416 # Result is Conv
417 nng, squeeze1_op, conv2d_op, squeeze2_op = setup_network(Op.Squeeze)
418 arch = testutil.create_arch()
419 assert verify_graph_health(nng)
420 nng = optimise_graph(nng, arch, NetworkType.TFLite)
421 assert verify_graph_health(nng)
422
423 # Test2 squeeze1 with different quantisation, this Squeeze op is expected to remain
424 # Network is Squeeze-Conv-Squeeze
425 # expected is Squeeze-Conv
426 nng, squeeze1_op, conv2d_op, squeeze2_op = setup_network(Op.Squeeze)
427 quant_zp32 = testutil.default_quant_params()
428 quant_zp32.zero_point = 32
429 squeeze1_op.ofm.quantization = quant_zp32
430 assert verify_graph_health(nng)
431 nng = optimise_graph(nng, arch, NetworkType.TFLite)
432 assert verify_graph_health(nng)