blob: e290dd2c4438399298b186e0848e8a0c0ccadff7 [file] [log] [blame]
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Unit tests for tflite_model_semantic
19import numpy as np
20
21from ethosu.vela.data_type import DataType
22from ethosu.vela.operation import Op
23from ethosu.vela.operation import Padding
24from ethosu.vela.tensor import create_const_tensor
25from ethosu.vela.tensor import QuantizationParameters
26from ethosu.vela.tensor import Tensor
27from ethosu.vela.test import testutil
28from ethosu.vela.tflite_model_semantic import TFLiteSemantic
29
30semantic_checker = TFLiteSemantic()
31
32
33def test_constraint_tens_no_dynamic():
34 # Tensors cannot be dynamic (no shape, not a scalar)
35 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [])
36 assert not semantic_checker.is_operator_semantic_valid(op)
37
38
39def test_constraint_tens_defined_shape():
40 # Tensors cannot have None in them
41 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, None, 8], [1, 8, 8, 8])
42 assert not semantic_checker.is_operator_semantic_valid(op)
43
44
45def test_constraint_tens_output_scalar():
46 # Scalar output is not allowed at all:
47 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [])
48 op.ofm.values = 0.5
49 assert not semantic_checker.is_operator_semantic_valid(op)
50
51
52def test_constraint_tens_input_scalar():
53 # Shapeless input is allowed if its of a certain type:
54 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8])
55 assert semantic_checker.is_operator_semantic_valid(op)
56 # Invalid shapeless input due to op type:
57 op = testutil.create_op_with_quant_tensors(Op.Relu, [], [1, 8, 8, 8])
58 op.ifm.values = 0.5
59 assert not semantic_checker.is_operator_semantic_valid(op)
60
61
62def test_constraint_tens_shape_size():
63 # Tensors cannot be > 4D
64 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 1, 8, 8, 8], [1, 1, 8, 8, 8], set_ifm_ofm_shapes=False)
65 assert not semantic_checker.is_operator_semantic_valid(op)
66
67
68def test_constraint_tens_quant_none_check():
69 # Tensors must have quantization parameters
70 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm2_quant=None)
71 assert not semantic_checker.is_operator_semantic_valid(op)
72
73
74def test_constraint_tens_quant_scale():
75 # Quantization scale cannot be infinite
76 qp = QuantizationParameters()
77 qp.zero_point = 0
78 qp.scale_f32 = np.inf
79 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm_quant=qp)
80 assert not semantic_checker.is_operator_semantic_valid(op)
81
82
83def test_constraint_fc_output_2d_not_supp():
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +010084 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [7, 4, 6], [3, 2, 2, 8], weights_shape=[1, 9, 1])
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020085 assert not semantic_checker.is_operator_semantic_valid(op)
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +010086 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 6, 1], [3, 7, 4], weights_shape=[1, 1, 7, 1])
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020087 assert not semantic_checker.is_operator_semantic_valid(op)
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +010088 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 1, 4, 7], [1, 9], weights_shape=[12, 3])
89 assert not semantic_checker.is_operator_semantic_valid(op)
90 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4], [9], weights_shape=[3, 2])
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020091 assert not semantic_checker.is_operator_semantic_valid(op)
92
93
94def test_constraint_fc_output_2d_is_supp():
95 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
96 assert semantic_checker.is_operator_semantic_valid(op)
97 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1024], [16, 64], weights_shape=[1, 1024])
98 assert semantic_checker.is_operator_semantic_valid(op)
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +010099 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 1, 1], weights_shape=[12, 1, 1, 1])
100 assert semantic_checker.is_operator_semantic_valid(op)
101 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 1], weights_shape=[12, 1, 1, 1])
102 assert semantic_checker.is_operator_semantic_valid(op)
103 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [1, 1, 3, 2], weights_shape=[12, 1, 1, 1])
104 assert semantic_checker.is_operator_semantic_valid(op)
105 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 1, 1], weights_shape=[12, 1, 1, 1])
106 assert semantic_checker.is_operator_semantic_valid(op)
107 op = testutil.create_op_with_quant_tensors(
108 Op.FullyConnected, [12, 1, 1, 1], [1, 1, 24], weights_shape=[12, 1, 1, 1]
109 )
110 assert semantic_checker.is_operator_semantic_valid(op)
111 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1, 3, 4], weights_shape=[1, 1, 1, 1])
112 assert semantic_checker.is_operator_semantic_valid(op)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200113
114
115def test_constraint_conv_pass():
116 # First test a simple conv passes
117 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
118 op.attrs = {"stride_w": 1, "stride_h": 1}
119 assert semantic_checker.is_operator_semantic_valid(op)
120
121
122def test_constraint_stride_type():
123 # Stride width and height must be integer types
124 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
125 op.attrs = {"stride_w": 1.5, "stride_h": "1"}
126 assert not semantic_checker.is_operator_semantic_valid(op)
127
128
129def test_constraint_dilation_type():
130 # Dilation width and height must be integer types
131 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
132 op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 1.5, "dilation_h_factor": "1"}
133 assert not semantic_checker.is_operator_semantic_valid(op)
134
135
136def test_constraint_quant_scale_inf():
137 # Test handling IFM scale/OFM scale is infinite
138 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8])
139 op.ifm.quantization.scale_f32 = np.float32(1e9)
140 op.ofm.quantization.scale_f32 = np.float32(1e-35)
141 assert not semantic_checker.is_operator_semantic_valid(op)
142
143
144def test_constraint_ofm_scale_too_small():
145 # Tests handling of OFM scale < 1e-38
146 shp = [1, 10, 20, 16]
Jonas Ohlssond8575072022-03-30 10:30:25 +0200147 op = testutil.create_elemwise_op(
148 Op.Mul,
149 "mul",
150 shp,
151 shp,
152 shp,
153 ofm_quant=testutil.default_quant_params(),
154 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200155 assert semantic_checker.is_operator_semantic_valid(op)
156 op.ofm.quantization.scale_f32 = 1e-43
157 assert not semantic_checker.is_operator_semantic_valid(op)
158
159
160def test_constraint_matching_in_out_types():
161 # Valid
162 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
163 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 2, "padding": Padding.SAME}
164 assert semantic_checker.is_operator_semantic_valid(op)
165 # Invalid. datatypes for ifm and ofm must match (default uint8)
166 op.ifm.dtype = DataType.int8
167 assert not semantic_checker.is_operator_semantic_valid(op)
168
169
170def test_constraint_filter_type():
171 # Filter width/height must be integers
172 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
173 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2.5, "filter_height": "2", "padding": Padding.SAME}
174 assert not semantic_checker.is_operator_semantic_valid(op)
175
176
177def test_constraint_matching_shapes():
178 # Softmax requires the ifm and ofm shapes to match
179 op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 2, 2, 4])
180 assert not semantic_checker.is_operator_semantic_valid(op)
181 op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
182 assert semantic_checker.is_operator_semantic_valid(op)
183
184
185def test_constraint_beta_value_range():
186 # beta must be positive
187 op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
188 op.attrs["beta"] = -1.0
189 assert not semantic_checker.is_operator_semantic_valid(op)
190 op.attrs["beta"] = 0.0
191 assert semantic_checker.is_operator_semantic_valid(op)
192
193
194def test_constraint_splitv_inferred():
195 # SplitV requires a maximum of one inferred shape (-1)
196 qp = testutil.default_quant_params()
197 op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
198 sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], np.int16, quantization=qp)
199 op.add_input_tensor(sizes)
200 assert not semantic_checker.is_operator_semantic_valid(op)
201 op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
202 sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], np.int16, quantization=qp)
203 op.add_input_tensor(sizes)
204 assert semantic_checker.is_operator_semantic_valid(op)
205
206
207def test_constraint_concat_pass():
208 # A working concat
209 op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
210 ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
211 ifm2.quantization = testutil.default_quant_params()
212 op.add_input_tensor(ifm2)
213 op.attrs["axis"] = 3
214 assert semantic_checker.is_operator_semantic_valid(op)
215
216
217def test_constraint_axis_exists():
218 # Missing axis attribute
219 op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
220 ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
221 ifm2.quantization = testutil.default_quant_params()
222 op.add_input_tensor(ifm2)
223 assert not semantic_checker.is_operator_semantic_valid(op)
224
225
226def test_constraint_axis_valid():
227 # Invalid axis attribute
228 op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
229 ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
230 ifm2.quantization = testutil.default_quant_params()
231 op.add_input_tensor(ifm2)
232 op.attrs["axis"] = 7
233 assert not semantic_checker.is_operator_semantic_valid(op)
234
235
236def test_constraint_matching_dimensionality():
237 # Mismatching dimensionality: 4D+2D=4D
238 op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
239 ifm2 = Tensor([1, 4], DataType.uint8, "in2")
240 ifm2.quantization = testutil.default_quant_params()
241 op.add_input_tensor(ifm2)
242 op.attrs["axis"] = 3
243 assert not semantic_checker.is_operator_semantic_valid(op)
244
245
246def test_constraint_valid_dimensions():
247 # Mismatching dimension value:
248 # ifm2 has w and h as 2, which is not the axis to concat and doesnt match ifm1 or ofm
249 op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
250 ifm2 = Tensor([1, 2, 2, 4], DataType.uint8, "in2")
251 ifm2.quantization = testutil.default_quant_params()
252 op.add_input_tensor(ifm2)
253 op.attrs["axis"] = 3
254 assert not semantic_checker.is_operator_semantic_valid(op)
255
256
257def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
258 qp = testutil.default_quant_params()
259 in0 = Tensor(in_shape, DataType.uint8, "in")
260 in0.quantization = qp
261 in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp)
262 in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp)
263 in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp)
264 out = Tensor(out_shape, DataType.uint8, "out")
265 out.quantization = qp
266 attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
267 return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
268
269
270def create_pad_op(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200271 in_shape,
272 out_shape,
273 padding,
274 in_dtype=DataType.int8,
275 out_dtype=DataType.int8,
276 pad_dtype=DataType.int32,
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200277):
278 qp = testutil.default_quant_params()
279 in0 = Tensor(in_shape, in_dtype, "in")
280 in0.quantization = qp
281 pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
282 out = Tensor(out_shape, out_dtype, "out")
283 out.quantization = qp.clone()
284 op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
285 return op
286
287
288def test_constraint_pad_input_count():
289 # Incorrect number of input tensors (2)
Jonas Ohlssond8575072022-03-30 10:30:25 +0200290 op = create_pad_op(
291 in_shape=[1, 1, 1, 1],
292 out_shape=[1, 3, 3, 1],
293 padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
294 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200295 assert semantic_checker.is_operator_semantic_valid(op)
296 op.add_input_tensor(op.inputs[0].clone())
297 assert not semantic_checker.is_operator_semantic_valid(op)
298
299
300def create_strided_slice():
301 # Creates a valid strided slice operator with some valid inputs/outputs
302 op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
303 op.attrs["begin_mask"] = 1
304 op.attrs["end_mask"] = 9
305 assert semantic_checker.is_operator_semantic_valid(op)
306 return op
307
308
309def test_constraint_stridedslice_input_count():
310 # Wrong number of input tensors
311 op = create_strided_slice()
312 op.add_input_tensor(op.inputs[0].clone())
313 assert not semantic_checker.is_operator_semantic_valid(op)
314
315
316def test_constraint_stridedslice_inputs_const():
317 # begin, end, stride values must not be None
318 op = create_strided_slice()
319 op.inputs[1].values = None
320 assert not semantic_checker.is_operator_semantic_valid(op)
321 op = create_strided_slice()
322 op.inputs[2].values = None
323 assert not semantic_checker.is_operator_semantic_valid(op)
324 op = create_strided_slice()
325 op.inputs[3].values = None
326 assert not semantic_checker.is_operator_semantic_valid(op)
327
328
329def test_constraint_ellipsis_mask():
330 # Unsemantic_checkered ellipsis mask
331 op = create_strided_slice()
332 op.attrs["ellipsis_mask"] = 1
333 assert not semantic_checker.is_operator_semantic_valid(op)
334
335
336def test_constraint_axis_masks():
337 op = create_strided_slice()
338 # Setting one of new_axis_mask/shrink_axis_mask to non-zero is ok
339 op.attrs["new_axis_mask"] = 2
340 assert semantic_checker.is_operator_semantic_valid(op)
341 op = create_strided_slice()
342 op.attrs["shrink_axis_mask"] = 3
343 assert semantic_checker.is_operator_semantic_valid(op)
344 # But setting both to non-zero is not semantic_checkered
345 op.attrs["new_axis_mask"] = 2
346 assert not semantic_checker.is_operator_semantic_valid(op)
347
348
349def test_constraint_slice_ranges():
350 # Examples where end offset <= begin offset
351 op = create_strided_slice()
352 op.inputs[1].values = [0, 7, 2, 0]
353 assert not semantic_checker.is_operator_semantic_valid(op)
354 op = create_strided_slice()
355 op.inputs[2].values = [0, 7, 2, 0]
356 assert not semantic_checker.is_operator_semantic_valid(op)
357 op = create_strided_slice()
358 op.attrs["begin_mask"] = 0
359 assert not semantic_checker.is_operator_semantic_valid(op)
360 op = create_strided_slice()
361 op.attrs["end_mask"] = 0
362 assert not semantic_checker.is_operator_semantic_valid(op)
363
364
365def test_constraint_matching_inputs_types():
366 # input data types must match (default is uint8)
367 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
368 op.ifm2.dtype = DataType.int8
369 assert not semantic_checker.is_operator_semantic_valid(op)
370
371
372def test_constraint_matching_signed():
373 # signed inputs require output to also be signed
374 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
375 op.ofm.dtype = DataType.uint8
376 assert not semantic_checker.is_operator_semantic_valid(op)
377
378
379def test_constraint_unsigned_valid():
380 # unsigned inputs require output to be either:
381 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
382 # the same (default uint8)
383 assert semantic_checker.is_operator_semantic_valid(op)
384 op.ofm.dtype = DataType.int8
385 assert not semantic_checker.is_operator_semantic_valid(op)
386 op.ofm.dtype = DataType.int16
387 assert not semantic_checker.is_operator_semantic_valid(op)
388 # or int32
389 op.ofm.dtype = DataType.int32
390 assert semantic_checker.is_operator_semantic_valid(op)
391
392
393def test_constraint_matching_either_shapes():
394 # BINARY CASE
395 # At least one ifm shape must match ofm's shape
396 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [4, 4])
397 assert semantic_checker.is_operator_semantic_valid(op)
398 op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [1, 4], [4, 4])
399 assert semantic_checker.is_operator_semantic_valid(op)
400 op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [4, 4], [2, 2])
401 assert not semantic_checker.is_operator_semantic_valid(op)
402 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 4, 16])
403 assert not semantic_checker.is_operator_semantic_valid(op)
404 op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 4, 16])
405 assert not semantic_checker.is_operator_semantic_valid(op)
406
407 # UNARY CASE
408 # No second input so this is treated the same as requiring ifm shape to match ofm shape
409 op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32)
410 assert semantic_checker.is_operator_semantic_valid(op)
411 op = testutil.create_elemwise_op(Op.CLZ, "op", [4, 4], None, [2, 2], datatype=DataType.int32)
412 assert not semantic_checker.is_operator_semantic_valid(op)
413
414
415def test_constraint_alpha_valid():
Johan Alfvéne51a05c2022-05-11 13:10:50 +0200416 # Alpha can only be negative for int8 and uint8
417 op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2], DataType.int16)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200418 op.attrs["alpha"] = 0
419 assert semantic_checker.is_operator_semantic_valid(op)
420 op.attrs["alpha"] = -1
421 assert not semantic_checker.is_operator_semantic_valid(op)
Johan Alfvéne51a05c2022-05-11 13:10:50 +0200422 op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2], DataType.int8)
423 op.attrs["alpha"] = 0
424 assert semantic_checker.is_operator_semantic_valid(op)
425 op.attrs["alpha"] = -1
426 assert semantic_checker.is_operator_semantic_valid(op)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200427
428
429def test_constraint_hardswish_dtype():
430 # HardSwish operator dtype should be int8 or uint8, and input dtype must match output
431 # UINT8
432 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8])
433 assert semantic_checker.is_operator_semantic_valid(op)
434 # INT8
435 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
436 assert semantic_checker.is_operator_semantic_valid(op)
437
438 # Invalid
439 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16)
440 assert not semantic_checker.is_operator_semantic_valid(op)
441 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint16)
442 assert not semantic_checker.is_operator_semantic_valid(op)
443 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
444 assert not semantic_checker.is_operator_semantic_valid(op)
445
446 in_tens = Tensor([1, 8, 8, 8], DataType.int8, "in")
447 out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out")
448 op = testutil.create_op(Op.HardSwish, [in_tens], out_tens)
449 assert not semantic_checker.is_operator_semantic_valid(op)
450
451
452def test_constraint_keep_dims_ifm_ofm():
453 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
454 op.attrs["keep_num_dims"] = True
455 assert not semantic_checker.is_operator_semantic_valid(op)
456 op.attrs["keep_num_dims"] = False
457 assert semantic_checker.is_operator_semantic_valid(op)
458
459
460def create_mean(input_shape, output_shape, axis, datatype, attrs):
461 ifm = Tensor(input_shape, datatype, "in")
462 ifm.quantization = testutil.default_quant_params()
463 ofm = Tensor(output_shape, datatype, "out")
464 ofm.quantization = testutil.default_quant_params()
465 if type(axis) is list:
466 indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis, np.uint8)
467 elif type(axis) is int:
468 indices = create_const_tensor("indices", [], DataType.int32, axis, np.uint8)
469 op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
470 return op
471
472
473def test_mean_dtype():
474 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
475 assert semantic_checker.is_operator_semantic_valid(op)
476 op.ifm.dtype = DataType.int16
477 op.ofm.dtype = DataType.int16
478 assert not semantic_checker.is_operator_semantic_valid(op)
479
480
481def test_mean_axis():
482 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True})
483 assert not semantic_checker.is_operator_semantic_valid(op)
484 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True})
485 assert not semantic_checker.is_operator_semantic_valid(op)
486 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True})
487 assert not semantic_checker.is_operator_semantic_valid(op)
488 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
489 assert not semantic_checker.is_operator_semantic_valid(op)
490 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
491 assert semantic_checker.is_operator_semantic_valid(op)
492 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
493 assert semantic_checker.is_operator_semantic_valid(op)
494 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 2, DataType.int8, {"keep_dims": True})
495 assert semantic_checker.is_operator_semantic_valid(op)
496 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [2, 1], DataType.int8, {"keep_dims": True})
497 assert semantic_checker.is_operator_semantic_valid(op)
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200498
499
500def test_matching_in_out_quant():
501 # quantisation parameters of ifm and ofm should match.
502 quant = testutil.default_quant_params()
503 # create reshape op
504 ifm_shape = [64, 16]
505 ofm_shape = [1, 4, 16, 16]
506 ifm = create_const_tensor("reshape_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
507 ifm.quantization = quant
508 ofm = create_const_tensor("reshape_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
509 ofm.quantization = quant.clone()
510 shape_tens = create_const_tensor("shape", [1], DataType.int32, ofm_shape)
511 op = testutil.create_op(Op.Reshape, [ifm, shape_tens], ofm, set_ifm_ofm_shapes=False)
512 op.attrs["new_shape"] = ofm_shape
513
514 # Matching quantisation parameters
515 assert semantic_checker.is_operator_semantic_valid(op)
516
517 # Different zp
518 ofm.quantization.zero_point = 32
519 assert not semantic_checker.is_operator_semantic_valid(op)
520
521 # Different scale
522 ofm.quantization.zero_point = 0
523 ofm.quantization.scale_f32 = 0.9
524 assert not semantic_checker.is_operator_semantic_valid(op)
525
526 # Squeeze op diff quant
527 # create squeeze op
528 ifm_shape = [1, 1, 1, 1001]
529 ofm_shape = [1, 1001]
530 ifm = create_const_tensor("squeeze_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
531 ifm.quantization = quant
532 ofm = create_const_tensor("squeeze_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
533 ofm.quantization = quant.clone()
534 ofm.quantization.zero_point = 32
535 op = testutil.create_op(Op.Squeeze, [ifm], ofm, set_ifm_ofm_shapes=False)
536 op.attrs["squeeze_dims"] = [1, 2]
537 assert not semantic_checker.is_operator_semantic_valid(op)
538
539 # ExpandDims diff quant
540 quant = testutil.default_quant_params()
541 # create expand_dims op
542 ifm_shape = [4, 16, 16]
543 ofm_shape = [1, 4, 16, 16]
544 ifm = create_const_tensor("expand_dims_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
545 ifm.quantization = quant
546 ofm = create_const_tensor("expand_dims_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
547 ofm.quantization = quant.clone()
548 ofm.quantization.zero_point = 32
549 dim = create_const_tensor("expand_dims_dim", [], DataType.uint8, 0)
550 op = testutil.create_op(Op.ExpandDims, [ifm, dim], ofm, set_ifm_ofm_shapes=False)
551 assert not semantic_checker.is_operator_semantic_valid(op)