blob: 6677048780aaf30aa2a24c083039a34df195ac78 [file] [log] [blame]
Johan Alfven12e48112023-01-31 10:26:26 +01001# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020017# Description:
18# The TFLiteSemantic class which is a collection of TensorFlow lite model semantic checks.
19from collections import defaultdict
20
21import numpy as np
22
23from .data_type import BaseType
24from .data_type import DataType
25from .numeric_util import is_integer
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020026from .operation import Op
27from .supported_operators_util import docstring_format_args
28from .supported_operators_util import list_formatter
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020029from .tensor import check_quantized_tens_scaling_equal
Johan Alfven3ac03be2023-03-01 09:53:35 +010030from .tensor import shape_num_elements
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020031from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN
32from .tflite_mapping import optype_to_builtintype
33
34
35def _optype_formatter(op_list):
36 # Convert internal op types to external names
37 output = map(optype_to_builtintype, op_list)
38 # Remove UNKNOWNs
39 output = (x for x in output if x is not BUILTIN_OPERATOR_UNKNOWN)
40 return list_formatter(output)
41
42
43class TFLiteSemantic:
44 # Categorised lists of operators
Jonas Ohlssond8575072022-03-30 10:30:25 +020045 convolution_ops = set(
46 (
47 Op.Conv2DBias,
48 Op.Conv2D,
49 Op.QuantizedConv2D,
50 )
51 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020052 depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
53 transpose_convolution_ops = set((Op.Conv2DBackpropInput,))
54 convolution_like_ops = convolution_ops | depthwise_convolution_ops | transpose_convolution_ops
55 max_pooling_ops = Op.op_set(Op.is_maxpool_op)
56 avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
57 pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
58 unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
Jonas Ohlssond8575072022-03-30 10:30:25 +020059 binary_elem_wise_min_max_ops = set(
60 (
61 Op.Minimum,
62 Op.Maximum,
63 )
64 )
65 binary_elem_wise_shift_ops = set(
66 (
67 Op.SHL,
68 Op.SHR,
69 )
70 )
71 binary_elem_wise_add_mul_sub = set(
72 (
73 Op.Add,
74 Op.Mul,
75 Op.Sub,
76 )
77 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020078 binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
79 elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
Rickard Bolin6986a072022-12-19 12:33:40 +000080 shapeless_input_ops = binary_elem_wise_main_ops | set(
81 (Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims, Op.Quantize, Op.ArgMax)
82 )
Jonas Ohlssond8575072022-03-30 10:30:25 +020083 reshape_ops = set(
84 (
85 Op.Reshape,
86 Op.QuantizedReshape,
87 Op.Squeeze,
88 Op.ExpandDims,
89 )
90 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020091
92 def __init__(self):
93 # Setup the generic constraints. Note: the order matters
94 self.generic_constraints = []
Tim Hall2180a172023-03-10 18:11:34 +000095 self.generic_constraints.append(TFLiteSemantic.constraint_attributes_specified)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020096 self.generic_constraints.append(TFLiteSemantic.constraint_tens_no_dynamic)
97 self.generic_constraints.append(TFLiteSemantic.constraint_tens_defined_shape)
98 self.generic_constraints.append(TFLiteSemantic.constraint_tens_output_scalar)
99 self.generic_constraints.append(TFLiteSemantic.constraint_tens_input_scalar)
100 self.generic_constraints.append(TFLiteSemantic.constraint_tens_shape_size)
101
102 self.generic_constraints.append(TFLiteSemantic.constraint_tens_quant_none_check)
103 self.generic_constraints.append(TFLiteSemantic.constraint_tens_quant_scale)
104 self.generic_constraints.append(TFLiteSemantic.constraint_quant_scale_inf)
erik.andersson@arm.com3bbbed62021-12-20 14:14:16 +0100105 self.generic_constraints.append(TFLiteSemantic.constraint_none_const_tensors)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200106
107 # Setup specific constraints. Note: the order matters
108 self.specific_constraints = defaultdict(list)
109
110 # Conv-like checks:
111 for op_type in TFLiteSemantic.convolution_like_ops:
112 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_stride_type)
Tim Hallea4ba662022-11-11 18:19:53 +0000113 if op_type not in TFLiteSemantic.transpose_convolution_ops:
114 # Transpose Conv does not contain dilation
115 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_dilation_type)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200116
117 # Pooling checks:
118 for op_type in TFLiteSemantic.pooling_ops:
119 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_stride_type)
120 # AVG pooling specific checks:
121 for op_type in TFLiteSemantic.avg_pooling_ops:
122 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
123 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_filter_type)
124 # MAX pooling specific checks:
125 for op_type in TFLiteSemantic.max_pooling_ops:
126 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
127 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_filter_type)
128
129 # Concat specific checks:
130 for op_type in (Op.Concat, Op.ConcatTFLite):
131 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_axis_exists)
132 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_axis_valid)
133 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_dimensionality)
134 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_valid_dimensions)
Johan Alfvénb3932512022-09-12 17:44:25 +0200135 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_valid_dimensions_axis)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200136
137 # Element-wise checks:
138 for op_type in TFLiteSemantic.elem_wise_main_ops:
139 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_either_shapes)
140 # Unary specific checks:
141 for op_type in TFLiteSemantic.unary_elem_wise_main_ops:
142 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
143 # Binary Min/Max specific checks:
144 for op_type in TFLiteSemantic.binary_elem_wise_min_max_ops:
145 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
146 # Binary Add/Mul/Sub specific checks:
147 for op_type in TFLiteSemantic.binary_elem_wise_add_mul_sub:
148 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_inputs_types)
149 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_signed)
150 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_unsigned_valid)
151
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200152 # Ops reshaping dimensions: Reshape, Squeeze and ExpandDims
153 for op_type in TFLiteSemantic.reshape_ops:
154 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_quant)
Johan Alfven3ac03be2023-03-01 09:53:35 +0100155 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_elements)
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200156
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200157 # Softmax specific checks:
158 self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_shapes)
159 self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_in_out_types)
160 self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_beta_value_range)
161
Johan Alfven12e48112023-01-31 10:26:26 +0100162 # Split specific checks:
163 self.specific_constraints[Op.Split].append(TFLiteSemantic.constraint_split_axis)
164 self.specific_constraints[Op.Split].append(TFLiteSemantic.constraint_split_num_splits)
165
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200166 # SplitV specific checks:
167 self.specific_constraints[Op.SplitV].append(TFLiteSemantic.constraint_splitv_inferred)
168
169 # StridedSlice specific checks:
170 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_stridedslice_input_count)
171 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_stridedslice_inputs_const)
172 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_ellipsis_mask)
173 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_axis_masks)
174 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_slice_ranges)
175
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200176 # FullyConnected specific checks:
177 self.specific_constraints[Op.FullyConnected].append(TFLiteSemantic.constraint_fc_output_2d)
178 self.specific_constraints[Op.FullyConnected].append(TFLiteSemantic.constraint_keep_dim_ifm_ofm)
179
180 # Pad specific checks:
181 self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_input_count)
182 self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_constant)
183
184 # HardSwish specific checks:
185 self.specific_constraints[Op.HardSwish].append(TFLiteSemantic.constraint_input_8bit)
186 self.specific_constraints[Op.HardSwish].append(TFLiteSemantic.constraint_matching_in_out_types)
Fredrik Svedberg701ba912022-09-07 16:01:15 +0200187
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200188 # Mean specific checks:
189 self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_input_8bit)
190 self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_input_dims)
191 self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_axis)
192
Rickard Bolin6986a072022-12-19 12:33:40 +0000193 # ArgMax specific checks:
194 self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_input_8bit)
Johan Alfvenc1ad80b2023-03-31 10:19:23 +0200195 self.specific_constraints[Op.ArgMax].append(TFLiteSemantic.constraint_argmax_output)
Rickard Bolin6986a072022-12-19 12:33:40 +0000196
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200197 # UnidirectionalSequenceLstm specific checks:
198 self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_input_signed)
199 self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_matching_in_out_types)
200 self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_dimensions)
201 self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_inputs)
202 self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_intermediates)
203 self.specific_constraints[Op.UnidirectionalSequenceLstm].append(TFLiteSemantic.constraint_lstm_variables)
204
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200205 def is_operator_semantic_valid(self, op):
206 ext_type = optype_to_builtintype(op.type)
207
208 if op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const):
209 return True
210
Ayaan Masood4965fae2022-06-29 11:30:57 +0100211 # Generic constraints list filtered out to exclude certain constraints depending on op.type
212 filtered_generic_constraints = []
213
214 for constraint in self.generic_constraints:
215 # Check constraint not in dictionary otherwise return empty array
216 if constraint not in self.get_generic_constraint_exclude_list().get(op.type, []):
217 filtered_generic_constraints.append(constraint)
218
219 for constraint in filtered_generic_constraints + self.specific_constraints[op.type]:
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200220 valid, extra = constraint(op)
221 if not valid:
222 print(
Tim Hall3584a9c2021-11-18 22:05:17 +0000223 f"Warning: Unsupported TensorFlow Lite semantics for {ext_type} '{op.name}'. Placing on CPU instead"
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200224 )
225 print(f" - {constraint.__doc__}")
226 if extra:
227 print(f" {extra}")
228 return False
229
230 return True
231
232 @staticmethod
Ayaan Masood4965fae2022-06-29 11:30:57 +0100233 def get_generic_constraint_exclude_list():
234
235 # Not all generic constraints can be applied to each operator
236 generic_constraints_exclude_list = {
237 Op.Shape: [
238 TFLiteSemantic.constraint_tens_quant_none_check,
Ayaan Masood25f48dd2022-06-29 18:16:04 +0100239 ],
240 Op.Quantize: [
241 TFLiteSemantic.constraint_tens_no_dynamic,
242 TFLiteSemantic.constraint_tens_output_scalar,
Ayaan Masood25f48dd2022-06-29 18:16:04 +0100243 ],
Rickard Bolin6986a072022-12-19 12:33:40 +0000244 Op.ArgMax: [
245 TFLiteSemantic.constraint_tens_quant_none_check,
246 ],
Ayaan Masood4965fae2022-06-29 11:30:57 +0100247 }
248 return generic_constraints_exclude_list
249
250 @staticmethod
erik.andersson@arm.com3bbbed62021-12-20 14:14:16 +0100251 def constraint_none_const_tensors(op):
252 "Constant tensors should not have NoneType-values"
253 valid = True
254 extra = ""
255 for tens in filter(None, op.inputs):
256 if len(tens.ops) > 0 and tens.ops[0].type == Op.Const and tens.values is None:
257 valid = False
258 extra = str(tens.name)
259 return valid, f"Unexpected None value for constant tensor: {extra}"
260
261 @staticmethod
Tim Hall2180a172023-03-10 18:11:34 +0000262 def constraint_attributes_specified(op):
263 "All required operator attributes must be specified"
264 # operators that have been created internally (i.e. not created as part of reading an input network) may not
265 # have the read error attribute
266 attribute_read_error = op.attrs.get("attribute_read_error", [])
267 valid = len(attribute_read_error) == 0
268 extra = ", ".join(attribute_read_error)
269 return valid, f"Op has missing attributes: {extra}"
270
271 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200272 def constraint_tens_no_dynamic(op):
273 "Input(s) and Output tensors must not be dynamic"
274 valid = True
275 extra = []
276 tensors = [tens for tens in op.inputs + op.outputs if tens]
277 for tens in tensors:
278 if (tens.shape == []) and (tens.values is None):
279 valid = False
280 extra.append(tens.name)
281 extra = ", ".join(extra)
282 return valid, f"Op has dynamic tensor(s): {extra}"
283
284 @staticmethod
285 def constraint_tens_defined_shape(op):
286 "Input(s) and Output tensors must have a defined shape"
287 valid = True
288 extra = []
289 tensors = [tens for tens in op.inputs + op.outputs if tens]
290 for tens in tensors:
291 if not tens.has_fully_defined_shape():
292 valid = False
293 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
294 return valid, ", ".join(extra)
295
296 @staticmethod
297 def constraint_tens_output_scalar(op):
298 "Output tensors cannot be scalar"
299 ofm = op.ofm
300 valid = ofm.shape != []
301 return valid, f"Output Tensor '{ofm.name}' is scalar"
302
303 @classmethod
304 @docstring_format_args([_optype_formatter(shapeless_input_ops)])
305 def constraint_tens_input_scalar(cls, op):
306 "Scalar Input tensors are only valid for op type: {}"
307 valid = True
308 extra = []
309 tensors = [tens for tens in op.inputs if tens]
310 for tens in tensors:
311 if (tens.shape == []) and (op.type not in cls.shapeless_input_ops):
312 valid = False
313 extra.append(tens.name)
314 extra = ", ".join(extra)
315 return valid, f"Op has scalar input tensor(s): {extra}"
316
317 @staticmethod
318 def constraint_tens_shape_size(op):
319 "Input(s) and Output tensors must not be greater than 4D"
320 valid = True
321 extra = []
322 tensors = [tens for tens in op.inputs + op.outputs if tens]
323 for tens in tensors:
324 if len(tens.shape) > 4:
325 valid = False
326 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
327 return valid, ", ".join(extra)
328
329 @staticmethod
330 def constraint_tens_quant_none_check(op):
331 "Input(s), Output and Weight tensors must have quantization parameters"
332 valid = True
333 extra = []
334 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
335 for tens in tensors:
336 if tens.quantization is None:
337 valid = False
338 extra.append(tens.name)
339 extra = ", ".join(extra)
340 return valid, f"Op has tensors with missing quantization parameters: {extra}"
341
342 @staticmethod
343 def constraint_tens_quant_scale(op):
344 "Input(s), Output and Weight tensors with quantization scales must be finite"
345 valid = True
346 extra = []
347 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
348 for tens in tensors:
Fredrik Svedberg11563172022-07-06 14:54:12 +0200349 if (
350 tens.quantization
351 and tens.quantization.scale_f32 is not None
352 and np.isinf(tens.quantization.scale_f32).any()
353 ):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200354 valid = False
355 extra.append(f"Tensor '{tens.name}' has quantization scale: {tens.quantization.scale_f32}")
356 return valid, ", ".join(extra)
357
358 @staticmethod
359 def constraint_fc_output_2d(op):
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100360 """The output tensor(s) must have 2D shape"""
361 valid = op.ifm.get_shape_as_2d(op.weights.shape[-2]) is not None
362 extra = f"Op has non-2D output tensor '{op.ofm.name}'" if not valid else ""
363
364 return valid, extra
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200365
366 @staticmethod
367 def constraint_stride_type(op):
368 "Stride values for both width and height must be integer types"
369 w, h = op.get_kernel_stride()
370 valid = is_integer(w) and is_integer(h)
371 return valid, f"Op has stride WxH as: {repr(w)}x{repr(h)}"
372
373 @staticmethod
374 def constraint_dilation_type(op):
375 "Dilation factor values for both width and height must be integer types"
376 w, h = op.get_kernel_dilation()
377 valid = is_integer(w) and is_integer(h)
378 return valid, f"Op has dilation factor WxH as: {repr(w)}x{repr(h)}"
379
380 @staticmethod
381 def constraint_quant_scale_inf(op):
382 "Input and Output tensors must have quantization scales that fit within float32 precision"
383 if op.ofm is not None and op.ofm.is_quantized():
384 ofm_scale = op.ofm.quantization.scale_f32
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200385 if np.any(ofm_scale < np.finfo(np.float32).tiny):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200386 return (
387 False,
388 f"The quantization scale of the output tensor is {ofm_scale}, "
389 + f"minimum supported is: {np.finfo(np.float32).tiny}",
390 )
391 if op.ifm is not None and op.ifm.is_quantized():
392 ifm_scale = op.ifm.quantization.scale_f32
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200393 if np.any(np.isinf(ifm_scale / ofm_scale)):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200394 return (
395 False,
396 f"IFM scale divided by OFM scale is infinite, ifm_scale={ifm_scale} ofm_scale={ofm_scale}",
397 )
398 return True, "Op's quantization is ok"
399
400 @staticmethod
401 def constraint_matching_in_out_types(op):
402 "IFM and OFM data types must match"
403 ifm_dtype = op.ifm.dtype
404 ofm_dtype = op.ofm.dtype
405 valid = ifm_dtype == ofm_dtype
406 return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
407
408 @staticmethod
409 def constraint_beta_value_range(op):
410 "Beta value needs to be positive"
411 beta = op.attrs.get("beta", 1.0)
412 valid = beta >= 0
413 return valid, f"Op has beta={beta}"
414
415 @staticmethod
416 def constraint_filter_type(op):
417 "Kernel filter values for both width and height must be integer types"
418 w = op.kernel.width
419 h = op.kernel.height
420 valid = is_integer(w) and is_integer(h)
421 return valid, f"Op has kernel filter WxH as: {repr(w)}x{repr(h)}"
422
423 @staticmethod
424 def constraint_matching_shapes(op):
425 "IFM and OFM shapes must match"
426 ifm_shape = op.ifm.shape
427 ofm_shape = op.ofm.shape
428 valid = ifm_shape == ofm_shape
429 return valid, f"Op has ifm_shape={ifm_shape} and ofm_shape={ofm_shape}"
430
431 @staticmethod
Johan Alfven12e48112023-01-31 10:26:26 +0100432 def constraint_split_axis(op):
433 "Axis value must be in the range [-RANK(IFM) to +RANK(IFM))"
434 axis_tens = op.inputs[0]
435 input_tens = op.inputs[1]
436 dims = len(input_tens.shape)
437 axis = int(axis_tens.values)
438 axis += dims if axis < 0 else 0
439 valid = 0 <= axis < dims
440 return valid, f"Op has ifm_dimensions={dims} and axis value is: {axis}"
441
442 @staticmethod
443 def constraint_split_num_splits(op):
444 "Axis must be divisible by number of splits"
445 num_splits = op.attrs.get("num_splits")
446 axis_tens = op.inputs[0]
447 input_tens = op.inputs[1]
448 dims = len(input_tens.shape)
449 axis = int(axis_tens.values)
450 axis += dims if axis < 0 else 0
451 valid = input_tens.shape[axis] % num_splits == 0
452 return valid, f"Op has ifm shape={input_tens.shape} axis={axis} num_splits={num_splits}"
453
454 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200455 def constraint_splitv_inferred(op):
456 "Only one size is allowed to be inferred"
457 sizes = op.inputs[1].values
458 valid = np.count_nonzero(sizes == -1) <= 1
459 return valid, f"Op has multiple inferred sizes (-1): {sizes}"
460
461 @staticmethod
462 def constraint_axis_exists(op):
463 "Axis attribute must exist"
464 axis = op.attrs.get("axis")
465 valid = axis is not None
466 return valid, f"Op has axis={axis}"
467
468 @staticmethod
469 def constraint_axis_valid(op):
470 "Axis attribute must be in the range [0, <ofm_dimensions>)"
471 dims = len(op.ofm.shape)
472 axis = op.attrs["axis"]
473 axis += dims if axis < 0 else 0
474 valid = 0 <= axis < dims
475 return valid, f"Op has ofm_dimensions={dims} and axis attribute is: {axis}"
476
477 @staticmethod
478 def constraint_matching_dimensionality(op):
479 "All Input dimensionalities must match OFM dimensionality"
480 valid = True
481 extra = []
482 ofm_dim = len(op.ofm.shape)
483 tensors = [tens for tens in op.inputs if tens]
484 for tens in tensors:
485 dim = len(tens.shape)
486 if dim != ofm_dim:
487 valid = False
488 extra.append(f"Tensor '{tens.name}' has dimension: {dim}")
489 extra = ", ".join(extra)
490 return valid, f"Op has ofm_dimension={ofm_dim} and the list of mismatching inputs are: {extra}"
491
492 @staticmethod
493 def constraint_valid_dimensions(op):
494 "All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute"
495 valid = True
496 extra = []
497 ofm_shape = op.ofm.shape
498 ofm_dim = len(ofm_shape)
499 axis = op.attrs["axis"]
500 axis += ofm_dim if axis < 0 else 0
501 tensors = [tens for tens in op.inputs if tens]
502 for tens in tensors:
503 if any(tens.shape[dim] != ofm_shape[dim] for dim in range(ofm_dim) if dim != axis):
504 valid = False
505 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
506 extra = ", ".join(extra)
507 return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
508
509 @staticmethod
Johan Alfvénb3932512022-09-12 17:44:25 +0200510 def constraint_valid_dimensions_axis(op):
511 """The size of the OFM axis must match the sum of all IFM axis defined by the axis attribute"""
512 valid = True
513 extra = []
514 ofm_shape = op.ofm.shape
515 ofm_dim = len(ofm_shape)
516 axis = op.attrs["axis"]
517 axis += ofm_dim if axis < 0 else 0
518
519 sum_ifm_axis = 0
520 tensors = [tens for tens in op.inputs if tens]
521 for tens in tensors:
522 sum_ifm_axis += tens.shape[axis]
523 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
524
525 valid = sum_ifm_axis == ofm_shape[axis]
526 extra = ", ".join(extra)
527 return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
528
529 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200530 def constraint_stridedslice_input_count(op):
531 "Exactly 4 Input tensors are required"
532 inputs = len(op.inputs)
533 valid = inputs == 4
534 return valid, f"Op has {inputs} inputs"
535
536 @staticmethod
537 def constraint_pad_input_count(op):
538 "Number of input tensors must be exactly 2"
539 inputs = len(op.inputs)
540 valid = inputs == 2
541 return valid, f"Op has {inputs} inputs"
542
543 @staticmethod
544 def constraint_pad_constant(op):
545 "The padding tensor must be constant"
546 pad_tensor = op.inputs[1].values
547 valid = pad_tensor is not None
548 return valid, f"Op has non-constant padding tensor: {op.inputs[1].values}"
549
550 @staticmethod
551 def constraint_stridedslice_inputs_const(op):
552 "Begin, End and Stride Input tensors must be constant"
553 valid = True
554 extra = []
555 _, begin, end, strides = op.inputs
556 if begin.values is None:
557 valid = False
558 extra.append(f"Begin tensor '{begin.name}'")
559 if end.values is None:
560 valid = False
561 extra.append(f"End tensor '{end.name}'")
562 if strides.values is None:
563 valid = False
564 extra.append(f"Stride tensor '{strides.name}'")
565 extra = ", ".join(extra)
566 return valid, f"Op has non-constant tensors: {extra}"
567
568 @staticmethod
569 def constraint_ellipsis_mask(op):
570 "ellipsis_mask must be 0"
571 ellipsis = op.attrs["ellipsis_mask"]
572 valid = ellipsis == 0
573 return valid, f"Op has ellipsis mask as: {ellipsis}"
574
575 @staticmethod
576 def constraint_axis_masks(op):
577 "new_axis_mask and shrink_axis_mask cannot both be set"
578 new_axis = op.attrs["new_axis_mask"]
579 shrink_axis = op.attrs["shrink_axis_mask"]
580 valid = (new_axis == 0) or (shrink_axis == 0)
581 return valid, f"Op has new_axis_mask={new_axis} and shrink_axis_mask={shrink_axis}"
582
Tim Halld0e41cf2023-02-14 14:54:18 +0000583 def _get_slice_offsets(input_shape, offset_tens, offset_mask, is_begin=True):
584 # For strided slice operator: get start or end offsets
585 # input_shape: List[int], offset_tens: Tensor, offset_mask: int, is_begin: bool = True
586 offsets = len(input_shape) * [0] if is_begin else input_shape[:]
587 for idx in range(len(input_shape)):
588 # If the i:th bit in the mask is not set then the value in offset_tens[i] should be used, otherwise it
589 # should be ignored
590 if (offset_mask & (1 << idx)) == 0:
591 offsets[idx] = offset_tens.values[idx]
592 if offsets[idx] < 0:
593 # Convert negative indexing to positive ones
594 offsets[idx] += input_shape[idx]
595 return offsets
596
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200597 @staticmethod
598 def constraint_slice_ranges(op):
599 "Slice 'end' values must be greater than 'begin' values"
600 ifm, begin, end, _ = op.inputs
Tim Halld0e41cf2023-02-14 14:54:18 +0000601 shrink_axis_mask = op.attrs["shrink_axis_mask"]
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200602 # Calculate offset begin/end
Tim Halld0e41cf2023-02-14 14:54:18 +0000603 offset_begin = TFLiteSemantic._get_slice_offsets(ifm.shape, begin, op.attrs["begin_mask"], is_begin=True)
604 offset_end = TFLiteSemantic._get_slice_offsets(ifm.shape, end, op.attrs["end_mask"], is_begin=False)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200605 # Check "end - begin" doesn't result in any zero or negative elements
Tim Halld0e41cf2023-02-14 14:54:18 +0000606 valid = True
607 # if a shrink mask bit is set then the end position provided by the operation should be ignored, and instead a
608 # new end position should be calculated so that calculations in the graph optimiser, such as (end - start),
609 # result in the correct value. otherwise, we just need to check that the begin and end values are valid
610 for i in range(len(ifm.shape)):
611 if (shrink_axis_mask & (1 << i)) != 0:
612 offset_end[i] = offset_begin[i] + 1
613 else:
614 if offset_end[i] <= offset_begin[i]:
615 valid = False
616
617 op.attrs["offset_begin"] = offset_begin
618 op.attrs["offset_end"] = offset_end
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200619 return valid, f"Op has begin_values={begin.values} and end_values={end.values}"
620
621 @staticmethod
622 def constraint_matching_inputs_types(op):
623 "Both Input data types must match"
624 ifm_dtype = op.ifm.dtype
625 ifm2_dtype = op.ifm2.dtype
626 valid = ifm_dtype == ifm2_dtype
627 return valid, f"Op has ifm_dtype={ifm_dtype} and ifm2_dtype={ifm2_dtype}"
628
629 @staticmethod
630 def constraint_matching_signed(op):
631 "For IFM that are signed, OFM must also be signed"
632 valid = True
633 ifm_dtype = op.ifm.dtype
634 ofm_dtype = op.ofm.dtype
635 if ifm_dtype.type & BaseType.Signed:
636 valid = bool(ofm_dtype.type & BaseType.Signed)
637 return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
638
639 @staticmethod
640 def constraint_unsigned_valid(op):
641 "For IFM that are unsigned, OFM must either be the same type or int32"
642 valid = True
643 ifm_dtype = op.ifm.dtype
644 ofm_dtype = op.ofm.dtype
645 if ifm_dtype.type & BaseType.Unsigned:
646 valid = (ifm_dtype == ofm_dtype) or (ofm_dtype == DataType.int32)
647 return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
648
649 @staticmethod
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200650 def constraint_input_signed(op):
651 "IFM must be int8 or int16"
652 ifm_dtype = op.ifm.dtype
653 valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.int16)
654 return valid, f"Op has ifm_dtype={ifm_dtype}"
655
656 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200657 def constraint_input_8bit(op):
658 "IFM must be int8 or uint8"
659 ifm_dtype = op.ifm.dtype
660 valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.uint8)
661 return valid, f"Op has ifm_dtype={ifm_dtype}"
662
663 @staticmethod
Johan Alfvenc1ad80b2023-03-31 10:19:23 +0200664 def constraint_argmax_output(op):
665 "OFM must be int32 or int64"
666 ofm_dtype = op.ofm.dtype
667 valid = ofm_dtype in (DataType.int32, DataType.int64)
668 return valid, f"Op has ofm_dtype={ofm_dtype}"
669
670 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200671 def constraint_matching_either_shapes(op):
672 "At least one Input's shape must match the OFM's shape"
673 ifm_shape = op.ifm.shape
674 ifm2_shape = op.ifm2.shape if op.ifm2 else None
675 ofm_shape = op.ofm.shape
676 valid = (ifm_shape == ofm_shape) or (ifm2_shape == ofm_shape)
677 return valid, f"Op has ifm_shape={ifm_shape}, ifm2_shape={ifm2_shape} and ofm_shape={ofm_shape}"
678
679 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200680 def constraint_keep_dim_ifm_ofm(op):
681 "The IFM and OFM must have the same number of dimensions if keep_num_dims is set to true"
682 valid = True
683 if op.attrs.get("keep_num_dims"):
684 valid = len(op.ifm.shape) == len(op.ofm.shape)
685 return valid, f"Op has ifm shape={op.ifm.shape} and ofm shape={op.ofm.shape}"
686
687 @staticmethod
688 def constraint_mean_input_dims(op):
689 "Input tensor must be at least 2D"
690 dims = len(op.inputs[0].shape)
691 return 2 <= dims <= 4, f"Input is {dims}D"
692
693 @staticmethod
694 def constraint_mean_axis(op):
695 "Axis indices must correspond to height and width axes"
696 dims = len(op.inputs[0].shape)
697 axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
698 if dims == 2 or dims == 3:
699 valid = axis in (0, 1, [0], [1], [0, 1], [1, 0])
700 elif dims == 4:
701 valid = axis in (1, 2, [1], [2], [1, 2], [2, 1])
702 return valid, f"Axis is {axis}"
703
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200704 @staticmethod
705 def constraint_matching_in_out_quant(op):
706 "Input and output quantisation must match."
707 if not check_quantized_tens_scaling_equal(op.ifm, op.ofm):
708 return False, "IFM and OFM quantisation parameters are not equal."
709 return True, "IFM and OFM quantisation parameters matches."
710
Johan Alfven3ac03be2023-03-01 09:53:35 +0100711 @staticmethod
712 def constraint_matching_in_out_elements(op):
713 "Input and output number of elements must match."
714 if shape_num_elements(op.ifm.shape) != shape_num_elements(op.ofm.shape):
715 return False, f"IFM {op.ifm.shape} and OFM {op.ofm.shape} number of elements are not equal."
716 return True, "IFM and OFM number of elements are equal."
717
Fredrik Svedberg0ac08042023-04-11 22:35:04 +0200718 @staticmethod
719 def constraint_lstm_dimensions(op):
720 "IFM and OFM must have 3D shape"
721 valid = len(op.ifm.shape) == len(op.ofm.shape) == 3
722 return valid, f"Op has ifm shape {op.ifm.shape} and ofm shape {op.ofm.shape}"
723
724 @staticmethod
725 def constraint_lstm_inputs(op):
726 "Must have 24 input tensors"
727 n_inputs = len(op.inputs)
728 return n_inputs == 24, f"Op has {n_inputs} inputs"
729
730 @staticmethod
731 def constraint_lstm_intermediates(op):
732 "Must have 5 intermediate tensors"
733 n_intermediates = len(op.intermediates)
734 return n_intermediates == 5, f"Op has {n_intermediates} intermediates"
735
736 @staticmethod
737 def constraint_lstm_variables(op):
738 "State tensors must be variable"
739 valid = True
740 extra = []
741 for tens in op.inputs[18:20]:
742 if not tens.is_variable:
743 valid = False
744 extra.append(tens.name)
745 extra = ", ".join(extra)
746 return valid, f"Op has non-variable state tensor(s): {extra}"
747
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200748
749def tflite_semantic_checker(nng):
750 semantic_checker = TFLiteSemantic()
751 for sg in nng.subgraphs:
752 for op in sg.get_all_ops():
753 op.run_on_npu = semantic_checker.is_operator_semantic_valid(op)
754 return nng