blob: 1c258de8fdb8c98fd0286093e5ada6f4bafccd52 [file] [log] [blame]
Johan Alfven12e48112023-01-31 10:26:26 +01001# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020017# Description:
18# The TFLiteSemantic class which is a collection of TensorFlow lite model semantic checks.
19from collections import defaultdict
20
21import numpy as np
22
23from .data_type import BaseType
24from .data_type import DataType
25from .numeric_util import is_integer
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020026from .operation import Op
27from .supported_operators_util import docstring_format_args
28from .supported_operators_util import list_formatter
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020029from .tensor import check_quantized_tens_scaling_equal
Johan Alfven3ac03be2023-03-01 09:53:35 +010030from .tensor import shape_num_elements
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020031from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN
32from .tflite_mapping import optype_to_builtintype
33
34
35def _optype_formatter(op_list):
36 # Convert internal op types to external names
37 output = map(optype_to_builtintype, op_list)
38 # Remove UNKNOWNs
39 output = (x for x in output if x is not BUILTIN_OPERATOR_UNKNOWN)
40 return list_formatter(output)
41
42
43class TFLiteSemantic:
44 # Categorised lists of operators
Jonas Ohlssond8575072022-03-30 10:30:25 +020045 convolution_ops = set(
46 (
47 Op.Conv2DBias,
48 Op.Conv2D,
49 Op.QuantizedConv2D,
50 )
51 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020052 depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
53 transpose_convolution_ops = set((Op.Conv2DBackpropInput,))
54 convolution_like_ops = convolution_ops | depthwise_convolution_ops | transpose_convolution_ops
55 max_pooling_ops = Op.op_set(Op.is_maxpool_op)
56 avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
57 pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
58 unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
Jonas Ohlssond8575072022-03-30 10:30:25 +020059 binary_elem_wise_min_max_ops = set(
60 (
61 Op.Minimum,
62 Op.Maximum,
63 )
64 )
65 binary_elem_wise_shift_ops = set(
66 (
67 Op.SHL,
68 Op.SHR,
69 )
70 )
71 binary_elem_wise_add_mul_sub = set(
72 (
73 Op.Add,
74 Op.Mul,
75 Op.Sub,
76 )
77 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020078 binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
Johan Alfven906c9e82023-05-25 11:18:50 +020079 elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops | set((Op.SquaredDifference,))
Rickard Bolin6986a072022-12-19 12:33:40 +000080 shapeless_input_ops = binary_elem_wise_main_ops | set(
81 (Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims, Op.Quantize, Op.ArgMax)
82 )
Jonas Ohlssond8575072022-03-30 10:30:25 +020083 reshape_ops = set(
84 (
85 Op.Reshape,
86 Op.QuantizedReshape,
87 Op.Squeeze,
88 Op.ExpandDims,
89 )
90 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020091
92 def __init__(self):
93 # Setup the generic constraints. Note: the order matters
94 self.generic_constraints = []
Tim Hall2180a172023-03-10 18:11:34 +000095 self.generic_constraints.append(TFLiteSemantic.constraint_attributes_specified)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020096 self.generic_constraints.append(TFLiteSemantic.constraint_tens_no_dynamic)
97 self.generic_constraints.append(TFLiteSemantic.constraint_tens_defined_shape)
98 self.generic_constraints.append(TFLiteSemantic.constraint_tens_output_scalar)
99 self.generic_constraints.append(TFLiteSemantic.constraint_tens_input_scalar)
100 self.generic_constraints.append(TFLiteSemantic.constraint_tens_shape_size)
101
102 self.generic_constraints.append(TFLiteSemantic.constraint_tens_quant_none_check)
103 self.generic_constraints.append(TFLiteSemantic.constraint_tens_quant_scale)
104 self.generic_constraints.append(TFLiteSemantic.constraint_quant_scale_inf)
erik.andersson@arm.com3bbbed62021-12-20 14:14:16 +0100105 self.generic_constraints.append(TFLiteSemantic.constraint_none_const_tensors)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200106
107 # Setup specific constraints. Note: the order matters
108 self.specific_constraints = defaultdict(list)
109
110 # Conv-like checks:
111 for op_type in TFLiteSemantic.convolution_like_ops:
112 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_stride_type)
Tim Hall9cf63a32023-06-27 12:07:49 +0100113 if op_type in TFLiteSemantic.convolution_ops:
114 # Only Conv has groups
115 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_conv_groups_ifm_depth)
116 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_conv_groups_num_filters)
Tim Hallea4ba662022-11-11 18:19:53 +0000117 if op_type not in TFLiteSemantic.transpose_convolution_ops:
118 # Transpose Conv does not contain dilation
119 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_dilation_type)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200120
121 # Pooling checks:
122 for op_type in TFLiteSemantic.pooling_ops:
123 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_stride_type)
124 # AVG pooling specific checks:
125 for op_type in TFLiteSemantic.avg_pooling_ops:
126 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
127 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_filter_type)
128 # MAX pooling specific checks:
129 for op_type in TFLiteSemantic.max_pooling_ops:
130 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
131 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_filter_type)
132
133 # Concat specific checks:
134 for op_type in (Op.Concat, Op.ConcatTFLite):
135 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_axis_exists)
136 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_axis_valid)
137 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_dimensionality)
138 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_valid_dimensions)
Johan Alfvénb3932512022-09-12 17:44:25 +0200139 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_valid_dimensions_axis)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200140
141 # Element-wise checks:
142 for op_type in TFLiteSemantic.elem_wise_main_ops:
143 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_either_shapes)
144 # Unary specific checks:
145 for op_type in TFLiteSemantic.unary_elem_wise_main_ops:
146 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
147 # Binary Min/Max specific checks:
148 for op_type in TFLiteSemantic.binary_elem_wise_min_max_ops:
149 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
150 # Binary Add/Mul/Sub specific checks:
151 for op_type in TFLiteSemantic.binary_elem_wise_add_mul_sub:
152 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_inputs_types)
153 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_signed)
154 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_unsigned_valid)
155
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200156 # Ops reshaping dimensions: Reshape, Squeeze and ExpandDims
157 for op_type in TFLiteSemantic.reshape_ops:
158 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_quant)
Johan Alfven3ac03be2023-03-01 09:53:35 +0100159 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_elements)
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200160
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200161 # Softmax specific checks:
162 self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_shapes)
163 self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_in_out_types)
164 self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_beta_value_range)
165
Johan Alfven12e48112023-01-31 10:26:26 +0100166 # Split specific checks:
167 self.specific_constraints[Op.Split].append(TFLiteSemantic.constraint_split_axis)
168 self.specific_constraints[Op.Split].append(TFLiteSemantic.constraint_split_num_splits)
169
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200170 # SplitV specific checks:
171 self.specific_constraints[Op.SplitV].append(TFLiteSemantic.constraint_splitv_inferred)
172
173 # StridedSlice specific checks:
174 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_stridedslice_input_count)
175 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_stridedslice_inputs_const)
176 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_ellipsis_mask)
177 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_axis_masks)
178 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_slice_ranges)
179
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200180 # FullyConnected specific checks:
181 self.specific_constraints[Op.FullyConnected].append(TFLiteSemantic.constraint_fc_output_2d)
182 self.specific_constraints[Op.FullyConnected].append(TFLiteSemantic.constraint_keep_dim_ifm_ofm)
183
184 # Pad specific checks:
185 self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_input_count)
186 self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_constant)
Johan Gunnarsson81b765d2023-08-04 17:16:29 +0200187 self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_output_shape)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200188
189 # HardSwish specific checks:
190 self.specific_constraints[Op.HardSwish].append(TFLiteSemantic.constraint_input_8bit)
191 self.specific_constraints[Op.HardSwish].append(TFLiteSemantic.constraint_matching_in_out_types)
Fredrik Svedberg701ba912022-09-07 16:01:15 +0200192
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200193 # Mean specific checks:
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200194 self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_input_dims)
195 self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_axis)
196
Rickard Bolin6986a072022-12-19 12:33:40 +0000197 # ArgMax specific checks:
198 self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_input_8bit)
Johan Alfvenc1ad80b2023-03-31 10:19:23 +0200199 self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_argmax_output)
Rickard Bolin6986a072022-12-19 12:33:40 +0000200
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200201 # UnidirectionalSequenceLstm specific checks:
202 self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_input_signed)
203 self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_matching_in_out_types)
204 self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_dimensions)
205 self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_inputs)
206 self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_intermediates)
207 self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_variables)
208
Johan Alfvence502732023-04-24 13:35:40 +0200209 # Exp specific checks
210 self.specific_constraints[Op.Exp].append(TFLiteSemantic.constraint_input_signed)
211
Johan Alfvenf418e832023-11-13 10:23:32 +0100212 # Transpose specific checks
213 self.specific_constraints[Op.Transpose].append(TFLiteSemantic.constraint_transpose_permutation_size)
214 self.specific_constraints[Op.Transpose].append(TFLiteSemantic.constraint_transpose_permutation_values)
215
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200216 def is_operator_semantic_valid(self, op):
217 ext_type = optype_to_builtintype(op.type)
218
219 if op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const):
220 return True
221
Ayaan Masood4965fae2022-06-29 11:30:57 +0100222 # Generic constraints list filtered out to exclude certain constraints depending on op.type
223 filtered_generic_constraints = []
224
225 for constraint in self.generic_constraints:
226 # Check constraint not in dictionary otherwise return empty array
227 if constraint not in self.get_generic_constraint_exclude_list().get(op.type, []):
228 filtered_generic_constraints.append(constraint)
229
230 for constraint in filtered_generic_constraints + self.specific_constraints[op.type]:
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200231 valid, extra = constraint(op)
232 if not valid:
233 print(
Tim Hall3584a9c2021-11-18 22:05:17 +0000234 f"Warning: Unsupported TensorFlow Lite semantics for {ext_type} '{op.name}'. Placing on CPU instead"
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200235 )
236 print(f" - {constraint.__doc__}")
237 if extra:
238 print(f" {extra}")
239 return False
240
241 return True
242
243 @staticmethod
Ayaan Masood4965fae2022-06-29 11:30:57 +0100244 def get_generic_constraint_exclude_list():
245
246 # Not all generic constraints can be applied to each operator
247 generic_constraints_exclude_list = {
248 Op.Shape: [
249 TFLiteSemantic.constraint_tens_quant_none_check,
Ayaan Masood25f48dd2022-06-29 18:16:04 +0100250 ],
251 Op.Quantize: [
252 TFLiteSemantic.constraint_tens_no_dynamic,
253 TFLiteSemantic.constraint_tens_output_scalar,
Ayaan Masood25f48dd2022-06-29 18:16:04 +0100254 ],
Rickard Bolin6986a072022-12-19 12:33:40 +0000255 Op.ArgMax: [
256 TFLiteSemantic.constraint_tens_quant_none_check,
257 ],
Johan Alfvena8fda882023-10-28 16:04:46 +0200258 Op.Transpose: [
259 TFLiteSemantic.constraint_tens_quant_none_check,
260 ],
Rickard Bolinfdbb0722023-09-05 11:38:19 +0000261 Op.MirrorPad: [
262 TFLiteSemantic.constraint_tens_quant_none_check,
263 ],
Ayaan Masood4965fae2022-06-29 11:30:57 +0100264 }
265 return generic_constraints_exclude_list
266
267 @staticmethod
erik.andersson@arm.com3bbbed62021-12-20 14:14:16 +0100268 def constraint_none_const_tensors(op):
269 "Constant tensors should not have NoneType-values"
270 valid = True
271 extra = ""
272 for tens in filter(None, op.inputs):
273 if len(tens.ops) > 0 and tens.ops[0].type == Op.Const and tens.values is None:
274 valid = False
275 extra = str(tens.name)
276 return valid, f"Unexpected None value for constant tensor: {extra}"
277
278 @staticmethod
Tim Hall2180a172023-03-10 18:11:34 +0000279 def constraint_attributes_specified(op):
280 "All required operator attributes must be specified"
281 # operators that have been created internally (i.e. not created as part of reading an input network) may not
282 # have the read error attribute
283 attribute_read_error = op.attrs.get("attribute_read_error", [])
284 valid = len(attribute_read_error) == 0
285 extra = ", ".join(attribute_read_error)
286 return valid, f"Op has missing attributes: {extra}"
287
288 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200289 def constraint_tens_no_dynamic(op):
290 "Input(s) and Output tensors must not be dynamic"
291 valid = True
292 extra = []
293 tensors = [tens for tens in op.inputs + op.outputs if tens]
294 for tens in tensors:
295 if (tens.shape == []) and (tens.values is None):
296 valid = False
297 extra.append(tens.name)
298 extra = ", ".join(extra)
299 return valid, f"Op has dynamic tensor(s): {extra}"
300
301 @staticmethod
302 def constraint_tens_defined_shape(op):
303 "Input(s) and Output tensors must have a defined shape"
304 valid = True
305 extra = []
306 tensors = [tens for tens in op.inputs + op.outputs if tens]
307 for tens in tensors:
308 if not tens.has_fully_defined_shape():
309 valid = False
310 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
311 return valid, ", ".join(extra)
312
313 @staticmethod
314 def constraint_tens_output_scalar(op):
315 "Output tensors cannot be scalar"
316 ofm = op.ofm
317 valid = ofm.shape != []
318 return valid, f"Output Tensor '{ofm.name}' is scalar"
319
320 @classmethod
321 @docstring_format_args([_optype_formatter(shapeless_input_ops)])
322 def constraint_tens_input_scalar(cls, op):
323 "Scalar Input tensors are only valid for op type: {}"
324 valid = True
325 extra = []
326 tensors = [tens for tens in op.inputs if tens]
327 for tens in tensors:
328 if (tens.shape == []) and (op.type not in cls.shapeless_input_ops):
329 valid = False
330 extra.append(tens.name)
331 extra = ", ".join(extra)
332 return valid, f"Op has scalar input tensor(s): {extra}"
333
334 @staticmethod
335 def constraint_tens_shape_size(op):
336 "Input(s) and Output tensors must not be greater than 4D"
337 valid = True
338 extra = []
339 tensors = [tens for tens in op.inputs + op.outputs if tens]
340 for tens in tensors:
341 if len(tens.shape) > 4:
342 valid = False
343 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
344 return valid, ", ".join(extra)
345
346 @staticmethod
347 def constraint_tens_quant_none_check(op):
348 "Input(s), Output and Weight tensors must have quantization parameters"
349 valid = True
350 extra = []
351 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
352 for tens in tensors:
353 if tens.quantization is None:
354 valid = False
355 extra.append(tens.name)
356 extra = ", ".join(extra)
357 return valid, f"Op has tensors with missing quantization parameters: {extra}"
358
359 @staticmethod
360 def constraint_tens_quant_scale(op):
361 "Input(s), Output and Weight tensors with quantization scales must be finite"
362 valid = True
363 extra = []
364 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
365 for tens in tensors:
Fredrik Svedberg11563172022-07-06 14:54:12 +0200366 if (
367 tens.quantization
368 and tens.quantization.scale_f32 is not None
369 and np.isinf(tens.quantization.scale_f32).any()
370 ):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200371 valid = False
372 extra.append(f"Tensor '{tens.name}' has quantization scale: {tens.quantization.scale_f32}")
373 return valid, ", ".join(extra)
374
375 @staticmethod
376 def constraint_fc_output_2d(op):
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100377 """The output tensor(s) must have 2D shape"""
378 valid = op.ifm.get_shape_as_2d(op.weights.shape[-2]) is not None
379 extra = f"Op has non-2D output tensor '{op.ofm.name}'" if not valid else ""
380
381 return valid, extra
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200382
383 @staticmethod
384 def constraint_stride_type(op):
385 "Stride values for both width and height must be integer types"
386 w, h = op.get_kernel_stride()
387 valid = is_integer(w) and is_integer(h)
388 return valid, f"Op has stride WxH as: {repr(w)}x{repr(h)}"
389
390 @staticmethod
Tim Hall9cf63a32023-06-27 12:07:49 +0100391 def constraint_conv_groups_ifm_depth(op):
392 """IFM depth must be a whole multiple of the filter kernel depth"""
393 ifm_depth = op.ifm.shape[-1] # nhwc
394 kernel_ic = op.weights.shape[-2] # hwio
395 num_conv_groups = ifm_depth // kernel_ic
396
397 if ifm_depth % kernel_ic == 0:
398 op.attrs["num_conv_groups"] = num_conv_groups
399 valid = True
400 else:
401 valid = False
402
403 return valid, f"IFM depth = {ifm_depth} and filter kernel depth = {kernel_ic}"
404
405 @staticmethod
406 def constraint_conv_groups_num_filters(op):
407 """Number of filter kernels must be equally divisible by the number of convolution groups"""
408 ifm_depth = op.ifm.shape[-1] # nhwc
409 kernel_ic = op.weights.shape[-2] # hwio
410 kernel_oc = op.weights.shape[-1] # hwio
411 num_conv_groups = ifm_depth // kernel_ic
412
413 if kernel_oc % num_conv_groups == 0:
414 valid = True
415 else:
416 valid = False
417
418 return valid, f"Filter kernels = {kernel_oc} and convolution groups = {num_conv_groups}"
419
420 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200421 def constraint_dilation_type(op):
422 "Dilation factor values for both width and height must be integer types"
423 w, h = op.get_kernel_dilation()
424 valid = is_integer(w) and is_integer(h)
425 return valid, f"Op has dilation factor WxH as: {repr(w)}x{repr(h)}"
426
427 @staticmethod
428 def constraint_quant_scale_inf(op):
429 "Input and Output tensors must have quantization scales that fit within float32 precision"
430 if op.ofm is not None and op.ofm.is_quantized():
431 ofm_scale = op.ofm.quantization.scale_f32
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200432 if np.any(ofm_scale < np.finfo(np.float32).tiny):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200433 return (
434 False,
435 f"The quantization scale of the output tensor is {ofm_scale}, "
436 + f"minimum supported is: {np.finfo(np.float32).tiny}",
437 )
438 if op.ifm is not None and op.ifm.is_quantized():
439 ifm_scale = op.ifm.quantization.scale_f32
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200440 if np.any(np.isinf(ifm_scale / ofm_scale)):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200441 return (
442 False,
443 f"IFM scale divided by OFM scale is infinite, ifm_scale={ifm_scale} ofm_scale={ofm_scale}",
444 )
445 return True, "Op's quantization is ok"
446
447 @staticmethod
448 def constraint_matching_in_out_types(op):
449 "IFM and OFM data types must match"
450 ifm_dtype = op.ifm.dtype
451 ofm_dtype = op.ofm.dtype
452 valid = ifm_dtype == ofm_dtype
453 return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
454
455 @staticmethod
456 def constraint_beta_value_range(op):
457 "Beta value needs to be positive"
458 beta = op.attrs.get("beta", 1.0)
459 valid = beta >= 0
460 return valid, f"Op has beta={beta}"
461
462 @staticmethod
463 def constraint_filter_type(op):
464 "Kernel filter values for both width and height must be integer types"
465 w = op.kernel.width
466 h = op.kernel.height
467 valid = is_integer(w) and is_integer(h)
468 return valid, f"Op has kernel filter WxH as: {repr(w)}x{repr(h)}"
469
470 @staticmethod
471 def constraint_matching_shapes(op):
472 "IFM and OFM shapes must match"
473 ifm_shape = op.ifm.shape
474 ofm_shape = op.ofm.shape
475 valid = ifm_shape == ofm_shape
476 return valid, f"Op has ifm_shape={ifm_shape} and ofm_shape={ofm_shape}"
477
478 @staticmethod
Johan Alfven12e48112023-01-31 10:26:26 +0100479 def constraint_split_axis(op):
480 "Axis value must be in the range [-RANK(IFM) to +RANK(IFM))"
481 axis_tens = op.inputs[0]
482 input_tens = op.inputs[1]
483 dims = len(input_tens.shape)
Tim Hall762d3ac2023-07-06 11:42:02 +0100484 # handle axis being a scalar or 1-D array
William Isaksson75d34022023-08-10 12:22:44 +0000485 if axis_tens.values.ndim == 0:
486 axis = int(axis_tens.values)
487 else:
488 axis = int(axis_tens.values[0])
Johan Alfven12e48112023-01-31 10:26:26 +0100489 axis += dims if axis < 0 else 0
490 valid = 0 <= axis < dims
491 return valid, f"Op has ifm_dimensions={dims} and axis value is: {axis}"
492
493 @staticmethod
494 def constraint_split_num_splits(op):
495 "Axis must be divisible by number of splits"
496 num_splits = op.attrs.get("num_splits")
497 axis_tens = op.inputs[0]
498 input_tens = op.inputs[1]
499 dims = len(input_tens.shape)
Tim Hall762d3ac2023-07-06 11:42:02 +0100500 # handle axis being a scalar or 1-D array
William Isaksson75d34022023-08-10 12:22:44 +0000501 if axis_tens.values.ndim == 0:
502 axis = int(axis_tens.values)
503 else:
504 axis = int(axis_tens.values[0])
Johan Alfven12e48112023-01-31 10:26:26 +0100505 axis += dims if axis < 0 else 0
506 valid = input_tens.shape[axis] % num_splits == 0
507 return valid, f"Op has ifm shape={input_tens.shape} axis={axis} num_splits={num_splits}"
508
509 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200510 def constraint_splitv_inferred(op):
511 "Only one size is allowed to be inferred"
512 sizes = op.inputs[1].values
513 valid = np.count_nonzero(sizes == -1) <= 1
514 return valid, f"Op has multiple inferred sizes (-1): {sizes}"
515
516 @staticmethod
517 def constraint_axis_exists(op):
518 "Axis attribute must exist"
519 axis = op.attrs.get("axis")
520 valid = axis is not None
521 return valid, f"Op has axis={axis}"
522
523 @staticmethod
524 def constraint_axis_valid(op):
525 "Axis attribute must be in the range [0, <ofm_dimensions>)"
526 dims = len(op.ofm.shape)
527 axis = op.attrs["axis"]
528 axis += dims if axis < 0 else 0
529 valid = 0 <= axis < dims
530 return valid, f"Op has ofm_dimensions={dims} and axis attribute is: {axis}"
531
532 @staticmethod
533 def constraint_matching_dimensionality(op):
534 "All Input dimensionalities must match OFM dimensionality"
535 valid = True
536 extra = []
537 ofm_dim = len(op.ofm.shape)
538 tensors = [tens for tens in op.inputs if tens]
539 for tens in tensors:
540 dim = len(tens.shape)
541 if dim != ofm_dim:
542 valid = False
543 extra.append(f"Tensor '{tens.name}' has dimension: {dim}")
544 extra = ", ".join(extra)
545 return valid, f"Op has ofm_dimension={ofm_dim} and the list of mismatching inputs are: {extra}"
546
547 @staticmethod
548 def constraint_valid_dimensions(op):
549 "All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute"
550 valid = True
551 extra = []
552 ofm_shape = op.ofm.shape
553 ofm_dim = len(ofm_shape)
554 axis = op.attrs["axis"]
555 axis += ofm_dim if axis < 0 else 0
556 tensors = [tens for tens in op.inputs if tens]
557 for tens in tensors:
558 if any(tens.shape[dim] != ofm_shape[dim] for dim in range(ofm_dim) if dim != axis):
559 valid = False
560 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
561 extra = ", ".join(extra)
562 return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
563
564 @staticmethod
Johan Alfvénb3932512022-09-12 17:44:25 +0200565 def constraint_valid_dimensions_axis(op):
566 """The size of the OFM axis must match the sum of all IFM axis defined by the axis attribute"""
567 valid = True
568 extra = []
569 ofm_shape = op.ofm.shape
570 ofm_dim = len(ofm_shape)
571 axis = op.attrs["axis"]
572 axis += ofm_dim if axis < 0 else 0
573
574 sum_ifm_axis = 0
575 tensors = [tens for tens in op.inputs if tens]
576 for tens in tensors:
577 sum_ifm_axis += tens.shape[axis]
578 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
579
580 valid = sum_ifm_axis == ofm_shape[axis]
581 extra = ", ".join(extra)
582 return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
583
584 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200585 def constraint_stridedslice_input_count(op):
586 "Exactly 4 Input tensors are required"
587 inputs = len(op.inputs)
588 valid = inputs == 4
589 return valid, f"Op has {inputs} inputs"
590
591 @staticmethod
592 def constraint_pad_input_count(op):
593 "Number of input tensors must be exactly 2"
594 inputs = len(op.inputs)
595 valid = inputs == 2
596 return valid, f"Op has {inputs} inputs"
597
598 @staticmethod
599 def constraint_pad_constant(op):
600 "The padding tensor must be constant"
601 pad_tensor = op.inputs[1].values
602 valid = pad_tensor is not None
603 return valid, f"Op has non-constant padding tensor: {op.inputs[1].values}"
604
605 @staticmethod
Johan Gunnarsson81b765d2023-08-04 17:16:29 +0200606 def constraint_pad_output_shape(op):
607 "Shape of output tensor must equal to size of input tensor plus padding"
608 input_shape = op.inputs[0].shape
609 expected_output_shape = op.outputs[0].shape
610 pad_tensor = op.inputs[1].values
611 actual_output_shape = input_shape + pad_tensor.T[0] + pad_tensor.T[1]
612 valid = np.array_equal(actual_output_shape, expected_output_shape)
613 return valid, f"Op has wrong output tensor shape: {expected_output_shape}, has shape: {actual_output_shape}"
614
615 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200616 def constraint_stridedslice_inputs_const(op):
617 "Begin, End and Stride Input tensors must be constant"
618 valid = True
619 extra = []
620 _, begin, end, strides = op.inputs
621 if begin.values is None:
622 valid = False
623 extra.append(f"Begin tensor '{begin.name}'")
624 if end.values is None:
625 valid = False
626 extra.append(f"End tensor '{end.name}'")
627 if strides.values is None:
628 valid = False
629 extra.append(f"Stride tensor '{strides.name}'")
630 extra = ", ".join(extra)
631 return valid, f"Op has non-constant tensors: {extra}"
632
633 @staticmethod
634 def constraint_ellipsis_mask(op):
635 "ellipsis_mask must be 0"
636 ellipsis = op.attrs["ellipsis_mask"]
637 valid = ellipsis == 0
638 return valid, f"Op has ellipsis mask as: {ellipsis}"
639
640 @staticmethod
641 def constraint_axis_masks(op):
642 "new_axis_mask and shrink_axis_mask cannot both be set"
643 new_axis = op.attrs["new_axis_mask"]
644 shrink_axis = op.attrs["shrink_axis_mask"]
645 valid = (new_axis == 0) or (shrink_axis == 0)
646 return valid, f"Op has new_axis_mask={new_axis} and shrink_axis_mask={shrink_axis}"
647
Tim Halld0e41cf2023-02-14 14:54:18 +0000648 def _get_slice_offsets(input_shape, offset_tens, offset_mask, is_begin=True):
649 # For strided slice operator: get start or end offsets
650 # input_shape: List[int], offset_tens: Tensor, offset_mask: int, is_begin: bool = True
651 offsets = len(input_shape) * [0] if is_begin else input_shape[:]
652 for idx in range(len(input_shape)):
653 # If the i:th bit in the mask is not set then the value in offset_tens[i] should be used, otherwise it
654 # should be ignored
655 if (offset_mask & (1 << idx)) == 0:
656 offsets[idx] = offset_tens.values[idx]
657 if offsets[idx] < 0:
658 # Convert negative indexing to positive ones
659 offsets[idx] += input_shape[idx]
660 return offsets
661
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200662 @staticmethod
663 def constraint_slice_ranges(op):
664 "Slice 'end' values must be greater than 'begin' values"
665 ifm, begin, end, _ = op.inputs
Tim Halld0e41cf2023-02-14 14:54:18 +0000666 shrink_axis_mask = op.attrs["shrink_axis_mask"]
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200667 # Calculate offset begin/end
Tim Halld0e41cf2023-02-14 14:54:18 +0000668 offset_begin = TFLiteSemantic._get_slice_offsets(ifm.shape, begin, op.attrs["begin_mask"], is_begin=True)
669 offset_end = TFLiteSemantic._get_slice_offsets(ifm.shape, end, op.attrs["end_mask"], is_begin=False)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200670 # Check "end - begin" doesn't result in any zero or negative elements
Tim Halld0e41cf2023-02-14 14:54:18 +0000671 valid = True
672 # if a shrink mask bit is set then the end position provided by the operation should be ignored, and instead a
673 # new end position should be calculated so that calculations in the graph optimiser, such as (end - start),
674 # result in the correct value. otherwise, we just need to check that the begin and end values are valid
675 for i in range(len(ifm.shape)):
676 if (shrink_axis_mask & (1 << i)) != 0:
677 offset_end[i] = offset_begin[i] + 1
678 else:
679 if offset_end[i] <= offset_begin[i]:
680 valid = False
681
682 op.attrs["offset_begin"] = offset_begin
683 op.attrs["offset_end"] = offset_end
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200684 return valid, f"Op has begin_values={begin.values} and end_values={end.values}"
685
686 @staticmethod
687 def constraint_matching_inputs_types(op):
688 "Both Input data types must match"
689 ifm_dtype = op.ifm.dtype
690 ifm2_dtype = op.ifm2.dtype
691 valid = ifm_dtype == ifm2_dtype
692 return valid, f"Op has ifm_dtype={ifm_dtype} and ifm2_dtype={ifm2_dtype}"
693
694 @staticmethod
695 def constraint_matching_signed(op):
696 "For IFM that are signed, OFM must also be signed"
697 valid = True
698 ifm_dtype = op.ifm.dtype
699 ofm_dtype = op.ofm.dtype
700 if ifm_dtype.type & BaseType.Signed:
701 valid = bool(ofm_dtype.type & BaseType.Signed)
702 return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
703
704 @staticmethod
705 def constraint_unsigned_valid(op):
706 "For IFM that are unsigned, OFM must either be the same type or int32"
707 valid = True
708 ifm_dtype = op.ifm.dtype
709 ofm_dtype = op.ofm.dtype
710 if ifm_dtype.type & BaseType.Unsigned:
711 valid = (ifm_dtype == ofm_dtype) or (ofm_dtype == DataType.int32)
712 return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
713
714 @staticmethod
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200715 def constraint_input_signed(op):
716 "IFM must be int8 or int16"
717 ifm_dtype = op.ifm.dtype
718 valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.int16)
719 return valid, f"Op has ifm_dtype={ifm_dtype}"
720
721 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200722 def constraint_input_8bit(op):
723 "IFM must be int8 or uint8"
724 ifm_dtype = op.ifm.dtype
725 valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.uint8)
726 return valid, f"Op has ifm_dtype={ifm_dtype}"
727
728 @staticmethod
Johan Alfvenc1ad80b2023-03-31 10:19:23 +0200729 def constraint_argmax_output(op):
730 "OFM must be int32 or int64"
731 ofm_dtype = op.ofm.dtype
732 valid = ofm_dtype in (DataType.int32, DataType.int64)
733 return valid, f"Op has ofm_dtype={ofm_dtype}"
734
735 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200736 def constraint_matching_either_shapes(op):
737 "At least one Input's shape must match the OFM's shape"
738 ifm_shape = op.ifm.shape
739 ifm2_shape = op.ifm2.shape if op.ifm2 else None
740 ofm_shape = op.ofm.shape
741 valid = (ifm_shape == ofm_shape) or (ifm2_shape == ofm_shape)
742 return valid, f"Op has ifm_shape={ifm_shape}, ifm2_shape={ifm2_shape} and ofm_shape={ofm_shape}"
743
744 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200745 def constraint_keep_dim_ifm_ofm(op):
746 "The IFM and OFM must have the same number of dimensions if keep_num_dims is set to true"
747 valid = True
748 if op.attrs.get("keep_num_dims"):
749 valid = len(op.ifm.shape) == len(op.ofm.shape)
750 return valid, f"Op has ifm shape={op.ifm.shape} and ofm shape={op.ofm.shape}"
751
752 @staticmethod
753 def constraint_mean_input_dims(op):
754 "Input tensor must be at least 2D"
755 dims = len(op.inputs[0].shape)
756 return 2 <= dims <= 4, f"Input is {dims}D"
757
758 @staticmethod
759 def constraint_mean_axis(op):
Alexander Hansson1d5e8592023-06-27 12:36:25 +0000760 """Requirements for axis parameter:
761 When IFM tensor is 2D:
762 - Reduction in both axes is supported.
763 When IFM tensor is 3D or 4D:
764 - Reduction in Batch axis is only supported if batch size is 1.
765 - Reduction in both Height and Width axes is supported.
Alexander Hanssonda8741a2023-06-30 15:41:13 +0000766 - Reduction in Depth axis is supported if at least one of H,W,C are of size 1."""
Alexander Hansson1d5e8592023-06-27 12:36:25 +0000767 input_shape = op.inputs[0].shape
768 dims = len(input_shape)
769 if op.inputs[1].shape == []:
770 axis = [int(op.inputs[1].values)]
771 else:
772 axis = list(op.inputs[1].values)
773 valid = True
774
775 for ax in axis:
776 if ax < 0 or ax >= dims:
777 return False, "Axis parameter is out of bounds. axis: {axis}, dims: {dims}. "
Alexander Hanssonda8741a2023-06-30 15:41:13 +0000778
779 # Batch is only supported if batch shape is 1
780 if dims == 4 and ax == 0:
781 if input_shape[0] != 1:
Alexander Hansson1d5e8592023-06-27 12:36:25 +0000782 valid = False
783 break
Alexander Hanssonda8741a2023-06-30 15:41:13 +0000784
785 # Depth is supported if any of h,w,c == 1
786 if dims == 3:
787 if ax == 2 and not any([s == 1 for s in input_shape]):
788 valid = False
789 break
790
791 # Depth is supported if any of h,w,c == 1
792 if dims == 4:
793 if ax == 3 and not any([s == 1 for s in input_shape[1:]]):
Alexander Hansson1d5e8592023-06-27 12:36:25 +0000794 valid = False
795 break
796
797 return valid, f"Shape is {input_shape}, Axis is {axis}."
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200798
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200799 @staticmethod
800 def constraint_matching_in_out_quant(op):
801 "Input and output quantisation must match."
802 if not check_quantized_tens_scaling_equal(op.ifm, op.ofm):
803 return False, "IFM and OFM quantisation parameters are not equal."
804 return True, "IFM and OFM quantisation parameters matches."
805
Johan Alfven3ac03be2023-03-01 09:53:35 +0100806 @staticmethod
807 def constraint_matching_in_out_elements(op):
808 "Input and output number of elements must match."
809 if shape_num_elements(op.ifm.shape) != shape_num_elements(op.ofm.shape):
810 return False, f"IFM {op.ifm.shape} and OFM {op.ofm.shape} number of elements are not equal."
811 return True, "IFM and OFM number of elements are equal."
812
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200813 @staticmethod
814 def constraint_lstm_dimensions(op):
815 "IFM and OFM must have 3D shape"
816 valid = len(op.ifm.shape) == len(op.ofm.shape) == 3
817 return valid, f"Op has ifm shape {op.ifm.shape} and ofm shape {op.ofm.shape}"
818
819 @staticmethod
820 def constraint_lstm_inputs(op):
821 "Must have 24 input tensors"
822 n_inputs = len(op.inputs)
823 return n_inputs == 24, f"Op has {n_inputs} inputs"
824
825 @staticmethod
826 def constraint_lstm_intermediates(op):
827 "Must have 5 intermediate tensors"
828 n_intermediates = len(op.intermediates)
829 return n_intermediates == 5, f"Op has {n_intermediates} intermediates"
830
831 @staticmethod
832 def constraint_lstm_variables(op):
833 "State tensors must be variable"
834 valid = True
835 extra = []
836 for tens in op.inputs[18:20]:
837 if not tens.is_variable:
838 valid = False
839 extra.append(tens.name)
840 extra = ", ".join(extra)
841 return valid, f"Op has non-variable state tensor(s): {extra}"
842
Johan Alfvenf418e832023-11-13 10:23:32 +0100843 @staticmethod
844 def constraint_transpose_permutation_size(op):
845 "Permutation array must be a 1D tensor with RANK(IFM) elements"
846 dims = len(op.inputs[0].shape)
847 perm = op.inputs[1]
848 valid = len(perm.shape) == 1 and perm.shape[0] == dims
849 return valid, f"Op has ifm_dimension={dims} and permutation shape {perm.shape}"
850
851 @staticmethod
852 def constraint_transpose_permutation_values(op):
853 "Permutation array must have constant values in the range [0, RANK(IFM))"
854 dims = len(op.inputs[0].shape)
855 perm = op.inputs[1]
856 valid = False
857 if perm.values is not None:
858 valid = not any([val < 0 or val >= dims for val in perm.values])
859 return valid, f"Op has ifm_dimension={dims} and permutation values are: {perm.values}"
860
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200861
862def tflite_semantic_checker(nng):
863 semantic_checker = TFLiteSemantic()
864 for sg in nng.subgraphs:
865 for op in sg.get_all_ops():
866 op.run_on_npu = semantic_checker.is_operator_semantic_valid(op)
867 return nng