blob: 7a0e234d104843b038a360baaf513c75124eeece [file] [log] [blame]
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# The TFLiteSemantic class which is a collection of TensorFlow lite model semantic checks.
18from collections import defaultdict
19
20import numpy as np
21
22from .data_type import BaseType
23from .data_type import DataType
24from .numeric_util import is_integer
25from .operation import get_slice_offsets
26from .operation import Op
27from .supported_operators_util import docstring_format_args
28from .supported_operators_util import list_formatter
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020029from .tensor import check_quantized_tens_scaling_equal
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020030from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN
31from .tflite_mapping import optype_to_builtintype
32
33
34def _optype_formatter(op_list):
35 # Convert internal op types to external names
36 output = map(optype_to_builtintype, op_list)
37 # Remove UNKNOWNs
38 output = (x for x in output if x is not BUILTIN_OPERATOR_UNKNOWN)
39 return list_formatter(output)
40
41
42class TFLiteSemantic:
43 # Categorised lists of operators
Jonas Ohlssond8575072022-03-30 10:30:25 +020044 convolution_ops = set(
45 (
46 Op.Conv2DBias,
47 Op.Conv2D,
48 Op.QuantizedConv2D,
49 )
50 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020051 depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
52 transpose_convolution_ops = set((Op.Conv2DBackpropInput,))
53 convolution_like_ops = convolution_ops | depthwise_convolution_ops | transpose_convolution_ops
54 max_pooling_ops = Op.op_set(Op.is_maxpool_op)
55 avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
56 pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
57 unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
Jonas Ohlssond8575072022-03-30 10:30:25 +020058 binary_elem_wise_min_max_ops = set(
59 (
60 Op.Minimum,
61 Op.Maximum,
62 )
63 )
64 binary_elem_wise_shift_ops = set(
65 (
66 Op.SHL,
67 Op.SHR,
68 )
69 )
70 binary_elem_wise_add_mul_sub = set(
71 (
72 Op.Add,
73 Op.Mul,
74 Op.Sub,
75 )
76 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020077 binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
78 elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
Fredrik Svedberg11563172022-07-06 14:54:12 +020079 shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims, Op.Quantize))
Jonas Ohlssond8575072022-03-30 10:30:25 +020080 reshape_ops = set(
81 (
82 Op.Reshape,
83 Op.QuantizedReshape,
84 Op.Squeeze,
85 Op.ExpandDims,
86 )
87 )
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020088
89 def __init__(self):
90 # Setup the generic constraints. Note: the order matters
91 self.generic_constraints = []
92 self.generic_constraints.append(TFLiteSemantic.constraint_tens_no_dynamic)
93 self.generic_constraints.append(TFLiteSemantic.constraint_tens_defined_shape)
94 self.generic_constraints.append(TFLiteSemantic.constraint_tens_output_scalar)
95 self.generic_constraints.append(TFLiteSemantic.constraint_tens_input_scalar)
96 self.generic_constraints.append(TFLiteSemantic.constraint_tens_shape_size)
97
98 self.generic_constraints.append(TFLiteSemantic.constraint_tens_quant_none_check)
99 self.generic_constraints.append(TFLiteSemantic.constraint_tens_quant_scale)
100 self.generic_constraints.append(TFLiteSemantic.constraint_quant_scale_inf)
erik.andersson@arm.com3bbbed62021-12-20 14:14:16 +0100101 self.generic_constraints.append(TFLiteSemantic.constraint_none_const_tensors)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200102
103 # Setup specific constraints. Note: the order matters
104 self.specific_constraints = defaultdict(list)
105
106 # Conv-like checks:
107 for op_type in TFLiteSemantic.convolution_like_ops:
108 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_stride_type)
109 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_dilation_type)
110
111 # Pooling checks:
112 for op_type in TFLiteSemantic.pooling_ops:
113 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_stride_type)
114 # AVG pooling specific checks:
115 for op_type in TFLiteSemantic.avg_pooling_ops:
116 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
117 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_filter_type)
118 # MAX pooling specific checks:
119 for op_type in TFLiteSemantic.max_pooling_ops:
120 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
121 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_filter_type)
122
123 # Concat specific checks:
124 for op_type in (Op.Concat, Op.ConcatTFLite):
125 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_axis_exists)
126 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_axis_valid)
127 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_dimensionality)
128 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_valid_dimensions)
Johan Alfvénb3932512022-09-12 17:44:25 +0200129 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_valid_dimensions_axis)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200130
131 # Element-wise checks:
132 for op_type in TFLiteSemantic.elem_wise_main_ops:
133 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_either_shapes)
134 # Unary specific checks:
135 for op_type in TFLiteSemantic.unary_elem_wise_main_ops:
136 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
137 # Binary Min/Max specific checks:
138 for op_type in TFLiteSemantic.binary_elem_wise_min_max_ops:
139 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
140 # Binary Add/Mul/Sub specific checks:
141 for op_type in TFLiteSemantic.binary_elem_wise_add_mul_sub:
142 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_inputs_types)
143 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_signed)
144 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_unsigned_valid)
145
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200146 # Ops reshaping dimensions: Reshape, Squeeze and ExpandDims
147 for op_type in TFLiteSemantic.reshape_ops:
148 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_quant)
149
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200150 # Softmax specific checks:
151 self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_shapes)
152 self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_in_out_types)
153 self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_beta_value_range)
154
155 # SplitV specific checks:
156 self.specific_constraints[Op.SplitV].append(TFLiteSemantic.constraint_splitv_inferred)
157
158 # StridedSlice specific checks:
159 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_stridedslice_input_count)
160 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_stridedslice_inputs_const)
161 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_ellipsis_mask)
162 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_axis_masks)
163 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_slice_ranges)
164
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200165 # FullyConnected specific checks:
166 self.specific_constraints[Op.FullyConnected].append(TFLiteSemantic.constraint_fc_output_2d)
167 self.specific_constraints[Op.FullyConnected].append(TFLiteSemantic.constraint_keep_dim_ifm_ofm)
168
169 # Pad specific checks:
170 self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_input_count)
171 self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_constant)
172
173 # HardSwish specific checks:
174 self.specific_constraints[Op.HardSwish].append(TFLiteSemantic.constraint_input_8bit)
175 self.specific_constraints[Op.HardSwish].append(TFLiteSemantic.constraint_matching_in_out_types)
Fredrik Svedberg701ba912022-09-07 16:01:15 +0200176
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200177 # Mean specific checks:
178 self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_input_8bit)
179 self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_input_dims)
180 self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_axis)
181
182 def is_operator_semantic_valid(self, op):
183 ext_type = optype_to_builtintype(op.type)
184
185 if op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const):
186 return True
187
Ayaan Masood4965fae2022-06-29 11:30:57 +0100188 # Generic constraints list filtered out to exclude certain constraints depending on op.type
189 filtered_generic_constraints = []
190
191 for constraint in self.generic_constraints:
192 # Check constraint not in dictionary otherwise return empty array
193 if constraint not in self.get_generic_constraint_exclude_list().get(op.type, []):
194 filtered_generic_constraints.append(constraint)
195
196 for constraint in filtered_generic_constraints + self.specific_constraints[op.type]:
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200197 valid, extra = constraint(op)
198 if not valid:
199 print(
Tim Hall3584a9c2021-11-18 22:05:17 +0000200 f"Warning: Unsupported TensorFlow Lite semantics for {ext_type} '{op.name}'. Placing on CPU instead"
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200201 )
202 print(f" - {constraint.__doc__}")
203 if extra:
204 print(f" {extra}")
205 return False
206
207 return True
208
209 @staticmethod
Ayaan Masood4965fae2022-06-29 11:30:57 +0100210 def get_generic_constraint_exclude_list():
211
212 # Not all generic constraints can be applied to each operator
213 generic_constraints_exclude_list = {
214 Op.Shape: [
215 TFLiteSemantic.constraint_tens_quant_none_check,
Ayaan Masood25f48dd2022-06-29 18:16:04 +0100216 ],
217 Op.Quantize: [
218 TFLiteSemantic.constraint_tens_no_dynamic,
219 TFLiteSemantic.constraint_tens_output_scalar,
Ayaan Masood25f48dd2022-06-29 18:16:04 +0100220 ],
Ayaan Masood4965fae2022-06-29 11:30:57 +0100221 }
222 return generic_constraints_exclude_list
223
224 @staticmethod
erik.andersson@arm.com3bbbed62021-12-20 14:14:16 +0100225 def constraint_none_const_tensors(op):
226 "Constant tensors should not have NoneType-values"
227 valid = True
228 extra = ""
229 for tens in filter(None, op.inputs):
230 if len(tens.ops) > 0 and tens.ops[0].type == Op.Const and tens.values is None:
231 valid = False
232 extra = str(tens.name)
233 return valid, f"Unexpected None value for constant tensor: {extra}"
234
235 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200236 def constraint_tens_no_dynamic(op):
237 "Input(s) and Output tensors must not be dynamic"
238 valid = True
239 extra = []
240 tensors = [tens for tens in op.inputs + op.outputs if tens]
241 for tens in tensors:
242 if (tens.shape == []) and (tens.values is None):
243 valid = False
244 extra.append(tens.name)
245 extra = ", ".join(extra)
246 return valid, f"Op has dynamic tensor(s): {extra}"
247
248 @staticmethod
249 def constraint_tens_defined_shape(op):
250 "Input(s) and Output tensors must have a defined shape"
251 valid = True
252 extra = []
253 tensors = [tens for tens in op.inputs + op.outputs if tens]
254 for tens in tensors:
255 if not tens.has_fully_defined_shape():
256 valid = False
257 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
258 return valid, ", ".join(extra)
259
260 @staticmethod
261 def constraint_tens_output_scalar(op):
262 "Output tensors cannot be scalar"
263 ofm = op.ofm
264 valid = ofm.shape != []
265 return valid, f"Output Tensor '{ofm.name}' is scalar"
266
267 @classmethod
268 @docstring_format_args([_optype_formatter(shapeless_input_ops)])
269 def constraint_tens_input_scalar(cls, op):
270 "Scalar Input tensors are only valid for op type: {}"
271 valid = True
272 extra = []
273 tensors = [tens for tens in op.inputs if tens]
274 for tens in tensors:
275 if (tens.shape == []) and (op.type not in cls.shapeless_input_ops):
276 valid = False
277 extra.append(tens.name)
278 extra = ", ".join(extra)
279 return valid, f"Op has scalar input tensor(s): {extra}"
280
281 @staticmethod
282 def constraint_tens_shape_size(op):
283 "Input(s) and Output tensors must not be greater than 4D"
284 valid = True
285 extra = []
286 tensors = [tens for tens in op.inputs + op.outputs if tens]
287 for tens in tensors:
288 if len(tens.shape) > 4:
289 valid = False
290 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
291 return valid, ", ".join(extra)
292
293 @staticmethod
294 def constraint_tens_quant_none_check(op):
295 "Input(s), Output and Weight tensors must have quantization parameters"
296 valid = True
297 extra = []
298 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
299 for tens in tensors:
300 if tens.quantization is None:
301 valid = False
302 extra.append(tens.name)
303 extra = ", ".join(extra)
304 return valid, f"Op has tensors with missing quantization parameters: {extra}"
305
306 @staticmethod
307 def constraint_tens_quant_scale(op):
308 "Input(s), Output and Weight tensors with quantization scales must be finite"
309 valid = True
310 extra = []
311 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
312 for tens in tensors:
Fredrik Svedberg11563172022-07-06 14:54:12 +0200313 if (
314 tens.quantization
315 and tens.quantization.scale_f32 is not None
316 and np.isinf(tens.quantization.scale_f32).any()
317 ):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200318 valid = False
319 extra.append(f"Tensor '{tens.name}' has quantization scale: {tens.quantization.scale_f32}")
320 return valid, ", ".join(extra)
321
322 @staticmethod
323 def constraint_fc_output_2d(op):
Ayaan Masooda2ec5aa2022-04-21 14:28:03 +0100324 """The output tensor(s) must have 2D shape"""
325 valid = op.ifm.get_shape_as_2d(op.weights.shape[-2]) is not None
326 extra = f"Op has non-2D output tensor '{op.ofm.name}'" if not valid else ""
327
328 return valid, extra
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200329
330 @staticmethod
331 def constraint_stride_type(op):
332 "Stride values for both width and height must be integer types"
333 w, h = op.get_kernel_stride()
334 valid = is_integer(w) and is_integer(h)
335 return valid, f"Op has stride WxH as: {repr(w)}x{repr(h)}"
336
337 @staticmethod
338 def constraint_dilation_type(op):
339 "Dilation factor values for both width and height must be integer types"
340 w, h = op.get_kernel_dilation()
341 valid = is_integer(w) and is_integer(h)
342 return valid, f"Op has dilation factor WxH as: {repr(w)}x{repr(h)}"
343
344 @staticmethod
345 def constraint_quant_scale_inf(op):
346 "Input and Output tensors must have quantization scales that fit within float32 precision"
347 if op.ofm is not None and op.ofm.is_quantized():
348 ofm_scale = op.ofm.quantization.scale_f32
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200349 if np.any(ofm_scale < np.finfo(np.float32).tiny):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200350 return (
351 False,
352 f"The quantization scale of the output tensor is {ofm_scale}, "
353 + f"minimum supported is: {np.finfo(np.float32).tiny}",
354 )
355 if op.ifm is not None and op.ifm.is_quantized():
356 ifm_scale = op.ifm.quantization.scale_f32
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200357 if np.any(np.isinf(ifm_scale / ofm_scale)):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200358 return (
359 False,
360 f"IFM scale divided by OFM scale is infinite, ifm_scale={ifm_scale} ofm_scale={ofm_scale}",
361 )
362 return True, "Op's quantization is ok"
363
364 @staticmethod
365 def constraint_matching_in_out_types(op):
366 "IFM and OFM data types must match"
367 ifm_dtype = op.ifm.dtype
368 ofm_dtype = op.ofm.dtype
369 valid = ifm_dtype == ofm_dtype
370 return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
371
372 @staticmethod
373 def constraint_beta_value_range(op):
374 "Beta value needs to be positive"
375 beta = op.attrs.get("beta", 1.0)
376 valid = beta >= 0
377 return valid, f"Op has beta={beta}"
378
379 @staticmethod
380 def constraint_filter_type(op):
381 "Kernel filter values for both width and height must be integer types"
382 w = op.kernel.width
383 h = op.kernel.height
384 valid = is_integer(w) and is_integer(h)
385 return valid, f"Op has kernel filter WxH as: {repr(w)}x{repr(h)}"
386
387 @staticmethod
388 def constraint_matching_shapes(op):
389 "IFM and OFM shapes must match"
390 ifm_shape = op.ifm.shape
391 ofm_shape = op.ofm.shape
392 valid = ifm_shape == ofm_shape
393 return valid, f"Op has ifm_shape={ifm_shape} and ofm_shape={ofm_shape}"
394
395 @staticmethod
396 def constraint_splitv_inferred(op):
397 "Only one size is allowed to be inferred"
398 sizes = op.inputs[1].values
399 valid = np.count_nonzero(sizes == -1) <= 1
400 return valid, f"Op has multiple inferred sizes (-1): {sizes}"
401
402 @staticmethod
403 def constraint_axis_exists(op):
404 "Axis attribute must exist"
405 axis = op.attrs.get("axis")
406 valid = axis is not None
407 return valid, f"Op has axis={axis}"
408
409 @staticmethod
410 def constraint_axis_valid(op):
411 "Axis attribute must be in the range [0, <ofm_dimensions>)"
412 dims = len(op.ofm.shape)
413 axis = op.attrs["axis"]
414 axis += dims if axis < 0 else 0
415 valid = 0 <= axis < dims
416 return valid, f"Op has ofm_dimensions={dims} and axis attribute is: {axis}"
417
418 @staticmethod
419 def constraint_matching_dimensionality(op):
420 "All Input dimensionalities must match OFM dimensionality"
421 valid = True
422 extra = []
423 ofm_dim = len(op.ofm.shape)
424 tensors = [tens for tens in op.inputs if tens]
425 for tens in tensors:
426 dim = len(tens.shape)
427 if dim != ofm_dim:
428 valid = False
429 extra.append(f"Tensor '{tens.name}' has dimension: {dim}")
430 extra = ", ".join(extra)
431 return valid, f"Op has ofm_dimension={ofm_dim} and the list of mismatching inputs are: {extra}"
432
433 @staticmethod
434 def constraint_valid_dimensions(op):
435 "All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute"
436 valid = True
437 extra = []
438 ofm_shape = op.ofm.shape
439 ofm_dim = len(ofm_shape)
440 axis = op.attrs["axis"]
441 axis += ofm_dim if axis < 0 else 0
442 tensors = [tens for tens in op.inputs if tens]
443 for tens in tensors:
444 if any(tens.shape[dim] != ofm_shape[dim] for dim in range(ofm_dim) if dim != axis):
445 valid = False
446 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
447 extra = ", ".join(extra)
448 return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
449
450 @staticmethod
Johan Alfvénb3932512022-09-12 17:44:25 +0200451 def constraint_valid_dimensions_axis(op):
452 """The size of the OFM axis must match the sum of all IFM axis defined by the axis attribute"""
453 valid = True
454 extra = []
455 ofm_shape = op.ofm.shape
456 ofm_dim = len(ofm_shape)
457 axis = op.attrs["axis"]
458 axis += ofm_dim if axis < 0 else 0
459
460 sum_ifm_axis = 0
461 tensors = [tens for tens in op.inputs if tens]
462 for tens in tensors:
463 sum_ifm_axis += tens.shape[axis]
464 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
465
466 valid = sum_ifm_axis == ofm_shape[axis]
467 extra = ", ".join(extra)
468 return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
469
470 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200471 def constraint_stridedslice_input_count(op):
472 "Exactly 4 Input tensors are required"
473 inputs = len(op.inputs)
474 valid = inputs == 4
475 return valid, f"Op has {inputs} inputs"
476
477 @staticmethod
478 def constraint_pad_input_count(op):
479 "Number of input tensors must be exactly 2"
480 inputs = len(op.inputs)
481 valid = inputs == 2
482 return valid, f"Op has {inputs} inputs"
483
484 @staticmethod
485 def constraint_pad_constant(op):
486 "The padding tensor must be constant"
487 pad_tensor = op.inputs[1].values
488 valid = pad_tensor is not None
489 return valid, f"Op has non-constant padding tensor: {op.inputs[1].values}"
490
491 @staticmethod
492 def constraint_stridedslice_inputs_const(op):
493 "Begin, End and Stride Input tensors must be constant"
494 valid = True
495 extra = []
496 _, begin, end, strides = op.inputs
497 if begin.values is None:
498 valid = False
499 extra.append(f"Begin tensor '{begin.name}'")
500 if end.values is None:
501 valid = False
502 extra.append(f"End tensor '{end.name}'")
503 if strides.values is None:
504 valid = False
505 extra.append(f"Stride tensor '{strides.name}'")
506 extra = ", ".join(extra)
507 return valid, f"Op has non-constant tensors: {extra}"
508
509 @staticmethod
510 def constraint_ellipsis_mask(op):
511 "ellipsis_mask must be 0"
512 ellipsis = op.attrs["ellipsis_mask"]
513 valid = ellipsis == 0
514 return valid, f"Op has ellipsis mask as: {ellipsis}"
515
516 @staticmethod
517 def constraint_axis_masks(op):
518 "new_axis_mask and shrink_axis_mask cannot both be set"
519 new_axis = op.attrs["new_axis_mask"]
520 shrink_axis = op.attrs["shrink_axis_mask"]
521 valid = (new_axis == 0) or (shrink_axis == 0)
522 return valid, f"Op has new_axis_mask={new_axis} and shrink_axis_mask={shrink_axis}"
523
524 @staticmethod
525 def constraint_slice_ranges(op):
526 "Slice 'end' values must be greater than 'begin' values"
527 ifm, begin, end, _ = op.inputs
528 # Calculate offset begin/end
529 offset_begin = get_slice_offsets(ifm.shape, begin, op.attrs["begin_mask"], is_begin=True)
530 offset_end = get_slice_offsets(ifm.shape, end, op.attrs["end_mask"], is_begin=False)
531 # Check "end - begin" doesn't result in any zero or negative elements
532 valid = all((e - b) > 0 for b, e in zip(offset_begin, offset_end))
533 return valid, f"Op has begin_values={begin.values} and end_values={end.values}"
534
535 @staticmethod
536 def constraint_matching_inputs_types(op):
537 "Both Input data types must match"
538 ifm_dtype = op.ifm.dtype
539 ifm2_dtype = op.ifm2.dtype
540 valid = ifm_dtype == ifm2_dtype
541 return valid, f"Op has ifm_dtype={ifm_dtype} and ifm2_dtype={ifm2_dtype}"
542
543 @staticmethod
544 def constraint_matching_signed(op):
545 "For IFM that are signed, OFM must also be signed"
546 valid = True
547 ifm_dtype = op.ifm.dtype
548 ofm_dtype = op.ofm.dtype
549 if ifm_dtype.type & BaseType.Signed:
550 valid = bool(ofm_dtype.type & BaseType.Signed)
551 return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
552
553 @staticmethod
554 def constraint_unsigned_valid(op):
555 "For IFM that are unsigned, OFM must either be the same type or int32"
556 valid = True
557 ifm_dtype = op.ifm.dtype
558 ofm_dtype = op.ofm.dtype
559 if ifm_dtype.type & BaseType.Unsigned:
560 valid = (ifm_dtype == ofm_dtype) or (ofm_dtype == DataType.int32)
561 return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
562
563 @staticmethod
564 def constraint_input_8bit(op):
565 "IFM must be int8 or uint8"
566 ifm_dtype = op.ifm.dtype
567 valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.uint8)
568 return valid, f"Op has ifm_dtype={ifm_dtype}"
569
570 @staticmethod
571 def constraint_matching_either_shapes(op):
572 "At least one Input's shape must match the OFM's shape"
573 ifm_shape = op.ifm.shape
574 ifm2_shape = op.ifm2.shape if op.ifm2 else None
575 ofm_shape = op.ofm.shape
576 valid = (ifm_shape == ofm_shape) or (ifm2_shape == ofm_shape)
577 return valid, f"Op has ifm_shape={ifm_shape}, ifm2_shape={ifm2_shape} and ofm_shape={ofm_shape}"
578
579 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200580 def constraint_keep_dim_ifm_ofm(op):
581 "The IFM and OFM must have the same number of dimensions if keep_num_dims is set to true"
582 valid = True
583 if op.attrs.get("keep_num_dims"):
584 valid = len(op.ifm.shape) == len(op.ofm.shape)
585 return valid, f"Op has ifm shape={op.ifm.shape} and ofm shape={op.ofm.shape}"
586
587 @staticmethod
588 def constraint_mean_input_dims(op):
589 "Input tensor must be at least 2D"
590 dims = len(op.inputs[0].shape)
591 return 2 <= dims <= 4, f"Input is {dims}D"
592
593 @staticmethod
594 def constraint_mean_axis(op):
595 "Axis indices must correspond to height and width axes"
596 dims = len(op.inputs[0].shape)
597 axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
598 if dims == 2 or dims == 3:
599 valid = axis in (0, 1, [0], [1], [0, 1], [1, 0])
600 elif dims == 4:
601 valid = axis in (1, 2, [1], [2], [1, 2], [2, 1])
602 return valid, f"Axis is {axis}"
603
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200604 @staticmethod
605 def constraint_matching_in_out_quant(op):
606 "Input and output quantisation must match."
607 if not check_quantized_tens_scaling_equal(op.ifm, op.ofm):
608 return False, "IFM and OFM quantisation parameters are not equal."
609 return True, "IFM and OFM quantisation parameters matches."
610
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200611
612def tflite_semantic_checker(nng):
613 semantic_checker = TFLiteSemantic()
614 for sg in nng.subgraphs:
615 for op in sg.get_all_ops():
616 op.run_on_npu = semantic_checker.is_operator_semantic_valid(op)
617 return nng