blob: 1e5dbd4dd3ad6b240d5110e250b860ec860e9d94 [file] [log] [blame]
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# Description:
18# Unit tests for tflite_model_semantic
19import numpy as np
20
21from ethosu.vela.data_type import DataType
22from ethosu.vela.operation import Op
23from ethosu.vela.operation import Padding
24from ethosu.vela.tensor import create_const_tensor
25from ethosu.vela.tensor import QuantizationParameters
26from ethosu.vela.tensor import Tensor
27from ethosu.vela.test import testutil
28from ethosu.vela.tflite_model_semantic import TFLiteSemantic
29
30semantic_checker = TFLiteSemantic()
31
32
33def test_constraint_tens_no_dynamic():
34 # Tensors cannot be dynamic (no shape, not a scalar)
35 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [])
36 assert not semantic_checker.is_operator_semantic_valid(op)
37
38
39def test_constraint_tens_defined_shape():
40 # Tensors cannot have None in them
41 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, None, 8], [1, 8, 8, 8])
42 assert not semantic_checker.is_operator_semantic_valid(op)
43
44
45def test_constraint_tens_output_scalar():
46 # Scalar output is not allowed at all:
47 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [])
48 op.ofm.values = 0.5
49 assert not semantic_checker.is_operator_semantic_valid(op)
50
51
52def test_constraint_tens_input_scalar():
53 # Shapeless input is allowed if its of a certain type:
54 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8])
55 assert semantic_checker.is_operator_semantic_valid(op)
56 # Invalid shapeless input due to op type:
57 op = testutil.create_op_with_quant_tensors(Op.Relu, [], [1, 8, 8, 8])
58 op.ifm.values = 0.5
59 assert not semantic_checker.is_operator_semantic_valid(op)
60
61
62def test_constraint_tens_shape_size():
63 # Tensors cannot be > 4D
64 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 1, 8, 8, 8], [1, 1, 8, 8, 8], set_ifm_ofm_shapes=False)
65 assert not semantic_checker.is_operator_semantic_valid(op)
66
67
68def test_constraint_tens_quant_none_check():
69 # Tensors must have quantization parameters
70 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm2_quant=None)
71 assert not semantic_checker.is_operator_semantic_valid(op)
72
73
74def test_constraint_tens_quant_scale():
75 # Quantization scale cannot be infinite
76 qp = QuantizationParameters()
77 qp.zero_point = 0
78 qp.scale_f32 = np.inf
79 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm_quant=qp)
80 assert not semantic_checker.is_operator_semantic_valid(op)
81
82
83def test_constraint_fc_output_2d_not_supp():
84 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 2, 1], weights_shape=[12, 1, 1, 1])
85 assert not semantic_checker.is_operator_semantic_valid(op)
86 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 3, 4], weights_shape=[12, 1, 1, 1])
87 assert not semantic_checker.is_operator_semantic_valid(op)
88 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1], weights_shape=[1, 1, 1, 1])
89 assert not semantic_checker.is_operator_semantic_valid(op)
90
91
92def test_constraint_fc_output_2d_is_supp():
93 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
94 assert semantic_checker.is_operator_semantic_valid(op)
95 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1024], [16, 64], weights_shape=[1, 1024])
96 assert semantic_checker.is_operator_semantic_valid(op)
97
98
99def test_constraint_conv_pass():
100 # First test a simple conv passes
101 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
102 op.attrs = {"stride_w": 1, "stride_h": 1}
103 assert semantic_checker.is_operator_semantic_valid(op)
104
105
106def test_constraint_stride_type():
107 # Stride width and height must be integer types
108 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
109 op.attrs = {"stride_w": 1.5, "stride_h": "1"}
110 assert not semantic_checker.is_operator_semantic_valid(op)
111
112
113def test_constraint_dilation_type():
114 # Dilation width and height must be integer types
115 op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
116 op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 1.5, "dilation_h_factor": "1"}
117 assert not semantic_checker.is_operator_semantic_valid(op)
118
119
120def test_constraint_quant_scale_inf():
121 # Test handling IFM scale/OFM scale is infinite
122 op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8])
123 op.ifm.quantization.scale_f32 = np.float32(1e9)
124 op.ofm.quantization.scale_f32 = np.float32(1e-35)
125 assert not semantic_checker.is_operator_semantic_valid(op)
126
127
128def test_constraint_ofm_scale_too_small():
129 # Tests handling of OFM scale < 1e-38
130 shp = [1, 10, 20, 16]
Jonas Ohlssond8575072022-03-30 10:30:25 +0200131 op = testutil.create_elemwise_op(
132 Op.Mul,
133 "mul",
134 shp,
135 shp,
136 shp,
137 ofm_quant=testutil.default_quant_params(),
138 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200139 assert semantic_checker.is_operator_semantic_valid(op)
140 op.ofm.quantization.scale_f32 = 1e-43
141 assert not semantic_checker.is_operator_semantic_valid(op)
142
143
144def test_constraint_matching_in_out_types():
145 # Valid
146 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
147 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 2, "padding": Padding.SAME}
148 assert semantic_checker.is_operator_semantic_valid(op)
149 # Invalid. datatypes for ifm and ofm must match (default uint8)
150 op.ifm.dtype = DataType.int8
151 assert not semantic_checker.is_operator_semantic_valid(op)
152
153
154def test_constraint_filter_type():
155 # Filter width/height must be integers
156 op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
157 op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2.5, "filter_height": "2", "padding": Padding.SAME}
158 assert not semantic_checker.is_operator_semantic_valid(op)
159
160
161def test_constraint_matching_shapes():
162 # Softmax requires the ifm and ofm shapes to match
163 op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 2, 2, 4])
164 assert not semantic_checker.is_operator_semantic_valid(op)
165 op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
166 assert semantic_checker.is_operator_semantic_valid(op)
167
168
169def test_constraint_beta_value_range():
170 # beta must be positive
171 op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
172 op.attrs["beta"] = -1.0
173 assert not semantic_checker.is_operator_semantic_valid(op)
174 op.attrs["beta"] = 0.0
175 assert semantic_checker.is_operator_semantic_valid(op)
176
177
178def test_constraint_splitv_inferred():
179 # SplitV requires a maximum of one inferred shape (-1)
180 qp = testutil.default_quant_params()
181 op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
182 sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], np.int16, quantization=qp)
183 op.add_input_tensor(sizes)
184 assert not semantic_checker.is_operator_semantic_valid(op)
185 op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
186 sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], np.int16, quantization=qp)
187 op.add_input_tensor(sizes)
188 assert semantic_checker.is_operator_semantic_valid(op)
189
190
191def test_constraint_concat_pass():
192 # A working concat
193 op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
194 ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
195 ifm2.quantization = testutil.default_quant_params()
196 op.add_input_tensor(ifm2)
197 op.attrs["axis"] = 3
198 assert semantic_checker.is_operator_semantic_valid(op)
199
200
201def test_constraint_axis_exists():
202 # Missing axis attribute
203 op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
204 ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
205 ifm2.quantization = testutil.default_quant_params()
206 op.add_input_tensor(ifm2)
207 assert not semantic_checker.is_operator_semantic_valid(op)
208
209
210def test_constraint_axis_valid():
211 # Invalid axis attribute
212 op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
213 ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
214 ifm2.quantization = testutil.default_quant_params()
215 op.add_input_tensor(ifm2)
216 op.attrs["axis"] = 7
217 assert not semantic_checker.is_operator_semantic_valid(op)
218
219
220def test_constraint_matching_dimensionality():
221 # Mismatching dimensionality: 4D+2D=4D
222 op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
223 ifm2 = Tensor([1, 4], DataType.uint8, "in2")
224 ifm2.quantization = testutil.default_quant_params()
225 op.add_input_tensor(ifm2)
226 op.attrs["axis"] = 3
227 assert not semantic_checker.is_operator_semantic_valid(op)
228
229
230def test_constraint_valid_dimensions():
231 # Mismatching dimension value:
232 # ifm2 has w and h as 2, which is not the axis to concat and doesnt match ifm1 or ofm
233 op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
234 ifm2 = Tensor([1, 2, 2, 4], DataType.uint8, "in2")
235 ifm2.quantization = testutil.default_quant_params()
236 op.add_input_tensor(ifm2)
237 op.attrs["axis"] = 3
238 assert not semantic_checker.is_operator_semantic_valid(op)
239
240
241def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
242 qp = testutil.default_quant_params()
243 in0 = Tensor(in_shape, DataType.uint8, "in")
244 in0.quantization = qp
245 in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp)
246 in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp)
247 in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp)
248 out = Tensor(out_shape, DataType.uint8, "out")
249 out.quantization = qp
250 attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
251 return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
252
253
254def create_pad_op(
Jonas Ohlssond8575072022-03-30 10:30:25 +0200255 in_shape,
256 out_shape,
257 padding,
258 in_dtype=DataType.int8,
259 out_dtype=DataType.int8,
260 pad_dtype=DataType.int32,
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200261):
262 qp = testutil.default_quant_params()
263 in0 = Tensor(in_shape, in_dtype, "in")
264 in0.quantization = qp
265 pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
266 out = Tensor(out_shape, out_dtype, "out")
267 out.quantization = qp.clone()
268 op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
269 return op
270
271
272def test_constraint_pad_input_count():
273 # Incorrect number of input tensors (2)
Jonas Ohlssond8575072022-03-30 10:30:25 +0200274 op = create_pad_op(
275 in_shape=[1, 1, 1, 1],
276 out_shape=[1, 3, 3, 1],
277 padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
278 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200279 assert semantic_checker.is_operator_semantic_valid(op)
280 op.add_input_tensor(op.inputs[0].clone())
281 assert not semantic_checker.is_operator_semantic_valid(op)
282
283
284def create_strided_slice():
285 # Creates a valid strided slice operator with some valid inputs/outputs
286 op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
287 op.attrs["begin_mask"] = 1
288 op.attrs["end_mask"] = 9
289 assert semantic_checker.is_operator_semantic_valid(op)
290 return op
291
292
293def test_constraint_stridedslice_input_count():
294 # Wrong number of input tensors
295 op = create_strided_slice()
296 op.add_input_tensor(op.inputs[0].clone())
297 assert not semantic_checker.is_operator_semantic_valid(op)
298
299
300def test_constraint_stridedslice_inputs_const():
301 # begin, end, stride values must not be None
302 op = create_strided_slice()
303 op.inputs[1].values = None
304 assert not semantic_checker.is_operator_semantic_valid(op)
305 op = create_strided_slice()
306 op.inputs[2].values = None
307 assert not semantic_checker.is_operator_semantic_valid(op)
308 op = create_strided_slice()
309 op.inputs[3].values = None
310 assert not semantic_checker.is_operator_semantic_valid(op)
311
312
313def test_constraint_ellipsis_mask():
314 # Unsemantic_checkered ellipsis mask
315 op = create_strided_slice()
316 op.attrs["ellipsis_mask"] = 1
317 assert not semantic_checker.is_operator_semantic_valid(op)
318
319
320def test_constraint_axis_masks():
321 op = create_strided_slice()
322 # Setting one of new_axis_mask/shrink_axis_mask to non-zero is ok
323 op.attrs["new_axis_mask"] = 2
324 assert semantic_checker.is_operator_semantic_valid(op)
325 op = create_strided_slice()
326 op.attrs["shrink_axis_mask"] = 3
327 assert semantic_checker.is_operator_semantic_valid(op)
328 # But setting both to non-zero is not semantic_checkered
329 op.attrs["new_axis_mask"] = 2
330 assert not semantic_checker.is_operator_semantic_valid(op)
331
332
333def test_constraint_slice_ranges():
334 # Examples where end offset <= begin offset
335 op = create_strided_slice()
336 op.inputs[1].values = [0, 7, 2, 0]
337 assert not semantic_checker.is_operator_semantic_valid(op)
338 op = create_strided_slice()
339 op.inputs[2].values = [0, 7, 2, 0]
340 assert not semantic_checker.is_operator_semantic_valid(op)
341 op = create_strided_slice()
342 op.attrs["begin_mask"] = 0
343 assert not semantic_checker.is_operator_semantic_valid(op)
344 op = create_strided_slice()
345 op.attrs["end_mask"] = 0
346 assert not semantic_checker.is_operator_semantic_valid(op)
347
348
349def test_constraint_matching_inputs_types():
350 # input data types must match (default is uint8)
351 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
352 op.ifm2.dtype = DataType.int8
353 assert not semantic_checker.is_operator_semantic_valid(op)
354
355
356def test_constraint_matching_signed():
357 # signed inputs require output to also be signed
358 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
359 op.ofm.dtype = DataType.uint8
360 assert not semantic_checker.is_operator_semantic_valid(op)
361
362
363def test_constraint_unsigned_valid():
364 # unsigned inputs require output to be either:
365 op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
366 # the same (default uint8)
367 assert semantic_checker.is_operator_semantic_valid(op)
368 op.ofm.dtype = DataType.int8
369 assert not semantic_checker.is_operator_semantic_valid(op)
370 op.ofm.dtype = DataType.int16
371 assert not semantic_checker.is_operator_semantic_valid(op)
372 # or int32
373 op.ofm.dtype = DataType.int32
374 assert semantic_checker.is_operator_semantic_valid(op)
375
376
377def test_constraint_matching_either_shapes():
378 # BINARY CASE
379 # At least one ifm shape must match ofm's shape
380 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [4, 4])
381 assert semantic_checker.is_operator_semantic_valid(op)
382 op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [1, 4], [4, 4])
383 assert semantic_checker.is_operator_semantic_valid(op)
384 op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [4, 4], [2, 2])
385 assert not semantic_checker.is_operator_semantic_valid(op)
386 op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 4, 16])
387 assert not semantic_checker.is_operator_semantic_valid(op)
388 op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 4, 16])
389 assert not semantic_checker.is_operator_semantic_valid(op)
390
391 # UNARY CASE
392 # No second input so this is treated the same as requiring ifm shape to match ofm shape
393 op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32)
394 assert semantic_checker.is_operator_semantic_valid(op)
395 op = testutil.create_elemwise_op(Op.CLZ, "op", [4, 4], None, [2, 2], datatype=DataType.int32)
396 assert not semantic_checker.is_operator_semantic_valid(op)
397
398
399def test_constraint_alpha_valid():
400 # Alpha cannot be negative
401 op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2])
402 op.attrs["alpha"] = 0
403 assert semantic_checker.is_operator_semantic_valid(op)
404 op.attrs["alpha"] = -1
405 assert not semantic_checker.is_operator_semantic_valid(op)
406
407
408def test_constraint_hardswish_dtype():
409 # HardSwish operator dtype should be int8 or uint8, and input dtype must match output
410 # UINT8
411 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8])
412 assert semantic_checker.is_operator_semantic_valid(op)
413 # INT8
414 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
415 assert semantic_checker.is_operator_semantic_valid(op)
416
417 # Invalid
418 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16)
419 assert not semantic_checker.is_operator_semantic_valid(op)
420 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint16)
421 assert not semantic_checker.is_operator_semantic_valid(op)
422 op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
423 assert not semantic_checker.is_operator_semantic_valid(op)
424
425 in_tens = Tensor([1, 8, 8, 8], DataType.int8, "in")
426 out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out")
427 op = testutil.create_op(Op.HardSwish, [in_tens], out_tens)
428 assert not semantic_checker.is_operator_semantic_valid(op)
429
430
431def test_constraint_keep_dims_ifm_ofm():
432 op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
433 op.attrs["keep_num_dims"] = True
434 assert not semantic_checker.is_operator_semantic_valid(op)
435 op.attrs["keep_num_dims"] = False
436 assert semantic_checker.is_operator_semantic_valid(op)
437
438
439def create_mean(input_shape, output_shape, axis, datatype, attrs):
440 ifm = Tensor(input_shape, datatype, "in")
441 ifm.quantization = testutil.default_quant_params()
442 ofm = Tensor(output_shape, datatype, "out")
443 ofm.quantization = testutil.default_quant_params()
444 if type(axis) is list:
445 indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis, np.uint8)
446 elif type(axis) is int:
447 indices = create_const_tensor("indices", [], DataType.int32, axis, np.uint8)
448 op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
449 return op
450
451
452def test_mean_dtype():
453 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
454 assert semantic_checker.is_operator_semantic_valid(op)
455 op.ifm.dtype = DataType.int16
456 op.ofm.dtype = DataType.int16
457 assert not semantic_checker.is_operator_semantic_valid(op)
458
459
460def test_mean_axis():
461 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True})
462 assert not semantic_checker.is_operator_semantic_valid(op)
463 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True})
464 assert not semantic_checker.is_operator_semantic_valid(op)
465 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True})
466 assert not semantic_checker.is_operator_semantic_valid(op)
467 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
468 assert not semantic_checker.is_operator_semantic_valid(op)
469 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
470 assert semantic_checker.is_operator_semantic_valid(op)
471 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
472 assert semantic_checker.is_operator_semantic_valid(op)
473 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 2, DataType.int8, {"keep_dims": True})
474 assert semantic_checker.is_operator_semantic_valid(op)
475 op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [2, 1], DataType.int8, {"keep_dims": True})
476 assert semantic_checker.is_operator_semantic_valid(op)
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200477
478
479def test_matching_in_out_quant():
480 # quantisation parameters of ifm and ofm should match.
481 quant = testutil.default_quant_params()
482 # create reshape op
483 ifm_shape = [64, 16]
484 ofm_shape = [1, 4, 16, 16]
485 ifm = create_const_tensor("reshape_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
486 ifm.quantization = quant
487 ofm = create_const_tensor("reshape_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
488 ofm.quantization = quant.clone()
489 shape_tens = create_const_tensor("shape", [1], DataType.int32, ofm_shape)
490 op = testutil.create_op(Op.Reshape, [ifm, shape_tens], ofm, set_ifm_ofm_shapes=False)
491 op.attrs["new_shape"] = ofm_shape
492
493 # Matching quantisation parameters
494 assert semantic_checker.is_operator_semantic_valid(op)
495
496 # Different zp
497 ofm.quantization.zero_point = 32
498 assert not semantic_checker.is_operator_semantic_valid(op)
499
500 # Different scale
501 ofm.quantization.zero_point = 0
502 ofm.quantization.scale_f32 = 0.9
503 assert not semantic_checker.is_operator_semantic_valid(op)
504
505 # Squeeze op diff quant
506 # create squeeze op
507 ifm_shape = [1, 1, 1, 1001]
508 ofm_shape = [1, 1001]
509 ifm = create_const_tensor("squeeze_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
510 ifm.quantization = quant
511 ofm = create_const_tensor("squeeze_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
512 ofm.quantization = quant.clone()
513 ofm.quantization.zero_point = 32
514 op = testutil.create_op(Op.Squeeze, [ifm], ofm, set_ifm_ofm_shapes=False)
515 op.attrs["squeeze_dims"] = [1, 2]
516 assert not semantic_checker.is_operator_semantic_valid(op)
517
518 # ExpandDims diff quant
519 quant = testutil.default_quant_params()
520 # create expand_dims op
521 ifm_shape = [4, 16, 16]
522 ofm_shape = [1, 4, 16, 16]
523 ifm = create_const_tensor("expand_dims_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
524 ifm.quantization = quant
525 ofm = create_const_tensor("expand_dims_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
526 ofm.quantization = quant.clone()
527 ofm.quantization.zero_point = 32
528 dim = create_const_tensor("expand_dims_dim", [], DataType.uint8, 0)
529 op = testutil.create_op(Op.ExpandDims, [ifm, dim], ofm, set_ifm_ofm_shapes=False)
530 assert not semantic_checker.is_operator_semantic_valid(op)