blob: 2851ab16761cbfd8159d6810b93b2212d4664ca9 [file] [log] [blame]
Johan Alfven12e48112023-01-31 10:26:26 +01001# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02002#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Rickard Bolinbc6ee582022-11-04 08:24:29 +000016#
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020017# Description:
18# The TFLiteSemantic class which is a collection of TensorFlow lite model semantic checks.
19from collections import defaultdict
20
21import numpy as np
22
23from .data_type import BaseType
24from .data_type import DataType
25from .numeric_util import is_integer
26from .operation import get_slice_offsets
27from .operation import Op
28from .supported_operators_util import docstring_format_args
29from .supported_operators_util import list_formatter
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020030from .tensor import check_quantized_tens_scaling_equal
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020031from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN
32from .tflite_mapping import optype_to_builtintype
33
34
35def _optype_formatter(op_list):
36 # Convert internal op types to external names
37 output = map(optype_to_builtintype, op_list)
38 # Remove UNKNOWNs
39 output = (x for x in output if x is not BUILTIN_OPERATOR_UNKNOWN)
40 return list_formatter(output)
41
42
43class TFLiteSemantic:
44 # Categorised lists of operators
Jonas Ohlssond8575072022-03-30 10:30:25 +020045 convolution_ops = set(
46 (
47 Op.Conv2DBias,
48 Op.Conv2D,
49 Op.QuantizedConv2D,
50 )
51 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020052 depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
53 transpose_convolution_ops = set((Op.Conv2DBackpropInput,))
54 convolution_like_ops = convolution_ops | depthwise_convolution_ops | transpose_convolution_ops
55 max_pooling_ops = Op.op_set(Op.is_maxpool_op)
56 avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
57 pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
58 unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
Jonas Ohlssond8575072022-03-30 10:30:25 +020059 binary_elem_wise_min_max_ops = set(
60 (
61 Op.Minimum,
62 Op.Maximum,
63 )
64 )
65 binary_elem_wise_shift_ops = set(
66 (
67 Op.SHL,
68 Op.SHR,
69 )
70 )
71 binary_elem_wise_add_mul_sub = set(
72 (
73 Op.Add,
74 Op.Mul,
75 Op.Sub,
76 )
77 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020078 binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
79 elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
Fredrik Svedberg11563172022-07-06 14:54:12 +020080 shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims, Op.Quantize))
Jonas Ohlssond8575072022-03-30 10:30:25 +020081 reshape_ops = set(
82 (
83 Op.Reshape,
84 Op.QuantizedReshape,
85 Op.Squeeze,
86 Op.ExpandDims,
87 )
88 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020089
90 def __init__(self):
91 # Setup the generic constraints. Note: the order matters
92 self.generic_constraints = []
93 self.generic_constraints.append(TFLiteSemantic.constraint_tens_no_dynamic)
94 self.generic_constraints.append(TFLiteSemantic.constraint_tens_defined_shape)
95 self.generic_constraints.append(TFLiteSemantic.constraint_tens_output_scalar)
96 self.generic_constraints.append(TFLiteSemantic.constraint_tens_input_scalar)
97 self.generic_constraints.append(TFLiteSemantic.constraint_tens_shape_size)
98
99 self.generic_constraints.append(TFLiteSemantic.constraint_tens_quant_none_check)
100 self.generic_constraints.append(TFLiteSemantic.constraint_tens_quant_scale)
101 self.generic_constraints.append(TFLiteSemantic.constraint_quant_scale_inf)
erik.andersson@arm.com3bbbed62021-12-20 14:14:16 +0100102 self.generic_constraints.append(TFLiteSemantic.constraint_none_const_tensors)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200103
104 # Setup specific constraints. Note: the order matters
105 self.specific_constraints = defaultdict(list)
106
107 # Conv-like checks:
108 for op_type in TFLiteSemantic.convolution_like_ops:
109 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_stride_type)
Tim Hallea4ba662022-11-11 18:19:53 +0000110 if op_type not in TFLiteSemantic.transpose_convolution_ops:
111 # Transpose Conv does not contain dilation
112 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_dilation_type)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200113
114 # Pooling checks:
115 for op_type in TFLiteSemantic.pooling_ops:
116 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_stride_type)
117 # AVG pooling specific checks:
118 for op_type in TFLiteSemantic.avg_pooling_ops:
119 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
120 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_filter_type)
121 # MAX pooling specific checks:
122 for op_type in TFLiteSemantic.max_pooling_ops:
123 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
124 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_filter_type)
125
126 # Concat specific checks:
127 for op_type in (Op.Concat, Op.ConcatTFLite):
128 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_axis_exists)
129 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_axis_valid)
130 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_dimensionality)
131 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_valid_dimensions)
Johan Alfvénb3932512022-09-12 17:44:25 +0200132 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_valid_dimensions_axis)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200133
134 # Element-wise checks:
135 for op_type in TFLiteSemantic.elem_wise_main_ops:
136 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_either_shapes)
137 # Unary specific checks:
138 for op_type in TFLiteSemantic.unary_elem_wise_main_ops:
139 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
140 # Binary Min/Max specific checks:
141 for op_type in TFLiteSemantic.binary_elem_wise_min_max_ops:
142 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
143 # Binary Add/Mul/Sub specific checks:
144 for op_type in TFLiteSemantic.binary_elem_wise_add_mul_sub:
145 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_inputs_types)
146 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_signed)
147 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_unsigned_valid)
148
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200149 # Ops reshaping dimensions: Reshape, Squeeze and ExpandDims
150 for op_type in TFLiteSemantic.reshape_ops:
151 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_quant)
152
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200153 # Softmax specific checks:
154 self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_shapes)
155 self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_in_out_types)
156 self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_beta_value_range)
157
Johan Alfven12e48112023-01-31 10:26:26 +0100158 # Split specific checks:
159 self.specific_constraints[Op.Split].append(TFLiteSemantic.constraint_split_axis)
160 self.specific_constraints[Op.Split].append(TFLiteSemantic.constraint_split_num_splits)
161
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200162 # SplitV specific checks:
163 self.specific_constraints[Op.SplitV].append(TFLiteSemantic.constraint_splitv_inferred)
164
165 # StridedSlice specific checks:
166 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_stridedslice_input_count)
167 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_stridedslice_inputs_const)
168 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_ellipsis_mask)
169 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_axis_masks)
170 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_slice_ranges)
171
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200172 # FullyConnected specific checks:
173 self.specific_constraints[Op.FullyConnected].append(TFLiteSemantic.constraint_fc_output_2d)
174 self.specific_constraints[Op.FullyConnected].append(TFLiteSemantic.constraint_keep_dim_ifm_ofm)
175
176 # Pad specific checks:
177 self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_input_count)
178 self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_constant)
179
180 # HardSwish specific checks:
181 self.specific_constraints[Op.HardSwish].append(TFLiteSemantic.constraint_input_8bit)
182 self.specific_constraints[Op.HardSwish].append(TFLiteSemantic.constraint_matching_in_out_types)
Fredrik Svedberg701ba912022-09-07 16:01:15 +0200183
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200184 # Mean specific checks:
185 self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_input_8bit)
186 self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_input_dims)
187 self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_axis)
188
189 def is_operator_semantic_valid(self, op):
190 ext_type = optype_to_builtintype(op.type)
191
192 if op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const):
193 return True
194
Ayaan Masood4965fae2022-06-29 11:30:57 +0100195 # Generic constraints list filtered out to exclude certain constraints depending on op.type
196 filtered_generic_constraints = []
197
198 for constraint in self.generic_constraints:
199 # Check constraint not in dictionary otherwise return empty array
200 if constraint not in self.get_generic_constraint_exclude_list().get(op.type, []):
201 filtered_generic_constraints.append(constraint)
202
203 for constraint in filtered_generic_constraints + self.specific_constraints[op.type]:
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200204 valid, extra = constraint(op)
205 if not valid:
206 print(
Tim Hall3584a9c2021-11-18 22:05:17 +0000207 f"Warning: Unsupported TensorFlow Lite semantics for {ext_type} '{op.name}'. Placing on CPU instead"
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200208 )
209 print(f" - {constraint.__doc__}")
210 if extra:
211 print(f" {extra}")
212 return False
213
214 return True
215
216 @staticmethod
Ayaan Masood4965fae2022-06-29 11:30:57 +0100217 def get_generic_constraint_exclude_list():
218
219 # Not all generic constraints can be applied to each operator
220 generic_constraints_exclude_list = {
221 Op.Shape: [
222 TFLiteSemantic.constraint_tens_quant_none_check,
Ayaan Masood25f48dd2022-06-29 18:16:04 +0100223 ],
224 Op.Quantize: [
225 TFLiteSemantic.constraint_tens_no_dynamic,
226 TFLiteSemantic.constraint_tens_output_scalar,
Ayaan Masood25f48dd2022-06-29 18:16:04 +0100227 ],
Ayaan Masood4965fae2022-06-29 11:30:57 +0100228 }
229 return generic_constraints_exclude_list
230
231 @staticmethod
erik.andersson@arm.com3bbbed62021-12-20 14:14:16 +0100232 def constraint_none_const_tensors(op):
233 "Constant tensors should not have NoneType-values"
234 valid = True
235 extra = ""
236 for tens in filter(None, op.inputs):
237 if len(tens.ops) > 0 and tens.ops[0].type == Op.Const and tens.values is None:
238 valid = False
239 extra = str(tens.name)
240 return valid, f"Unexpected None value for constant tensor: {extra}"
241
242 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200243 def constraint_tens_no_dynamic(op):
244 "Input(s) and Output tensors must not be dynamic"
245 valid = True
246 extra = []
247 tensors = [tens for tens in op.inputs + op.outputs if tens]
248 for tens in tensors:
249 if (tens.shape == []) and (tens.values is None):
250 valid = False
251 extra.append(tens.name)
252 extra = ", ".join(extra)
253 return valid, f"Op has dynamic tensor(s): {extra}"
254
255 @staticmethod
256 def constraint_tens_defined_shape(op):
257 "Input(s) and Output tensors must have a defined shape"
258 valid = True
259 extra = []
260 tensors = [tens for tens in op.inputs + op.outputs if tens]
261 for tens in tensors:
262 if not tens.has_fully_defined_shape():
263 valid = False
264 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
265 return valid, ", ".join(extra)
266
267 @staticmethod
268 def constraint_tens_output_scalar(op):
269 "Output tensors cannot be scalar"
270 ofm = op.ofm
271 valid = ofm.shape != []
272 return valid, f"Output Tensor '{ofm.name}' is scalar"
273
274 @classmethod
275 @docstring_format_args([_optype_formatter(shapeless_input_ops)])
276 def constraint_tens_input_scalar(cls, op):
277 "Scalar Input tensors are only valid for op type: {}"
278 valid = True
279 extra = []
280 tensors = [tens for tens in op.inputs if tens]
281 for tens in tensors:
282 if (tens.shape == []) and (op.type not in cls.shapeless_input_ops):
283 valid = False
284 extra.append(tens.name)
285 extra = ", ".join(extra)
286 return valid, f"Op has scalar input tensor(s): {extra}"
287
288 @staticmethod
289 def constraint_tens_shape_size(op):
290 "Input(s) and Output tensors must not be greater than 4D"
291 valid = True
292 extra = []
293 tensors = [tens for tens in op.inputs + op.outputs if tens]
294 for tens in tensors:
295 if len(tens.shape) > 4:
296 valid = False
297 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
298 return valid, ", ".join(extra)
299
300 @staticmethod
301 def constraint_tens_quant_none_check(op):
302 "Input(s), Output and Weight tensors must have quantization parameters"
303 valid = True
304 extra = []
305 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
306 for tens in tensors:
307 if tens.quantization is None:
308 valid = False
309 extra.append(tens.name)
310 extra = ", ".join(extra)
311 return valid, f"Op has tensors with missing quantization parameters: {extra}"
312
313 @staticmethod
314 def constraint_tens_quant_scale(op):
315 "Input(s), Output and Weight tensors with quantization scales must be finite"
316 valid = True
317 extra = []
318 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
319 for tens in tensors:
Fredrik Svedberg11563172022-07-06 14:54:12 +0200320 if (
321 tens.quantization
322 and tens.quantization.scale_f32 is not None
323 and np.isinf(tens.quantization.scale_f32).any()
324 ):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200325 valid = False
326 extra.append(f"Tensor '{tens.name}' has quantization scale: {tens.quantization.scale_f32}")
327 return valid, ", ".join(extra)
328
329 @staticmethod
330 def constraint_fc_output_2d(op):
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100331 """The output tensor(s) must have 2D shape"""
332 valid = op.ifm.get_shape_as_2d(op.weights.shape[-2]) is not None
333 extra = f"Op has non-2D output tensor '{op.ofm.name}'" if not valid else ""
334
335 return valid, extra
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200336
337 @staticmethod
338 def constraint_stride_type(op):
339 "Stride values for both width and height must be integer types"
340 w, h = op.get_kernel_stride()
341 valid = is_integer(w) and is_integer(h)
342 return valid, f"Op has stride WxH as: {repr(w)}x{repr(h)}"
343
344 @staticmethod
345 def constraint_dilation_type(op):
346 "Dilation factor values for both width and height must be integer types"
347 w, h = op.get_kernel_dilation()
348 valid = is_integer(w) and is_integer(h)
349 return valid, f"Op has dilation factor WxH as: {repr(w)}x{repr(h)}"
350
351 @staticmethod
352 def constraint_quant_scale_inf(op):
353 "Input and Output tensors must have quantization scales that fit within float32 precision"
354 if op.ofm is not None and op.ofm.is_quantized():
355 ofm_scale = op.ofm.quantization.scale_f32
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200356 if np.any(ofm_scale < np.finfo(np.float32).tiny):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200357 return (
358 False,
359 f"The quantization scale of the output tensor is {ofm_scale}, "
360 + f"minimum supported is: {np.finfo(np.float32).tiny}",
361 )
362 if op.ifm is not None and op.ifm.is_quantized():
363 ifm_scale = op.ifm.quantization.scale_f32
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200364 if np.any(np.isinf(ifm_scale / ofm_scale)):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200365 return (
366 False,
367 f"IFM scale divided by OFM scale is infinite, ifm_scale={ifm_scale} ofm_scale={ofm_scale}",
368 )
369 return True, "Op's quantization is ok"
370
371 @staticmethod
372 def constraint_matching_in_out_types(op):
373 "IFM and OFM data types must match"
374 ifm_dtype = op.ifm.dtype
375 ofm_dtype = op.ofm.dtype
376 valid = ifm_dtype == ofm_dtype
377 return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
378
379 @staticmethod
380 def constraint_beta_value_range(op):
381 "Beta value needs to be positive"
382 beta = op.attrs.get("beta", 1.0)
383 valid = beta >= 0
384 return valid, f"Op has beta={beta}"
385
386 @staticmethod
387 def constraint_filter_type(op):
388 "Kernel filter values for both width and height must be integer types"
389 w = op.kernel.width
390 h = op.kernel.height
391 valid = is_integer(w) and is_integer(h)
392 return valid, f"Op has kernel filter WxH as: {repr(w)}x{repr(h)}"
393
394 @staticmethod
395 def constraint_matching_shapes(op):
396 "IFM and OFM shapes must match"
397 ifm_shape = op.ifm.shape
398 ofm_shape = op.ofm.shape
399 valid = ifm_shape == ofm_shape
400 return valid, f"Op has ifm_shape={ifm_shape} and ofm_shape={ofm_shape}"
401
402 @staticmethod
Johan Alfven12e48112023-01-31 10:26:26 +0100403 def constraint_split_axis(op):
404 "Axis value must be in the range [-RANK(IFM) to +RANK(IFM))"
405 axis_tens = op.inputs[0]
406 input_tens = op.inputs[1]
407 dims = len(input_tens.shape)
408 axis = int(axis_tens.values)
409 axis += dims if axis < 0 else 0
410 valid = 0 <= axis < dims
411 return valid, f"Op has ifm_dimensions={dims} and axis value is: {axis}"
412
413 @staticmethod
414 def constraint_split_num_splits(op):
415 "Axis must be divisible by number of splits"
416 num_splits = op.attrs.get("num_splits")
417 axis_tens = op.inputs[0]
418 input_tens = op.inputs[1]
419 dims = len(input_tens.shape)
420 axis = int(axis_tens.values)
421 axis += dims if axis < 0 else 0
422 valid = input_tens.shape[axis] % num_splits == 0
423 return valid, f"Op has ifm shape={input_tens.shape} axis={axis} num_splits={num_splits}"
424
425 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200426 def constraint_splitv_inferred(op):
427 "Only one size is allowed to be inferred"
428 sizes = op.inputs[1].values
429 valid = np.count_nonzero(sizes == -1) <= 1
430 return valid, f"Op has multiple inferred sizes (-1): {sizes}"
431
432 @staticmethod
433 def constraint_axis_exists(op):
434 "Axis attribute must exist"
435 axis = op.attrs.get("axis")
436 valid = axis is not None
437 return valid, f"Op has axis={axis}"
438
439 @staticmethod
440 def constraint_axis_valid(op):
441 "Axis attribute must be in the range [0, <ofm_dimensions>)"
442 dims = len(op.ofm.shape)
443 axis = op.attrs["axis"]
444 axis += dims if axis < 0 else 0
445 valid = 0 <= axis < dims
446 return valid, f"Op has ofm_dimensions={dims} and axis attribute is: {axis}"
447
448 @staticmethod
449 def constraint_matching_dimensionality(op):
450 "All Input dimensionalities must match OFM dimensionality"
451 valid = True
452 extra = []
453 ofm_dim = len(op.ofm.shape)
454 tensors = [tens for tens in op.inputs if tens]
455 for tens in tensors:
456 dim = len(tens.shape)
457 if dim != ofm_dim:
458 valid = False
459 extra.append(f"Tensor '{tens.name}' has dimension: {dim}")
460 extra = ", ".join(extra)
461 return valid, f"Op has ofm_dimension={ofm_dim} and the list of mismatching inputs are: {extra}"
462
463 @staticmethod
464 def constraint_valid_dimensions(op):
465 "All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute"
466 valid = True
467 extra = []
468 ofm_shape = op.ofm.shape
469 ofm_dim = len(ofm_shape)
470 axis = op.attrs["axis"]
471 axis += ofm_dim if axis < 0 else 0
472 tensors = [tens for tens in op.inputs if tens]
473 for tens in tensors:
474 if any(tens.shape[dim] != ofm_shape[dim] for dim in range(ofm_dim) if dim != axis):
475 valid = False
476 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
477 extra = ", ".join(extra)
478 return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
479
480 @staticmethod
Johan Alfvénb3932512022-09-12 17:44:25 +0200481 def constraint_valid_dimensions_axis(op):
482 """The size of the OFM axis must match the sum of all IFM axis defined by the axis attribute"""
483 valid = True
484 extra = []
485 ofm_shape = op.ofm.shape
486 ofm_dim = len(ofm_shape)
487 axis = op.attrs["axis"]
488 axis += ofm_dim if axis < 0 else 0
489
490 sum_ifm_axis = 0
491 tensors = [tens for tens in op.inputs if tens]
492 for tens in tensors:
493 sum_ifm_axis += tens.shape[axis]
494 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
495
496 valid = sum_ifm_axis == ofm_shape[axis]
497 extra = ", ".join(extra)
498 return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
499
500 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200501 def constraint_stridedslice_input_count(op):
502 "Exactly 4 Input tensors are required"
503 inputs = len(op.inputs)
504 valid = inputs == 4
505 return valid, f"Op has {inputs} inputs"
506
507 @staticmethod
508 def constraint_pad_input_count(op):
509 "Number of input tensors must be exactly 2"
510 inputs = len(op.inputs)
511 valid = inputs == 2
512 return valid, f"Op has {inputs} inputs"
513
514 @staticmethod
515 def constraint_pad_constant(op):
516 "The padding tensor must be constant"
517 pad_tensor = op.inputs[1].values
518 valid = pad_tensor is not None
519 return valid, f"Op has non-constant padding tensor: {op.inputs[1].values}"
520
521 @staticmethod
522 def constraint_stridedslice_inputs_const(op):
523 "Begin, End and Stride Input tensors must be constant"
524 valid = True
525 extra = []
526 _, begin, end, strides = op.inputs
527 if begin.values is None:
528 valid = False
529 extra.append(f"Begin tensor '{begin.name}'")
530 if end.values is None:
531 valid = False
532 extra.append(f"End tensor '{end.name}'")
533 if strides.values is None:
534 valid = False
535 extra.append(f"Stride tensor '{strides.name}'")
536 extra = ", ".join(extra)
537 return valid, f"Op has non-constant tensors: {extra}"
538
539 @staticmethod
540 def constraint_ellipsis_mask(op):
541 "ellipsis_mask must be 0"
542 ellipsis = op.attrs["ellipsis_mask"]
543 valid = ellipsis == 0
544 return valid, f"Op has ellipsis mask as: {ellipsis}"
545
546 @staticmethod
547 def constraint_axis_masks(op):
548 "new_axis_mask and shrink_axis_mask cannot both be set"
549 new_axis = op.attrs["new_axis_mask"]
550 shrink_axis = op.attrs["shrink_axis_mask"]
551 valid = (new_axis == 0) or (shrink_axis == 0)
552 return valid, f"Op has new_axis_mask={new_axis} and shrink_axis_mask={shrink_axis}"
553
554 @staticmethod
555 def constraint_slice_ranges(op):
556 "Slice 'end' values must be greater than 'begin' values"
557 ifm, begin, end, _ = op.inputs
558 # Calculate offset begin/end
559 offset_begin = get_slice_offsets(ifm.shape, begin, op.attrs["begin_mask"], is_begin=True)
560 offset_end = get_slice_offsets(ifm.shape, end, op.attrs["end_mask"], is_begin=False)
561 # Check "end - begin" doesn't result in any zero or negative elements
562 valid = all((e - b) > 0 for b, e in zip(offset_begin, offset_end))
563 return valid, f"Op has begin_values={begin.values} and end_values={end.values}"
564
565 @staticmethod
566 def constraint_matching_inputs_types(op):
567 "Both Input data types must match"
568 ifm_dtype = op.ifm.dtype
569 ifm2_dtype = op.ifm2.dtype
570 valid = ifm_dtype == ifm2_dtype
571 return valid, f"Op has ifm_dtype={ifm_dtype} and ifm2_dtype={ifm2_dtype}"
572
573 @staticmethod
574 def constraint_matching_signed(op):
575 "For IFM that are signed, OFM must also be signed"
576 valid = True
577 ifm_dtype = op.ifm.dtype
578 ofm_dtype = op.ofm.dtype
579 if ifm_dtype.type & BaseType.Signed:
580 valid = bool(ofm_dtype.type & BaseType.Signed)
581 return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
582
583 @staticmethod
584 def constraint_unsigned_valid(op):
585 "For IFM that are unsigned, OFM must either be the same type or int32"
586 valid = True
587 ifm_dtype = op.ifm.dtype
588 ofm_dtype = op.ofm.dtype
589 if ifm_dtype.type & BaseType.Unsigned:
590 valid = (ifm_dtype == ofm_dtype) or (ofm_dtype == DataType.int32)
591 return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
592
593 @staticmethod
594 def constraint_input_8bit(op):
595 "IFM must be int8 or uint8"
596 ifm_dtype = op.ifm.dtype
597 valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.uint8)
598 return valid, f"Op has ifm_dtype={ifm_dtype}"
599
600 @staticmethod
601 def constraint_matching_either_shapes(op):
602 "At least one Input's shape must match the OFM's shape"
603 ifm_shape = op.ifm.shape
604 ifm2_shape = op.ifm2.shape if op.ifm2 else None
605 ofm_shape = op.ofm.shape
606 valid = (ifm_shape == ofm_shape) or (ifm2_shape == ofm_shape)
607 return valid, f"Op has ifm_shape={ifm_shape}, ifm2_shape={ifm2_shape} and ofm_shape={ofm_shape}"
608
609 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200610 def constraint_keep_dim_ifm_ofm(op):
611 "The IFM and OFM must have the same number of dimensions if keep_num_dims is set to true"
612 valid = True
613 if op.attrs.get("keep_num_dims"):
614 valid = len(op.ifm.shape) == len(op.ofm.shape)
615 return valid, f"Op has ifm shape={op.ifm.shape} and ofm shape={op.ofm.shape}"
616
617 @staticmethod
618 def constraint_mean_input_dims(op):
619 "Input tensor must be at least 2D"
620 dims = len(op.inputs[0].shape)
621 return 2 <= dims <= 4, f"Input is {dims}D"
622
623 @staticmethod
624 def constraint_mean_axis(op):
625 "Axis indices must correspond to height and width axes"
626 dims = len(op.inputs[0].shape)
627 axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
628 if dims == 2 or dims == 3:
629 valid = axis in (0, 1, [0], [1], [0, 1], [1, 0])
630 elif dims == 4:
631 valid = axis in (1, 2, [1], [2], [1, 2], [2, 1])
632 return valid, f"Axis is {axis}"
633
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200634 @staticmethod
635 def constraint_matching_in_out_quant(op):
636 "Input and output quantisation must match."
637 if not check_quantized_tens_scaling_equal(op.ifm, op.ofm):
638 return False, "IFM and OFM quantisation parameters are not equal."
639 return True, "IFM and OFM quantisation parameters matches."
640
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200641
642def tflite_semantic_checker(nng):
643 semantic_checker = TFLiteSemantic()
644 for sg in nng.subgraphs:
645 for op in sg.get_all_ops():
646 op.run_on_npu = semantic_checker.is_operator_semantic_valid(op)
647 return nng