blob: 84f9916088348d7f810656783dc775541a10867f [file] [log] [blame]
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Unit tests for tflite_model_semantic
19import numpy as np
20
21from ethosu.vela.data_type import DataType
22from ethosu.vela.operation import Op
23from ethosu.vela.operation import Padding
24from ethosu.vela.tensor import create_const_tensor
25from ethosu.vela.tensor import QuantizationParameters
26from ethosu.vela.tensor import Tensor
27from ethosu.vela.test import testutil
28from ethosu.vela.tflite_model_semantic import TFLiteSemantic
29
30semantic_checker = TFLiteSemantic()
31
32
33def test_constraint_tens_no_dynamic():
34 # Tensors cannot be dynamic (no shape, not a scalar)
35 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [])
36 assert not semantic_checker.is_operator_semantic_valid(op)
37
38
39def test_constraint_tens_defined_shape():
40 # Tensors cannot have None in them
41 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, None, 8], [1, 8, 8, 8])
42 assert not semantic_checker.is_operator_semantic_valid(op)
43
44
45def test_constraint_tens_output_scalar():
46 # Scalar output is not allowed at all:
47 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [])
48 op.ofm.values = 0.5
49 assert not semantic_checker.is_operator_semantic_valid(op)
50
51
52def test_constraint_tens_input_scalar():
53 # Shapeless input is allowed if its of a certain type:
54 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8])
55 assert semantic_checker.is_operator_semantic_valid(op)
56 # Invalid shapeless input due to op type:
57 op = testutil.create_op_with_quant_tensors(Op.Relu, [], [1, 8, 8, 8])
58 op.ifm.values = 0.5
59 assert not semantic_checker.is_operator_semantic_valid(op)
60
61
62def test_constraint_tens_shape_size():
63 # Tensors cannot be > 4D
64 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 1, 8, 8, 8], [1, 1, 8, 8, 8], set_ifm_ofm_shapes=False)
65 assert not semantic_checker.is_operator_semantic_valid(op)
66
67
68def test_constraint_tens_quant_none_check():
69 # Tensors must have quantization parameters
70 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm2_quant=None)
71 assert not semantic_checker.is_operator_semantic_valid(op)
72
73
74def test_constraint_tens_quant_scale():
75 # Quantization scale cannot be infinite
76 qp = QuantizationParameters()
77 qp.zero_point = 0
78 qp.scale_f32 = np.inf
79 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm_quant=qp)
80 assert not semantic_checker.is_operator_semantic_valid(op)
81
82
83def test_constraint_fc_output_2d_not_supp():
84 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 2, 1], weights_shape=[12, 1, 1, 1])
85 assert not semantic_checker.is_operator_semantic_valid(op)
86 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 3, 4], weights_shape=[12, 1, 1, 1])
87 assert not semantic_checker.is_operator_semantic_valid(op)
88 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1], weights_shape=[1, 1, 1, 1])
89 assert not semantic_checker.is_operator_semantic_valid(op)
90
91
92def test_constraint_fc_output_2d_is_supp():
93 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
94 assert semantic_checker.is_operator_semantic_valid(op)
95 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1024], [16, 64], weights_shape=[1, 1024])
96 assert semantic_checker.is_operator_semantic_valid(op)
97
98
99def test_constraint_conv_pass():
100 # First test a simple conv passes
101 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
102 op.attrs = {"stride_w": 1, "stride_h": 1}
103 assert semantic_checker.is_operator_semantic_valid(op)
104
105
106def test_constraint_stride_type():
107 # Stride width and height must be integer types
108 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
109 op.attrs = {"stride_w": 1.5, "stride_h": "1"}
110 assert not semantic_checker.is_operator_semantic_valid(op)
111
112
113def test_constraint_dilation_type():
114 # Dilation width and height must be integer types
115 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
116 op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 1.5, "dilation_h_factor": "1"}
117 assert not semantic_checker.is_operator_semantic_valid(op)
118
119
120def test_constraint_quant_scale_inf():
121 # Test handling IFM scale/OFM scale is infinite
122 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8])
123 op.ifm.quantization.scale_f32 = np.float32(1e9)
124 op.ofm.quantization.scale_f32 = np.float32(1e-35)
125 assert not semantic_checker.is_operator_semantic_valid(op)
126
127
128def test_constraint_ofm_scale_too_small():
129 # Tests handling of OFM scale < 1e-38
130 shp = [1, 10, 20, 16]
131 op = testutil.create_elemwise_op(Op.Mul, "mul", shp, shp, shp, ofm_quant=testutil.default_quant_params(),)
132 assert semantic_checker.is_operator_semantic_valid(op)
133 op.ofm.quantization.scale_f32 = 1e-43
134 assert not semantic_checker.is_operator_semantic_valid(op)
135
136
137def test_constraint_matching_in_out_types():
138 # Valid
139 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
140 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 2, "padding": Padding.SAME}
141 assert semantic_checker.is_operator_semantic_valid(op)
142 # Invalid. datatypes for ifm and ofm must match (default uint8)
143 op.ifm.dtype = DataType.int8
144 assert not semantic_checker.is_operator_semantic_valid(op)
145
146
147def test_constraint_filter_type():
148 # Filter width/height must be integers
149 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
150 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2.5, "filter_height": "2", "padding": Padding.SAME}
151 assert not semantic_checker.is_operator_semantic_valid(op)
152
153
154def test_constraint_matching_shapes():
155 # Softmax requires the ifm and ofm shapes to match
156 op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 2, 2, 4])
157 assert not semantic_checker.is_operator_semantic_valid(op)
158 op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
159 assert semantic_checker.is_operator_semantic_valid(op)
160
161
162def test_constraint_beta_value_range():
163 # beta must be positive
164 op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
165 op.attrs["beta"] = -1.0
166 assert not semantic_checker.is_operator_semantic_valid(op)
167 op.attrs["beta"] = 0.0
168 assert semantic_checker.is_operator_semantic_valid(op)
169
170
171def test_constraint_splitv_inferred():
172 # SplitV requires a maximum of one inferred shape (-1)
173 qp = testutil.default_quant_params()
174 op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
175 sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], np.int16, quantization=qp)
176 op.add_input_tensor(sizes)
177 assert not semantic_checker.is_operator_semantic_valid(op)
178 op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
179 sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], np.int16, quantization=qp)
180 op.add_input_tensor(sizes)
181 assert semantic_checker.is_operator_semantic_valid(op)
182
183
184def test_constraint_concat_pass():
185 # A working concat
186 op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
187 ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
188 ifm2.quantization = testutil.default_quant_params()
189 op.add_input_tensor(ifm2)
190 op.attrs["axis"] = 3
191 assert semantic_checker.is_operator_semantic_valid(op)
192
193
194def test_constraint_axis_exists():
195 # Missing axis attribute
196 op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
197 ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
198 ifm2.quantization = testutil.default_quant_params()
199 op.add_input_tensor(ifm2)
200 assert not semantic_checker.is_operator_semantic_valid(op)
201
202
203def test_constraint_axis_valid():
204 # Invalid axis attribute
205 op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
206 ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
207 ifm2.quantization = testutil.default_quant_params()
208 op.add_input_tensor(ifm2)
209 op.attrs["axis"] = 7
210 assert not semantic_checker.is_operator_semantic_valid(op)
211
212
213def test_constraint_matching_dimensionality():
214 # Mismatching dimensionality: 4D+2D=4D
215 op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
216 ifm2 = Tensor([1, 4], DataType.uint8, "in2")
217 ifm2.quantization = testutil.default_quant_params()
218 op.add_input_tensor(ifm2)
219 op.attrs["axis"] = 3
220 assert not semantic_checker.is_operator_semantic_valid(op)
221
222
223def test_constraint_valid_dimensions():
224 # Mismatching dimension value:
225 # ifm2 has w and h as 2, which is not the axis to concat and doesnt match ifm1 or ofm
226 op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
227 ifm2 = Tensor([1, 2, 2, 4], DataType.uint8, "in2")
228 ifm2.quantization = testutil.default_quant_params()
229 op.add_input_tensor(ifm2)
230 op.attrs["axis"] = 3
231 assert not semantic_checker.is_operator_semantic_valid(op)
232
233
234def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
235 qp = testutil.default_quant_params()
236 in0 = Tensor(in_shape, DataType.uint8, "in")
237 in0.quantization = qp
238 in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp)
239 in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp)
240 in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp)
241 out = Tensor(out_shape, DataType.uint8, "out")
242 out.quantization = qp
243 attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
244 return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
245
246
247def create_pad_op(
248 in_shape, out_shape, padding, in_dtype=DataType.int8, out_dtype=DataType.int8, pad_dtype=DataType.int32,
249):
250 qp = testutil.default_quant_params()
251 in0 = Tensor(in_shape, in_dtype, "in")
252 in0.quantization = qp
253 pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
254 out = Tensor(out_shape, out_dtype, "out")
255 out.quantization = qp.clone()
256 op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
257 return op
258
259
260def test_constraint_pad_input_count():
261 # Incorrect number of input tensors (2)
262 op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0]],)
263 assert semantic_checker.is_operator_semantic_valid(op)
264 op.add_input_tensor(op.inputs[0].clone())
265 assert not semantic_checker.is_operator_semantic_valid(op)
266
267
268def create_strided_slice():
269 # Creates a valid strided slice operator with some valid inputs/outputs
270 op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
271 op.attrs["begin_mask"] = 1
272 op.attrs["end_mask"] = 9
273 assert semantic_checker.is_operator_semantic_valid(op)
274 return op
275
276
277def test_constraint_stridedslice_input_count():
278 # Wrong number of input tensors
279 op = create_strided_slice()
280 op.add_input_tensor(op.inputs[0].clone())
281 assert not semantic_checker.is_operator_semantic_valid(op)
282
283
284def test_constraint_stridedslice_inputs_const():
285 # begin, end, stride values must not be None
286 op = create_strided_slice()
287 op.inputs[1].values = None
288 assert not semantic_checker.is_operator_semantic_valid(op)
289 op = create_strided_slice()
290 op.inputs[2].values = None
291 assert not semantic_checker.is_operator_semantic_valid(op)
292 op = create_strided_slice()
293 op.inputs[3].values = None
294 assert not semantic_checker.is_operator_semantic_valid(op)
295
296
297def test_constraint_ellipsis_mask():
298 # Unsemantic_checkered ellipsis mask
299 op = create_strided_slice()
300 op.attrs["ellipsis_mask"] = 1
301 assert not semantic_checker.is_operator_semantic_valid(op)
302
303
304def test_constraint_axis_masks():
305 op = create_strided_slice()
306 # Setting one of new_axis_mask/shrink_axis_mask to non-zero is ok
307 op.attrs["new_axis_mask"] = 2
308 assert semantic_checker.is_operator_semantic_valid(op)
309 op = create_strided_slice()
310 op.attrs["shrink_axis_mask"] = 3
311 assert semantic_checker.is_operator_semantic_valid(op)
312 # But setting both to non-zero is not semantic_checkered
313 op.attrs["new_axis_mask"] = 2
314 assert not semantic_checker.is_operator_semantic_valid(op)
315
316
317def test_constraint_slice_ranges():
318 # Examples where end offset <= begin offset
319 op = create_strided_slice()
320 op.inputs[1].values = [0, 7, 2, 0]
321 assert not semantic_checker.is_operator_semantic_valid(op)
322 op = create_strided_slice()
323 op.inputs[2].values = [0, 7, 2, 0]
324 assert not semantic_checker.is_operator_semantic_valid(op)
325 op = create_strided_slice()
326 op.attrs["begin_mask"] = 0
327 assert not semantic_checker.is_operator_semantic_valid(op)
328 op = create_strided_slice()
329 op.attrs["end_mask"] = 0
330 assert not semantic_checker.is_operator_semantic_valid(op)
331
332
333def test_constraint_matching_inputs_types():
334 # input data types must match (default is uint8)
335 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
336 op.ifm2.dtype = DataType.int8
337 assert not semantic_checker.is_operator_semantic_valid(op)
338
339
340def test_constraint_matching_signed():
341 # signed inputs require output to also be signed
342 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
343 op.ofm.dtype = DataType.uint8
344 assert not semantic_checker.is_operator_semantic_valid(op)
345
346
347def test_constraint_unsigned_valid():
348 # unsigned inputs require output to be either:
349 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
350 # the same (default uint8)
351 assert semantic_checker.is_operator_semantic_valid(op)
352 op.ofm.dtype = DataType.int8
353 assert not semantic_checker.is_operator_semantic_valid(op)
354 op.ofm.dtype = DataType.int16
355 assert not semantic_checker.is_operator_semantic_valid(op)
356 # or int32
357 op.ofm.dtype = DataType.int32
358 assert semantic_checker.is_operator_semantic_valid(op)
359
360
361def test_constraint_matching_either_shapes():
362 # BINARY CASE
363 # At least one ifm shape must match ofm's shape
364 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [4, 4])
365 assert semantic_checker.is_operator_semantic_valid(op)
366 op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [1, 4], [4, 4])
367 assert semantic_checker.is_operator_semantic_valid(op)
368 op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [4, 4], [2, 2])
369 assert not semantic_checker.is_operator_semantic_valid(op)
370 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 4, 16])
371 assert not semantic_checker.is_operator_semantic_valid(op)
372 op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 4, 16])
373 assert not semantic_checker.is_operator_semantic_valid(op)
374
375 # UNARY CASE
376 # No second input so this is treated the same as requiring ifm shape to match ofm shape
377 op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32)
378 assert semantic_checker.is_operator_semantic_valid(op)
379 op = testutil.create_elemwise_op(Op.CLZ, "op", [4, 4], None, [2, 2], datatype=DataType.int32)
380 assert not semantic_checker.is_operator_semantic_valid(op)
381
382
383def test_constraint_alpha_valid():
384 # Alpha cannot be negative
385 op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2])
386 op.attrs["alpha"] = 0
387 assert semantic_checker.is_operator_semantic_valid(op)
388 op.attrs["alpha"] = -1
389 assert not semantic_checker.is_operator_semantic_valid(op)
390
391
392def test_constraint_hardswish_dtype():
393 # HardSwish operator dtype should be int8 or uint8, and input dtype must match output
394 # UINT8
395 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8])
396 assert semantic_checker.is_operator_semantic_valid(op)
397 # INT8
398 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
399 assert semantic_checker.is_operator_semantic_valid(op)
400
401 # Invalid
402 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16)
403 assert not semantic_checker.is_operator_semantic_valid(op)
404 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint16)
405 assert not semantic_checker.is_operator_semantic_valid(op)
406 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
407 assert not semantic_checker.is_operator_semantic_valid(op)
408
409 in_tens = Tensor([1, 8, 8, 8], DataType.int8, "in")
410 out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out")
411 op = testutil.create_op(Op.HardSwish, [in_tens], out_tens)
412 assert not semantic_checker.is_operator_semantic_valid(op)
413
414
415def test_constraint_keep_dims_ifm_ofm():
416 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
417 op.attrs["keep_num_dims"] = True
418 assert not semantic_checker.is_operator_semantic_valid(op)
419 op.attrs["keep_num_dims"] = False
420 assert semantic_checker.is_operator_semantic_valid(op)
421
422
423def create_mean(input_shape, output_shape, axis, datatype, attrs):
424 ifm = Tensor(input_shape, datatype, "in")
425 ifm.quantization = testutil.default_quant_params()
426 ofm = Tensor(output_shape, datatype, "out")
427 ofm.quantization = testutil.default_quant_params()
428 if type(axis) is list:
429 indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis, np.uint8)
430 elif type(axis) is int:
431 indices = create_const_tensor("indices", [], DataType.int32, axis, np.uint8)
432 op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
433 return op
434
435
436def test_mean_dtype():
437 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
438 assert semantic_checker.is_operator_semantic_valid(op)
439 op.ifm.dtype = DataType.int16
440 op.ofm.dtype = DataType.int16
441 assert not semantic_checker.is_operator_semantic_valid(op)
442
443
444def test_mean_axis():
445 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True})
446 assert not semantic_checker.is_operator_semantic_valid(op)
447 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True})
448 assert not semantic_checker.is_operator_semantic_valid(op)
449 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True})
450 assert not semantic_checker.is_operator_semantic_valid(op)
451 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
452 assert not semantic_checker.is_operator_semantic_valid(op)
453 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
454 assert semantic_checker.is_operator_semantic_valid(op)
455 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
456 assert semantic_checker.is_operator_semantic_valid(op)
457 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 2, DataType.int8, {"keep_dims": True})
458 assert semantic_checker.is_operator_semantic_valid(op)
459 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [2, 1], DataType.int8, {"keep_dims": True})
460 assert semantic_checker.is_operator_semantic_valid(op)
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200461
462
463def test_matching_in_out_quant():
464 # quantisation parameters of ifm and ofm should match.
465 quant = testutil.default_quant_params()
466 # create reshape op
467 ifm_shape = [64, 16]
468 ofm_shape = [1, 4, 16, 16]
469 ifm = create_const_tensor("reshape_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
470 ifm.quantization = quant
471 ofm = create_const_tensor("reshape_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
472 ofm.quantization = quant.clone()
473 shape_tens = create_const_tensor("shape", [1], DataType.int32, ofm_shape)
474 op = testutil.create_op(Op.Reshape, [ifm, shape_tens], ofm, set_ifm_ofm_shapes=False)
475 op.attrs["new_shape"] = ofm_shape
476
477 # Matching quantisation parameters
478 assert semantic_checker.is_operator_semantic_valid(op)
479
480 # Different zp
481 ofm.quantization.zero_point = 32
482 assert not semantic_checker.is_operator_semantic_valid(op)
483
484 # Different scale
485 ofm.quantization.zero_point = 0
486 ofm.quantization.scale_f32 = 0.9
487 assert not semantic_checker.is_operator_semantic_valid(op)
488
489 # Squeeze op diff quant
490 # create squeeze op
491 ifm_shape = [1, 1, 1, 1001]
492 ofm_shape = [1, 1001]
493 ifm = create_const_tensor("squeeze_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
494 ifm.quantization = quant
495 ofm = create_const_tensor("squeeze_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
496 ofm.quantization = quant.clone()
497 ofm.quantization.zero_point = 32
498 op = testutil.create_op(Op.Squeeze, [ifm], ofm, set_ifm_ofm_shapes=False)
499 op.attrs["squeeze_dims"] = [1, 2]
500 assert not semantic_checker.is_operator_semantic_valid(op)
501
502 # ExpandDims diff quant
503 quant = testutil.default_quant_params()
504 # create expand_dims op
505 ifm_shape = [4, 16, 16]
506 ofm_shape = [1, 4, 16, 16]
507 ifm = create_const_tensor("expand_dims_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
508 ifm.quantization = quant
509 ofm = create_const_tensor("expand_dims_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
510 ofm.quantization = quant.clone()
511 ofm.quantization.zero_point = 32
512 dim = create_const_tensor("expand_dims_dim", [], DataType.uint8, 0)
513 op = testutil.create_op(Op.ExpandDims, [ifm, dim], ofm, set_ifm_ofm_shapes=False)
514 assert not semantic_checker.is_operator_semantic_valid(op)