blob: e290dd2c4438399298b186e0848e8a0c0ccadff7 [file] [log] [blame]
# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the License); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Description:
# Unit tests for tflite_model_semantic
import numpy as np
from ethosu.vela.data_type import DataType
from ethosu.vela.operation import Op
from ethosu.vela.operation import Padding
from ethosu.vela.tensor import create_const_tensor
from ethosu.vela.tensor import QuantizationParameters
from ethosu.vela.tensor import Tensor
from ethosu.vela.test import testutil
from ethosu.vela.tflite_model_semantic import TFLiteSemantic
semantic_checker = TFLiteSemantic()
def test_constraint_tens_no_dynamic():
# Tensors cannot be dynamic (no shape, not a scalar)
op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [])
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_tens_defined_shape():
# Tensors cannot have None in them
op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, None, 8], [1, 8, 8, 8])
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_tens_output_scalar():
# Scalar output is not allowed at all:
op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [])
op.ofm.values = 0.5
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_tens_input_scalar():
# Shapeless input is allowed if its of a certain type:
op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8])
assert semantic_checker.is_operator_semantic_valid(op)
# Invalid shapeless input due to op type:
op = testutil.create_op_with_quant_tensors(Op.Relu, [], [1, 8, 8, 8])
op.ifm.values = 0.5
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_tens_shape_size():
# Tensors cannot be > 4D
op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 1, 8, 8, 8], [1, 1, 8, 8, 8], set_ifm_ofm_shapes=False)
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_tens_quant_none_check():
# Tensors must have quantization parameters
op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm2_quant=None)
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_tens_quant_scale():
# Quantization scale cannot be infinite
qp = QuantizationParameters()
qp.zero_point = 0
qp.scale_f32 = np.inf
op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm_quant=qp)
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_fc_output_2d_not_supp():
op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [7, 4, 6], [3, 2, 2, 8], weights_shape=[1, 9, 1])
assert not semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 6, 1], [3, 7, 4], weights_shape=[1, 1, 7, 1])
assert not semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 1, 4, 7], [1, 9], weights_shape=[12, 3])
assert not semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4], [9], weights_shape=[3, 2])
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_fc_output_2d_is_supp():
op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
assert semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1024], [16, 64], weights_shape=[1, 1024])
assert semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 1, 1], weights_shape=[12, 1, 1, 1])
assert semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [3, 2, 1], weights_shape=[12, 1, 1, 1])
assert semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1], [1, 1, 3, 2], weights_shape=[12, 1, 1, 1])
assert semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [12, 1, 1, 1], [1, 1, 1], weights_shape=[12, 1, 1, 1])
assert semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_op_with_quant_tensors(
Op.FullyConnected, [12, 1, 1, 1], [1, 1, 24], weights_shape=[12, 1, 1, 1]
)
assert semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [1, 1, 1, 1], [1, 3, 4], weights_shape=[1, 1, 1, 1])
assert semantic_checker.is_operator_semantic_valid(op)
def test_constraint_conv_pass():
# First test a simple conv passes
op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 1, 1, 1], [1, 1, 1, 1], weights_shape=[1, 1, 1, 1])
op.attrs = {"stride_w": 1, "stride_h": 1}
assert semantic_checker.is_operator_semantic_valid(op)
def test_constraint_stride_type():
# Stride width and height must be integer types
op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
op.attrs = {"stride_w": 1.5, "stride_h": "1"}
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_dilation_type():
# Dilation width and height must be integer types
op = testutil.create_op_with_quant_tensors(Op.Conv2DBias, [1, 8, 8, 8], [1, 8, 8, 8])
op.attrs = {"stride_w": 1, "stride_h": 1, "dilation_w_factor": 1.5, "dilation_h_factor": "1"}
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_quant_scale_inf():
# Test handling IFM scale/OFM scale is infinite
op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8])
op.ifm.quantization.scale_f32 = np.float32(1e9)
op.ofm.quantization.scale_f32 = np.float32(1e-35)
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_ofm_scale_too_small():
# Tests handling of OFM scale < 1e-38
shp = [1, 10, 20, 16]
op = testutil.create_elemwise_op(
Op.Mul,
"mul",
shp,
shp,
shp,
ofm_quant=testutil.default_quant_params(),
)
assert semantic_checker.is_operator_semantic_valid(op)
op.ofm.quantization.scale_f32 = 1e-43
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_matching_in_out_types():
# Valid
op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 2, "padding": Padding.SAME}
assert semantic_checker.is_operator_semantic_valid(op)
# Invalid. datatypes for ifm and ofm must match (default uint8)
op.ifm.dtype = DataType.int8
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_filter_type():
# Filter width/height must be integers
op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2.5, "filter_height": "2", "padding": Padding.SAME}
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_matching_shapes():
# Softmax requires the ifm and ofm shapes to match
op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 2, 2, 4])
assert not semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
assert semantic_checker.is_operator_semantic_valid(op)
def test_constraint_beta_value_range():
# beta must be positive
op = testutil.create_op_with_quant_tensors(Op.Softmax, [1, 1, 1, 8], [1, 1, 1, 8])
op.attrs["beta"] = -1.0
assert not semantic_checker.is_operator_semantic_valid(op)
op.attrs["beta"] = 0.0
assert semantic_checker.is_operator_semantic_valid(op)
def test_constraint_splitv_inferred():
# SplitV requires a maximum of one inferred shape (-1)
qp = testutil.default_quant_params()
op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], np.int16, quantization=qp)
op.add_input_tensor(sizes)
assert not semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], np.int16, quantization=qp)
op.add_input_tensor(sizes)
assert semantic_checker.is_operator_semantic_valid(op)
def test_constraint_concat_pass():
# A working concat
op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
ifm2.quantization = testutil.default_quant_params()
op.add_input_tensor(ifm2)
op.attrs["axis"] = 3
assert semantic_checker.is_operator_semantic_valid(op)
def test_constraint_axis_exists():
# Missing axis attribute
op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
ifm2.quantization = testutil.default_quant_params()
op.add_input_tensor(ifm2)
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_axis_valid():
# Invalid axis attribute
op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
ifm2 = Tensor([1, 1, 1, 4], DataType.uint8, "in2")
ifm2.quantization = testutil.default_quant_params()
op.add_input_tensor(ifm2)
op.attrs["axis"] = 7
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_matching_dimensionality():
# Mismatching dimensionality: 4D+2D=4D
op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
ifm2 = Tensor([1, 4], DataType.uint8, "in2")
ifm2.quantization = testutil.default_quant_params()
op.add_input_tensor(ifm2)
op.attrs["axis"] = 3
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_valid_dimensions():
# Mismatching dimension value:
# ifm2 has w and h as 2, which is not the axis to concat and doesnt match ifm1 or ofm
op = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
ifm2 = Tensor([1, 2, 2, 4], DataType.uint8, "in2")
ifm2.quantization = testutil.default_quant_params()
op.add_input_tensor(ifm2)
op.attrs["axis"] = 3
assert not semantic_checker.is_operator_semantic_valid(op)
def create_strided_slice_op(in_shape, out_shape, start_offsets, end_offsets):
qp = testutil.default_quant_params()
in0 = Tensor(in_shape, DataType.uint8, "in")
in0.quantization = qp
in1 = create_const_tensor("begin", [len(start_offsets)], DataType.uint8, start_offsets, quantization=qp)
in2 = create_const_tensor("end", [len(end_offsets)], DataType.uint8, end_offsets, quantization=qp)
in3 = create_const_tensor("strides", [len(end_offsets)], DataType.uint8, len(end_offsets) * [1], quantization=qp)
out = Tensor(out_shape, DataType.uint8, "out")
out.quantization = qp
attrs = {"ellipsis_mask": 0, "new_axis_mask": 0, "shrink_axis_mask": 0, "begin_mask": 0, "end_mask": 0}
return testutil.create_op(Op.StridedSlice, [in0, in1, in2, in3], out, attrs=attrs)
def create_pad_op(
in_shape,
out_shape,
padding,
in_dtype=DataType.int8,
out_dtype=DataType.int8,
pad_dtype=DataType.int32,
):
qp = testutil.default_quant_params()
in0 = Tensor(in_shape, in_dtype, "in")
in0.quantization = qp
pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
out = Tensor(out_shape, out_dtype, "out")
out.quantization = qp.clone()
op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
return op
def test_constraint_pad_input_count():
# Incorrect number of input tensors (2)
op = create_pad_op(
in_shape=[1, 1, 1, 1],
out_shape=[1, 3, 3, 1],
padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
)
assert semantic_checker.is_operator_semantic_valid(op)
op.add_input_tensor(op.inputs[0].clone())
assert not semantic_checker.is_operator_semantic_valid(op)
def create_strided_slice():
# Creates a valid strided slice operator with some valid inputs/outputs
op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
op.attrs["begin_mask"] = 1
op.attrs["end_mask"] = 9
assert semantic_checker.is_operator_semantic_valid(op)
return op
def test_constraint_stridedslice_input_count():
# Wrong number of input tensors
op = create_strided_slice()
op.add_input_tensor(op.inputs[0].clone())
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_stridedslice_inputs_const():
# begin, end, stride values must not be None
op = create_strided_slice()
op.inputs[1].values = None
assert not semantic_checker.is_operator_semantic_valid(op)
op = create_strided_slice()
op.inputs[2].values = None
assert not semantic_checker.is_operator_semantic_valid(op)
op = create_strided_slice()
op.inputs[3].values = None
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_ellipsis_mask():
# Unsemantic_checkered ellipsis mask
op = create_strided_slice()
op.attrs["ellipsis_mask"] = 1
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_axis_masks():
op = create_strided_slice()
# Setting one of new_axis_mask/shrink_axis_mask to non-zero is ok
op.attrs["new_axis_mask"] = 2
assert semantic_checker.is_operator_semantic_valid(op)
op = create_strided_slice()
op.attrs["shrink_axis_mask"] = 3
assert semantic_checker.is_operator_semantic_valid(op)
# But setting both to non-zero is not semantic_checkered
op.attrs["new_axis_mask"] = 2
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_slice_ranges():
# Examples where end offset <= begin offset
op = create_strided_slice()
op.inputs[1].values = [0, 7, 2, 0]
assert not semantic_checker.is_operator_semantic_valid(op)
op = create_strided_slice()
op.inputs[2].values = [0, 7, 2, 0]
assert not semantic_checker.is_operator_semantic_valid(op)
op = create_strided_slice()
op.attrs["begin_mask"] = 0
assert not semantic_checker.is_operator_semantic_valid(op)
op = create_strided_slice()
op.attrs["end_mask"] = 0
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_matching_inputs_types():
# input data types must match (default is uint8)
op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
op.ifm2.dtype = DataType.int8
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_matching_signed():
# signed inputs require output to also be signed
op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
op.ofm.dtype = DataType.uint8
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_unsigned_valid():
# unsigned inputs require output to be either:
op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [1, 8, 8, 8], [1, 8, 8, 8])
# the same (default uint8)
assert semantic_checker.is_operator_semantic_valid(op)
op.ofm.dtype = DataType.int8
assert not semantic_checker.is_operator_semantic_valid(op)
op.ofm.dtype = DataType.int16
assert not semantic_checker.is_operator_semantic_valid(op)
# or int32
op.ofm.dtype = DataType.int32
assert semantic_checker.is_operator_semantic_valid(op)
def test_constraint_matching_either_shapes():
# BINARY CASE
# At least one ifm shape must match ofm's shape
op = testutil.create_elemwise_op(Op.Add, "op", [1, 4], [4, 4], [4, 4])
assert semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [1, 4], [4, 4])
assert semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_elemwise_op(Op.Add, "op", [4, 4], [4, 4], [2, 2])
assert not semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_elemwise_op(Op.Add, "op", [1, 4, 1, 16], [1, 1, 4, 1], [1, 4, 4, 16])
assert not semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_elemwise_op(Op.Add, "op", [1, 1, 4, 1], [1, 4, 1, 16], [1, 4, 4, 16])
assert not semantic_checker.is_operator_semantic_valid(op)
# UNARY CASE
# No second input so this is treated the same as requiring ifm shape to match ofm shape
op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32)
assert semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_elemwise_op(Op.CLZ, "op", [4, 4], None, [2, 2], datatype=DataType.int32)
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_alpha_valid():
# Alpha can only be negative for int8 and uint8
op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2], DataType.int16)
op.attrs["alpha"] = 0
assert semantic_checker.is_operator_semantic_valid(op)
op.attrs["alpha"] = -1
assert not semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2], DataType.int8)
op.attrs["alpha"] = 0
assert semantic_checker.is_operator_semantic_valid(op)
op.attrs["alpha"] = -1
assert semantic_checker.is_operator_semantic_valid(op)
def test_constraint_hardswish_dtype():
# HardSwish operator dtype should be int8 or uint8, and input dtype must match output
# UINT8
op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8])
assert semantic_checker.is_operator_semantic_valid(op)
# INT8
op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
assert semantic_checker.is_operator_semantic_valid(op)
# Invalid
op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16)
assert not semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint16)
assert not semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
assert not semantic_checker.is_operator_semantic_valid(op)
in_tens = Tensor([1, 8, 8, 8], DataType.int8, "in")
out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out")
op = testutil.create_op(Op.HardSwish, [in_tens], out_tens)
assert not semantic_checker.is_operator_semantic_valid(op)
def test_constraint_keep_dims_ifm_ofm():
op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
op.attrs["keep_num_dims"] = True
assert not semantic_checker.is_operator_semantic_valid(op)
op.attrs["keep_num_dims"] = False
assert semantic_checker.is_operator_semantic_valid(op)
def create_mean(input_shape, output_shape, axis, datatype, attrs):
ifm = Tensor(input_shape, datatype, "in")
ifm.quantization = testutil.default_quant_params()
ofm = Tensor(output_shape, datatype, "out")
ofm.quantization = testutil.default_quant_params()
if type(axis) is list:
indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis, np.uint8)
elif type(axis) is int:
indices = create_const_tensor("indices", [], DataType.int32, axis, np.uint8)
op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
return op
def test_mean_dtype():
op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
assert semantic_checker.is_operator_semantic_valid(op)
op.ifm.dtype = DataType.int16
op.ofm.dtype = DataType.int16
assert not semantic_checker.is_operator_semantic_valid(op)
def test_mean_axis():
op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 0, DataType.int8, {"keep_dims": True})
assert not semantic_checker.is_operator_semantic_valid(op)
op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [3], DataType.int8, {"keep_dims": True})
assert not semantic_checker.is_operator_semantic_valid(op)
op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 3], DataType.int8, {"keep_dims": True})
assert not semantic_checker.is_operator_semantic_valid(op)
op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [0, 1], DataType.int8, {"keep_dims": True})
assert not semantic_checker.is_operator_semantic_valid(op)
op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
assert semantic_checker.is_operator_semantic_valid(op)
op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [1], DataType.int8, {"keep_dims": True})
assert semantic_checker.is_operator_semantic_valid(op)
op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], 2, DataType.int8, {"keep_dims": True})
assert semantic_checker.is_operator_semantic_valid(op)
op = create_mean([1, 6, 6, 16], [1, 1, 1, 16], [2, 1], DataType.int8, {"keep_dims": True})
assert semantic_checker.is_operator_semantic_valid(op)
def test_matching_in_out_quant():
# quantisation parameters of ifm and ofm should match.
quant = testutil.default_quant_params()
# create reshape op
ifm_shape = [64, 16]
ofm_shape = [1, 4, 16, 16]
ifm = create_const_tensor("reshape_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
ifm.quantization = quant
ofm = create_const_tensor("reshape_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
ofm.quantization = quant.clone()
shape_tens = create_const_tensor("shape", [1], DataType.int32, ofm_shape)
op = testutil.create_op(Op.Reshape, [ifm, shape_tens], ofm, set_ifm_ofm_shapes=False)
op.attrs["new_shape"] = ofm_shape
# Matching quantisation parameters
assert semantic_checker.is_operator_semantic_valid(op)
# Different zp
ofm.quantization.zero_point = 32
assert not semantic_checker.is_operator_semantic_valid(op)
# Different scale
ofm.quantization.zero_point = 0
ofm.quantization.scale_f32 = 0.9
assert not semantic_checker.is_operator_semantic_valid(op)
# Squeeze op diff quant
# create squeeze op
ifm_shape = [1, 1, 1, 1001]
ofm_shape = [1, 1001]
ifm = create_const_tensor("squeeze_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
ifm.quantization = quant
ofm = create_const_tensor("squeeze_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
ofm.quantization = quant.clone()
ofm.quantization.zero_point = 32
op = testutil.create_op(Op.Squeeze, [ifm], ofm, set_ifm_ofm_shapes=False)
op.attrs["squeeze_dims"] = [1, 2]
assert not semantic_checker.is_operator_semantic_valid(op)
# ExpandDims diff quant
quant = testutil.default_quant_params()
# create expand_dims op
ifm_shape = [4, 16, 16]
ofm_shape = [1, 4, 16, 16]
ifm = create_const_tensor("expand_dims_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
ifm.quantization = quant
ofm = create_const_tensor("expand_dims_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
ofm.quantization = quant.clone()
ofm.quantization.zero_point = 32
dim = create_const_tensor("expand_dims_dim", [], DataType.uint8, 0)
op = testutil.create_op(Op.ExpandDims, [ifm, dim], ofm, set_ifm_ofm_shapes=False)
assert not semantic_checker.is_operator_semantic_valid(op)