blob: 152669f7eba8b00aa41ec9f514f2cf3307b346e3 [file] [log] [blame]
Rickard Bolinbc6ee582022-11-04 08:24:29 +00001# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
Diqing Zhong94457b12020-12-09 15:22:40 +01002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020018# Unit tests for tflite_graph_optimiser
Diqing Zhong94457b12020-12-09 15:22:40 +010019import numpy as np
Louis Verhaardebf4af62021-01-27 15:57:57 +010020import pytest
Diqing Zhong94457b12020-12-09 15:22:40 +010021
Louis Verhaardae2d5532020-12-11 17:19:54 +010022from ethosu.vela.data_type import DataType
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020023from ethosu.vela.graph_optimiser import optimise_graph
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020024from ethosu.vela.nn_graph import NetworkType
Diqing Zhong94457b12020-12-09 15:22:40 +010025from ethosu.vela.operation import Op
Ayaan Masood25f48dd2022-06-29 18:16:04 +010026from ethosu.vela.operation import Operation
Louis Verhaardae2d5532020-12-11 17:19:54 +010027from ethosu.vela.operation import Padding
Patrik Gustavsson3a269202021-01-21 08:28:55 +010028from ethosu.vela.rewrite_graph import verify_graph_health
Diqing Zhong94457b12020-12-09 15:22:40 +010029from ethosu.vela.tensor import create_const_tensor
patrik.gustavssoneeb85152020-12-21 17:10:40 +000030from ethosu.vela.tensor import Shape4D
Diqing Zhong94457b12020-12-09 15:22:40 +010031from ethosu.vela.tensor import Tensor
32from ethosu.vela.test import testutil
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020033from ethosu.vela.tflite_graph_optimiser import calc_explicit_padding
34from ethosu.vela.tflite_graph_optimiser import convert_batched_fc_shape
Ayaan Masood25f48dd2022-06-29 18:16:04 +010035from ethosu.vela.tflite_graph_optimiser import optimise_quantize
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +020036from ethosu.vela.tflite_graph_optimiser import replace_pad_by_hw_pad
37from ethosu.vela.tflite_graph_optimiser import rewrite_fully_connected_input
Diqing Zhong94457b12020-12-09 15:22:40 +010038
39
40def test_convert_batched_fc():
41 """Tests shape conversion of batched fully connected"""
Patrik Gustavsson3a269202021-01-21 08:28:55 +010042 ifm_shape = [4, 8]
43 ifm = create_const_tensor("test_in", ifm_shape, np.uint8, np.zeros(ifm_shape))
44 w_shape = [8, 4]
45 weights = create_const_tensor("weight_in", w_shape, np.uint8, np.zeros(w_shape))
Diqing Zhong94457b12020-12-09 15:22:40 +010046 ofm = Tensor(ifm.shape, np.uint8, "test_out")
47 op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm)
Patrik Gustavsson2349d422020-12-01 16:02:29 +010048
Diqing Zhong94457b12020-12-09 15:22:40 +010049 ifm.consumer_list.append(op)
50
51 prev_op = op.clone()
Patrik Gustavsson3a269202021-01-21 08:28:55 +010052 prev_op.ifm_shapes = op.ifm_shapes.copy()
53 prev_op.ofm_shapes = op.ofm_shapes.copy()
Patrik Gustavsson2349d422020-12-01 16:02:29 +010054
Patrik Gustavsson2c2522d2021-01-29 11:51:31 +010055 rewrite_fully_connected_input(op, None, None)
Diqing Zhong94457b12020-12-09 15:22:40 +010056 conv_op = convert_batched_fc_shape(op, None, None)
Diqing Zhong94457b12020-12-09 15:22:40 +010057 assert conv_op.ifm == prev_op.ifm
58 assert conv_op.ofm == prev_op.ofm
Patrik Gustavsson3a269202021-01-21 08:28:55 +010059 assert op.ifm_shapes[0] == Shape4D([1, 2, 2, 8])
60 assert op.ofm_shapes[0] == Shape4D([1, 2, 2, 8])
Diqing Zhong94457b12020-12-09 15:22:40 +010061 assert conv_op.type == Op.FullyConnected
62 assert len(conv_op.ifm.shape) == 2
Patrik Gustavsson3a269202021-01-21 08:28:55 +010063 assert len(conv_op.ofm.shape) == 2
64 assert conv_op.ifm.shape == conv_op.ofm.shape
65
66 ifm.shape = [1, 8]
67 weights.shape = [8, 1]
68 ofm.shape = [1, 8]
69 op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm)
70 ifm.consumer_list.append(op)
71
72 prev_op = op.clone()
73 prev_op.ifm_shapes = op.ifm_shapes.copy()
74 prev_op.ofm_shapes = op.ofm_shapes.copy()
75
Patrik Gustavsson2c2522d2021-01-29 11:51:31 +010076 rewrite_fully_connected_input(op, None, None)
Patrik Gustavsson3a269202021-01-21 08:28:55 +010077 conv_op = convert_batched_fc_shape(op, None, None)
78
79 assert conv_op.ifm == prev_op.ifm
80 assert conv_op.ofm == prev_op.ofm
81 assert op.ifm_shapes[0] == prev_op.ifm_shapes[0]
82 assert op.ofm_shapes[0] == prev_op.ofm_shapes[0]
83 assert conv_op.type == Op.FullyConnected
84 assert len(conv_op.ifm.shape) == 2
85 assert len(conv_op.ofm.shape) == 2
Diqing Zhong94457b12020-12-09 15:22:40 +010086 assert conv_op.ifm.shape == conv_op.ofm.shape
Louis Verhaardae2d5532020-12-11 17:19:54 +010087
88
Louis Verhaardebf4af62021-01-27 15:57:57 +010089explicit_padding_test_data = [
90 # Kernel size 2
91 [(17, 1, 2, 1, 1), (1, 1)],
92 [(18, 1, 2, 0, 1), (0, 1)],
93 [(18, 1, 2, 1, 0), (1, 0)],
94 # Kernel size 3
95 [(18, 2, 3, 1, 1), (1, 0)],
96 [(25, 2, 3, 1, 1), (1, 1)],
97 # Kernel size 4
98 [(18, 1, 4, 1, 2), (1, 2)],
99 [(18, 1, 4, 2, 1), (2, 1)],
100 [(19, 1, 4, 2, 2), (2, 2)],
101 # Kernel size 5
102 [(19, 1, 5, 1, 2), (1, 2)],
103 [(19, 1, 5, 0, 2), (0, 2)],
104 [(19, 1, 5, 1, 0), (1, 0)],
105 # Kernel size 21
106 [(41, 2, 21, 8, 10), (8, 10)],
107 [(41, 3, 21, 10, 10), (10, 9)],
108 [(42, 3, 21, 10, 10), (10, 8)],
109 [(42, 3, 21, 9, 10), (9, 9)],
110 [(41, 3, 21, 10, 6), (10, 6)],
111]
112
113
114@pytest.mark.parametrize("test_input, expected_result", explicit_padding_test_data)
115def test_calc_explicit_padding(test_input, expected_result):
116 input_size, stride, filter_size, explicit_pad_before, explicit_pad_after = test_input
117 before, after = calc_explicit_padding(input_size, stride, filter_size, explicit_pad_before, explicit_pad_after)
118 assert (before, after) == expected_result
119
120
Louis Verhaardc822d622021-03-11 14:59:06 +0100121def create_pad_and_conv2d(
122 in_shape,
123 out_shape,
124 padding,
125 in_dtype=DataType.int8,
126 out_dtype=DataType.int8,
127 pad_dtype=DataType.int32,
128 pad_setting=Padding.VALID,
129 kernel_size=3,
130):
131 """Creates Pad operator followed by a conv2d operator"""
132 qp = testutil.default_quant_params()
133 in0 = Tensor(in_shape, in_dtype, "in")
134 in0.quantization = qp
135 pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
136 out = Tensor(out_shape, out_dtype, "out")
137 out.quantization = qp.clone()
138 op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
139 op.run_on_npu = True
140 conv_out_tens = Tensor(in_shape, in_dtype, "output")
141 conv_out_tens.quantization = qp.clone()
142 weight_tens = Tensor([kernel_size, kernel_size, in_shape[-1], out_shape[-1]], in_dtype, "weights")
James Peet7519d502021-07-19 16:47:58 +0100143 weight_tens.values = np.zeros(weight_tens.shape, in_dtype.as_numpy_type())
Louis Verhaardc822d622021-03-11 14:59:06 +0100144 weight_tens.quantization = qp.clone()
145 bias_tens = Tensor(out_shape, pad_dtype, "biases")
146 attrs = {"padding": pad_setting, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1}
147 attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
148 conv2d_op = testutil.create_op(Op.Conv2DBias, [out, weight_tens, bias_tens], conv_out_tens, attrs)
149 conv2d_op.add_input_tensor(out)
150 conv2d_op.run_on_npu = True
151 return op, conv2d_op
152
153
154def test_pad_followed_by_conv_is_removed():
Louis Verhaardae2d5532020-12-11 17:19:54 +0100155 """
156 Tests that the PAD operator is bypassed when followed by a convolution operator,
157 and that the padding of the convolution operation is correctly updated
158 """
Louis Verhaardc822d622021-03-11 14:59:06 +0100159 pad_op, conv2d_op = create_pad_and_conv2d(
160 in_shape=[1, 76, 75, 64], out_shape=[1, 76, 75, 64], padding=[[0, 0], [2, 1], [1, 1], [0, 0]], kernel_size=4
161 )
162 nng = testutil.create_graph([pad_op, conv2d_op])
Louis Verhaardae2d5532020-12-11 17:19:54 +0100163 arch = testutil.create_arch()
164
Louis Verhaardc822d622021-03-11 14:59:06 +0100165 replace_pad_by_hw_pad(conv2d_op, nng, arch)
Louis Verhaardae2d5532020-12-11 17:19:54 +0100166
Louis Verhaardc822d622021-03-11 14:59:06 +0100167 op = nng.subgraphs[0].output_tensors[0].ops[0]
168 assert op.type == Op.Conv2DBias
Louis Verhaardae2d5532020-12-11 17:19:54 +0100169 assert op.attrs["padding"] == Padding.EXPLICIT
170 assert op.attrs["explicit_padding"] == (2, 1, 1, 1)
171 assert op.ifm.shape == [1, 76, 75, 64]
172 assert pad_op not in op.ifm.ops
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100173
174
Louis Verhaardc822d622021-03-11 14:59:06 +0100175leading_pad_test_data = [
176 (2, 2, 11, True),
177 (1, 2, 11, False),
178 (2, 1, 11, False),
179 (5, 2, 11, True),
180]
181
182
183@pytest.mark.parametrize("top, left, kernel_size, expect_pad_removed", leading_pad_test_data)
184def test_leading_pad_size(top, left, kernel_size, expect_pad_removed):
185 # Tests PAD operator with big kernel size; top and left pad must be multiple of stride
186 out_shape = [1, 11 + left, 11 + top, 1]
187 padding = [[0, 0], [top, 0], [left, 0], [0, 0]]
188 pad_op, conv2d_op = create_pad_and_conv2d(
189 in_shape=[1, 11, 11, 1], out_shape=out_shape, padding=padding, kernel_size=kernel_size
190 )
191 nng = testutil.create_graph([pad_op, conv2d_op])
192 arch = testutil.create_arch()
193 replace_pad_by_hw_pad(conv2d_op, nng, arch)
194 op = nng.subgraphs[0].output_tensors[0].ops[0]
195 if expect_pad_removed:
196 assert op.attrs["padding"] == Padding.EXPLICIT
197 assert "explicit_padding" in op.attrs
198 assert op.ifm.shape == op.ofm.shape
199 assert pad_op not in op.ifm.ops
200 else:
201 assert pad_op in op.ifm.ops
202 assert op.attrs["padding"] == Padding.VALID
203 assert "explicit_padding" not in op.attrs
204
205
Louis Verhaard1a92f782021-02-09 16:08:26 +0100206def test_optimise_pad_followed_by_avg_pool():
207 """
208 Tests that the PAD operator is bypassed when followed by a average pool operator,
209 and that the average pool is converted to a depthwise
210 """
211 # Create Pad operation followed by AvgPool
212 quant = testutil.default_quant_params()
213 in_tens = Tensor([1, 76, 75, 64], DataType.uint8, "input")
214 in_tens.quantization = quant
Louis Verhaardc822d622021-03-11 14:59:06 +0100215 # Test with 3x2 input tensor
216 pad_input = create_const_tensor("pad_input", [3, 2], DataType.int32, [[2, 2], [1, 1], [0, 0]])
Louis Verhaard1a92f782021-02-09 16:08:26 +0100217 temp_tens = Tensor([1, 79, 77, 64], DataType.uint8, "pad_out")
218 temp_tens.quantization = quant.clone()
219 out_tens = Tensor([1, 76, 75, 64], DataType.uint8, "output")
220 out_tens.quantization = quant.clone()
221
222 pad_op = testutil.create_op(Op.Pad, [in_tens, pad_input], temp_tens)
223 attrs = {
224 "padding": Padding.VALID,
225 "ksize": [1, 5, 3, 1],
226 "stride_w": 2,
227 "stride_h": 2,
228 "dilation_w_factor": 1,
229 "dilation_h_factor": 1,
230 }
231 attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
232 pad_op.run_on_npu = True
233 conv2d_op = testutil.create_op(Op.AvgPool, [temp_tens], out_tens, attrs)
234 conv2d_op.run_on_npu = True
Louis Verhaardc822d622021-03-11 14:59:06 +0100235 nng = testutil.create_graph([pad_op, conv2d_op])
Louis Verhaard1a92f782021-02-09 16:08:26 +0100236 arch = testutil.create_arch()
237
Louis Verhaardc822d622021-03-11 14:59:06 +0100238 replace_pad_by_hw_pad(conv2d_op, nng, arch)
Louis Verhaard1a92f782021-02-09 16:08:26 +0100239
Louis Verhaardc822d622021-03-11 14:59:06 +0100240 op = nng.subgraphs[0].output_tensors[0].ops[0]
Louis Verhaard1a92f782021-02-09 16:08:26 +0100241 assert op.type == Op.DepthwiseConv2DBias
242 assert op.attrs["padding"] == Padding.EXPLICIT
Louis Verhaardc822d622021-03-11 14:59:06 +0100243 assert op.attrs["explicit_padding"] == (2, 1, 2, 1)
Louis Verhaard1a92f782021-02-09 16:08:26 +0100244 assert op.ifm.shape == [1, 76, 75, 64]
245 assert pad_op not in op.ifm.ops
246 # Check that bias and weight tensors have been added
247 assert op.bias.shape == [64]
Louis Verhaard1a92f782021-02-09 16:08:26 +0100248 assert op.weights.shape == [5, 3, 1, 64]
249
250
Louis Verhaardc822d622021-03-11 14:59:06 +0100251pad_avg_pool_test_data = [
252 ((3, 3), (1, 1, 1, 1), True),
253 ((3, 3), (2, 1, 1, 1), False),
254 ((3, 3), (1, 2, 1, 1), False),
255 ((3, 3), (1, 1, 2, 1), False),
256 ((3, 3), (1, 1, 1, 2), False),
257 ((2, 4), (1, 2, 1, 2), True),
258 ((5, 3), (2, 1, 2, 1), True),
259 ((5, 3), (0, 1, 2, 1), True),
260 ((5, 3), (2, 0, 2, 1), True),
261 ((5, 3), (2, 1, 0, 1), True),
262 ((5, 3), (2, 1, 0, 1), True),
263 ((4, 4), (2, 2, 2, 2), True),
264 ((4, 4), (1, 2, 2, 2), False),
265 ((4, 4), (2, 1, 2, 2), False),
266 ((4, 4), (2, 2, 1, 2), False),
267 ((4, 4), (2, 2, 2, 1), False),
268]
269
270
271@pytest.mark.parametrize("k_size, padding, expect_pad_removed", pad_avg_pool_test_data)
272def test_pad_followed_by_avg_pool(k_size, padding, expect_pad_removed):
273 # Tests PAD followed by AvgPool
274 k_w, k_h = k_size
275 top, left, bottom, right = padding
276 pad_values = [[0, 0], [top, bottom], [left, right], [0, 0]]
277 dtype = DataType.int8
278 qp = testutil.default_quant_params()
279 in_shape = [1, 15, 17, 8]
280 out_shape = [1, in_shape[1] + top + bottom, in_shape[2] + left + right, in_shape[3]]
281 in0 = Tensor(in_shape, dtype, "in")
282 in0.quantization = qp
283 pad_tensor = create_const_tensor(
284 name="pad", shape=list(np.shape(pad_values)), values=pad_values, dtype=DataType.int32
285 )
286 out = Tensor(out_shape, dtype, "out")
287 out.quantization = qp.clone()
288 pad_op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
289 pool_out_tens = Tensor(in_shape, dtype, "output")
290 pool_out_tens.quantization = qp.clone()
291 attrs = {
292 "padding": Padding.VALID,
293 "ksize": [1, k_w, k_h, 1],
294 "stride_w": 1,
295 "stride_h": 1,
296 "dilation_w_factor": 1,
297 "dilation_h_factor": 1,
298 }
299 pool_op = testutil.create_op(Op.AvgPool, [out], pool_out_tens, attrs)
Louis Verhaardc822d622021-03-11 14:59:06 +0100300 pad_op.run_on_npu = True
301 pool_op.run_on_npu = True
302 nng = testutil.create_graph([pad_op, pool_op])
303 arch = testutil.create_arch()
Patrik Gustavsson8f1f9aa2021-06-28 07:41:58 +0200304 nng = optimise_graph(nng, arch, NetworkType.TFLite)
Louis Verhaardc822d622021-03-11 14:59:06 +0100305 sg = nng.subgraphs[0]
306 all_ops = sg.get_all_ops()
307 print("all_ops: ", all_ops)
308 # Pad should not be in the graph anymore, it should either have been removed or rewritten
309 assert not any(op.type == Op.Pad for op in all_ops)
310 op = nng.subgraphs[0].output_tensors[0].ops[0]
311 if expect_pad_removed:
312 # Expect rewrite to depthwise, PAD is removed
313 assert op.type == Op.DepthwiseConv2DBias
314 assert op.attrs["padding"] == Padding.EXPLICIT
315 assert any(pad > 0 for pad in op.attrs["explicit_padding"])
316 assert op.ifm.shape == op.ofm.shape
317 # Check that bias and weight tensors have been added
318 assert len(op.bias.shape) > 0
319 assert op.weights.shape is not None
320 else:
321 # Pad should have been rewritten to a number of average pool operations
322 assert all(op.type in (Op.AvgPool, Op.Const) for op in all_ops)
323 assert pool_op.type == Op.AvgPool
324 assert pool_op.attrs["padding"] == Padding.VALID
325
326
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200327def test_remove_reshape():
328 """
329 Test that the expected reshape are removed in graph_optimisation
330 """
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200331
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200332 # Create tensors and operators Test1
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200333 quant = testutil.default_quant_params()
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200334
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200335 # create reshape1 op
336 ifm_shape = [64, 16]
337 reshape1_ofm_shape = [1, 4, 16, 16]
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200338 reshape1_ifm = create_const_tensor("reshape1_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200339 reshape1_ifm.quantization = quant
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200340 reshape1_ofm = create_const_tensor("reshape1_out", reshape1_ofm_shape, DataType.uint8, np.zeros(reshape1_ofm_shape))
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200341 reshape1_ofm.quantization = quant
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200342 shape_tens = create_const_tensor("reshape1_shape", [1], DataType.int32, reshape1_ofm_shape)
343 reshape1_op = testutil.create_op(Op.Reshape, [reshape1_ifm, shape_tens], reshape1_ofm, set_ifm_ofm_shapes=False)
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200344 reshape1_op.attrs["new_shape"] = reshape1_ofm_shape
345 reshape1_op.run_on_npu = True
346
347 # create conv op
348 conv_ofm = Tensor([1, 8, 8, 16], DataType.uint8, "output")
349 conv_ofm.quantization = quant.clone()
350 weight_tens = Tensor([1, 1, 16, 16], DataType.uint8, "weights")
351 weight_tens.values = np.zeros(weight_tens.shape, np.uint8)
352 weight_tens.quantization = quant.clone()
353 bias_tens = Tensor([16], DataType.int32, "biases")
354
355 attrs = {"padding": Padding.SAME, "stride_w": 1, "stride_h": 1, "dilation_w_factor": 1, "dilation_h_factor": 1}
356 attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
357
358 conv2d_op = testutil.create_op(
359 Op.Conv2D, [reshape1_ofm, weight_tens, bias_tens], conv_ofm, attrs=attrs, set_ifm_ofm_shapes=False
360 )
361 conv2d_op.run_on_npu = True
362
363 # create reshape2 op
364 ofm_shape = [8, 8, 16]
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200365 reshape2_ofm = create_const_tensor("reshape2_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200366 reshape2_ofm.quantization = quant
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200367 shape_tens = create_const_tensor("reshape2_shape", [1], DataType.int32, ofm_shape)
368 reshape2_op = testutil.create_op(Op.Reshape, [conv_ofm, shape_tens], reshape2_ofm, set_ifm_ofm_shapes=False)
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200369 reshape2_op.attrs["new_shape"] = ofm_shape
370 reshape2_op.run_on_npu = True
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100371
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100372 # Test1 no Reshape op is expected to remain in the NPU subgrapgh
373 # but first one will be put on CPU
Patrik Gustavsson138d47f2021-02-08 10:13:48 +0100374 # Network is Reshape-Conv-Reshape
375 # Result is Conv
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200376 nng = testutil.create_graph([reshape1_op, conv2d_op, reshape2_op])
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100377 arch = testutil.create_arch()
378 assert verify_graph_health(nng)
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200379 nng = optimise_graph(nng, arch, NetworkType.TFLite, True)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100380 assert verify_graph_health(nng)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100381
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200382 # Create tensors and operator Test2
383 # create reshape op
384 reshape_ifm = create_const_tensor("reshape_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
385 reshape_ifm.quantization = quant
386 reshape_ofm = create_const_tensor("reshape1_out", reshape1_ofm_shape, DataType.uint8, np.zeros(reshape1_ofm_shape))
387 reshape_ofm.quantization = quant
388 shape_tens = create_const_tensor("reshape1_shape", [1], DataType.int32, reshape1_ofm_shape)
389 reshape_op = testutil.create_op(Op.Reshape, [reshape_ifm, shape_tens], reshape_ofm, set_ifm_ofm_shapes=False)
390 reshape_op.attrs["new_shape"] = reshape1_ofm_shape
391 reshape_op.run_on_npu = True
392
393 # Test2 Reshape ifm/ofm is sg input/output.
394 # Reshape op is expected to be replaced by a AvgPool 'NOP'.
395 #
396 # Network is Reshape
397 # expected is AvgPool
398 nng = testutil.create_graph([reshape_op])
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100399 assert verify_graph_health(nng)
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200400 nng = optimise_graph(nng, arch, NetworkType.TFLite, True)
Patrik Gustavsson3a269202021-01-21 08:28:55 +0100401 assert verify_graph_health(nng)
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200402
403
404def test_remove_squeeze():
405 """
406 Tests that the expected squeeze are removed in graph_optimisation
407 """
408
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200409 # Create tensors and operators Test1
410 quant = testutil.default_quant_params()
411
412 # create conv op
413 ifm_shape = [1, 1, 1, 1024]
414 conv_ifm = create_const_tensor("conv_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
415 conv_ifm.quantization = quant
416 conv_ofm = Tensor([1, 1, 1, 1001], DataType.uint8, "output")
417 conv_ofm.quantization = quant.clone()
418 weight_tens = Tensor([1, 1, 1024, 1001], DataType.uint8, "weights")
419 weight_tens.values = np.zeros(weight_tens.shape, np.uint8)
420 weight_tens.quantization = quant.clone()
421 bias_tens = Tensor([1001], DataType.int32, "biases")
422
423 attrs = {"padding": Padding.SAME, "stride_w": 1, "stride_h": 1, "dilation_w_factor": 1, "dilation_h_factor": 1}
424 attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
425
426 conv2d_op = testutil.create_op(
427 Op.Conv2D, [conv_ifm, weight_tens, bias_tens], conv_ofm, attrs=attrs, set_ifm_ofm_shapes=False
428 )
429 conv2d_op.run_on_npu = True
430
431 # create squeeze op
432 ofm_shape = [1, 1001]
433 squeeze_ofm = create_const_tensor("squeeze_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
434 squeeze_ofm.quantization = quant.clone()
435 squeeze_op = testutil.create_op(Op.Squeeze, [conv_ofm], squeeze_ofm, set_ifm_ofm_shapes=False)
436 squeeze_op.attrs["squeeze_dims"] = [1, 2]
437 squeeze_op.run_on_npu = True
438
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200439 # Test1 no Squeeze op is expected to remain in the NPU subgrapgh
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200440 #
441 # Network is Conv-Squeeze
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200442 # Result is Conv
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200443 nng = testutil.create_graph([conv2d_op, squeeze_op])
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200444 arch = testutil.create_arch()
445 assert verify_graph_health(nng)
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200446 nng = optimise_graph(nng, arch, NetworkType.TFLite, True)
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200447 assert verify_graph_health(nng)
448
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200449 # Create tensors and operator Test2
450 # create squeeze op
451 ifm_shape = [1, 1, 1, 1001]
452 squeeze_ifm = create_const_tensor("squeeze_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
453 squeeze_ifm.quantization = quant
454 squeeze_ofm = create_const_tensor("squeeze_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
455 squeeze_ofm.quantization = quant.clone()
456 squeeze_op = testutil.create_op(Op.Squeeze, [squeeze_ifm], squeeze_ofm, set_ifm_ofm_shapes=False)
457 squeeze_op.attrs["squeeze_dims"] = [1, 2]
458 squeeze_op.run_on_npu = True
459
460 # Test2 Squeeze ifm/ofm is sg input/output.
461 # Squeeze op is expected to be replaced by a AvgPool 'NOP'.
462 #
463 # Network is Squeeze
464 # expected is AvgPool
465 nng = testutil.create_graph([squeeze_op])
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200466 assert verify_graph_health(nng)
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200467 nng = optimise_graph(nng, arch, NetworkType.TFLite, True)
468 assert verify_graph_health(nng)
469
470
471def test_remove_expand_dims():
472 """
473 Tests that the expected ExpandDims are removed in graph_optimisation
474 """
475
476 # Create tensors and operators Test1
477 quant = testutil.default_quant_params()
478
479 # create ExpandDims op
480 ifm_shape = [4, 16, 16]
481 ofm_shape = [1, 4, 16, 16]
482 expand_dims_ifm = create_const_tensor("expand_dims_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
483 expand_dims_ifm.quantization = quant
484 expand_dims_ofm = create_const_tensor("expand_dims_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
485 expand_dims_ofm.quantization = quant.clone()
486 dim_tens = create_const_tensor("dim_tens", [], DataType.uint8, 1)
487 expand_dims_op = testutil.create_op(
488 Op.ExpandDims, [expand_dims_ifm, dim_tens], expand_dims_ofm, set_ifm_ofm_shapes=False
489 )
490 expand_dims_op.run_on_npu = True
491
492 # create conv op
493 conv_ofm = Tensor([1, 8, 8, 16], DataType.uint8, "output")
494 conv_ofm.quantization = quant.clone()
495 weight_tens = Tensor([1, 1, 16, 16], DataType.uint8, "weights")
496 weight_tens.values = np.zeros(weight_tens.shape, np.uint8)
497 weight_tens.quantization = quant.clone()
498 bias_tens = Tensor([16], DataType.int32, "biases")
499
500 attrs = {"padding": Padding.SAME, "stride_w": 1, "stride_h": 1, "dilation_w_factor": 1, "dilation_h_factor": 1}
501 attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
502
503 conv2d_op = testutil.create_op(
504 Op.Conv2D, [expand_dims_ofm, weight_tens, bias_tens], conv_ofm, attrs=attrs, set_ifm_ofm_shapes=False
505 )
506 conv2d_op.run_on_npu = True
507
508 # Test1 no ExpandDims op is expected to remain in the NPU subgrapgh
509 #
510 # Network is ExpandDims-Conv
511 # Result is Conv
512 nng = testutil.create_graph([expand_dims_op, conv2d_op])
513 arch = testutil.create_arch()
514 assert verify_graph_health(nng)
515 nng = optimise_graph(nng, arch, NetworkType.TFLite, True)
516 assert verify_graph_health(nng)
517
518 # create ExpandDims op
519 expand_dims_ifm = create_const_tensor("expand_dims_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
520 expand_dims_ifm.quantization = quant
521 expand_dims_ofm = create_const_tensor("expand_dims_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
522 expand_dims_ofm.quantization = quant.clone()
523 dim_tens = create_const_tensor("dim_tens", [], DataType.uint8, 1)
524 expand_dims_op = testutil.create_op(
525 Op.ExpandDims, [expand_dims_ifm, dim_tens], expand_dims_ofm, set_ifm_ofm_shapes=False
526 )
527 expand_dims_op.run_on_npu = True
528
529 # Test2 ExpandDims ifm/ofm is sg input/output.
530 # ExpandDims op is expected to be replaced by a AvgPool 'NOP'.
531 #
532 # Network is ExpandDims
533 # expected is AvgPool
534 nng = testutil.create_graph([expand_dims_op])
535 assert verify_graph_health(nng)
536 nng = optimise_graph(nng, arch, NetworkType.TFLite, True)
Jonas Ohlssonfbfd96e2021-08-25 11:38:03 +0200537 assert verify_graph_health(nng)
Ayaan Masood25f48dd2022-06-29 18:16:04 +0100538
539
540def test_quant_static_optimisations():
541
542 """
543 Tests if the quant value at vela compile time is calculated correctly
544 """
545
546 quant_ifm = create_const_tensor(
547 "const_quant_ifm", values=np.array(127), value_dtype=np.int8, shape=[], dtype=DataType.int8
548 )
549 quant_ifm.quantization = testutil.default_quant_params()
550 quant_ifm.quantization.scale_f32 = 0.15748031
551 quant_ifm.quantization.quant_min = -128
552 quant_ifm.quantization.quant_max = 127
553
554 quant_ofm = create_const_tensor("const_quant_ofm", values=np.array([]), shape=[], dtype=DataType.int8)
555 quant_ofm.quantization = testutil.default_quant_params()
556 quant_ofm.quantization.scale_f32 = 0.036092404
557 quant_ofm.quantization.zero_point = -128
558 quant_ofm.quantization.quant_min = -128
559 quant_ofm.quantization.quant_max = 127
560
561 # Create quant op
562
563 quant_op = testutil.create_op(Op.Quantize, [quant_ifm], quant_ofm)
564
565 quant_op.run_on_npu = True
566
567 op: Operation = optimise_quantize(quant_op, None, None)
568
569 assert op.ofm.values == 127
570
571 quant_ifm = create_const_tensor(
572 "const_quant_ifm", values=np.array(127), value_dtype=np.int8, shape=[], dtype=DataType.int8
573 )
574 quant_ifm.quantization = testutil.default_quant_params()
575 quant_ifm.quantization.scale_f32 = 0.15748031
576 quant_ifm.quantization.quant_min = -128
577 quant_ifm.quantization.quant_max = 127
578
579 quant_ofm = create_const_tensor("const_quant_ofm", values=np.array([]), shape=[], dtype=DataType.int8)
580 quant_ofm.quantization = testutil.default_quant_params()
581 quant_ofm.quantization.scale_f32 = 0.036092404
582 quant_ofm.quantization.zero_point = -128
583 quant_ofm.quantization.quant_min = -128
584 quant_ofm.quantization.quant_max = 127
585
586 # Create quant op
587
588 quant_op = testutil.create_op(Op.Quantize, [quant_ifm], quant_ofm)
589
590 quant_op.run_on_npu = True
591
592 op: Operation = optimise_quantize(quant_op, None, None)
593
594 assert op.ofm.values == 127
595
596
597def test_optimise_quantize_multiple_values():
598 """
599 Tests if the quant value at vela compile time is calculated correctly
600 when passing multiple values to quantize node
601 """
602
603 quant_ifm = create_const_tensor(
604 "const_quant_ifm", values=np.array([127, 127]), value_dtype=np.int8, shape=[], dtype=DataType.int8
605 )
606 quant_ifm.quantization = testutil.default_quant_params()
607 quant_ifm.quantization.scale_f32 = 0.15748031
608 quant_ifm.quantization.quant_min = -128
609 quant_ifm.quantization.quant_max = 127
610
611 quant_ofm = create_const_tensor("const_quant_ofm", values=np.array([]), shape=[], dtype=DataType.int8)
612 quant_ofm.quantization = testutil.default_quant_params()
613 quant_ofm.quantization.scale_f32 = 0.036092404
614 quant_ofm.quantization.zero_point = -128
615 quant_ofm.quantization.quant_min = -128
616 quant_ofm.quantization.quant_max = 127
617
618 # Create quant op
619
620 quant_op = testutil.create_op(Op.Quantize, [quant_ifm], quant_ofm)
621
622 quant_op.run_on_npu = True
623
624 op: Operation = optimise_quantize(quant_op, None, None)
625
626 assert (op.ofm.values == np.array([127, 127])).all()