blob: c8b373a3498d11b3d601610d786baa24e8dcd983 [file] [log] [blame]
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# The TFLiteSemantic class which is a collection of TensorFlow lite model semantic checks.
18from collections import defaultdict
19
20import numpy as np
21
22from .data_type import BaseType
23from .data_type import DataType
24from .numeric_util import is_integer
25from .operation import get_slice_offsets
26from .operation import Op
27from .supported_operators_util import docstring_format_args
28from .supported_operators_util import list_formatter
29from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN
30from .tflite_mapping import optype_to_builtintype
31
32
33def _optype_formatter(op_list):
34 # Convert internal op types to external names
35 output = map(optype_to_builtintype, op_list)
36 # Remove UNKNOWNs
37 output = (x for x in output if x is not BUILTIN_OPERATOR_UNKNOWN)
38 return list_formatter(output)
39
40
41class TFLiteSemantic:
42 # Categorised lists of operators
43 convolution_ops = set((Op.Conv2DBias, Op.Conv2D, Op.QuantizedConv2D,))
44 depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
45 transpose_convolution_ops = set((Op.Conv2DBackpropInput,))
46 convolution_like_ops = convolution_ops | depthwise_convolution_ops | transpose_convolution_ops
47 max_pooling_ops = Op.op_set(Op.is_maxpool_op)
48 avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
49 pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
50 unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
51 binary_elem_wise_min_max_ops = set((Op.Minimum, Op.Maximum,))
52 binary_elem_wise_shift_ops = set((Op.SHL, Op.SHR,))
53 binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.Sub,))
54 binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
55 elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
56 shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean))
57
58 def __init__(self):
59 # Setup the generic constraints. Note: the order matters
60 self.generic_constraints = []
61 self.generic_constraints.append(TFLiteSemantic.constraint_tens_no_dynamic)
62 self.generic_constraints.append(TFLiteSemantic.constraint_tens_defined_shape)
63 self.generic_constraints.append(TFLiteSemantic.constraint_tens_output_scalar)
64 self.generic_constraints.append(TFLiteSemantic.constraint_tens_input_scalar)
65 self.generic_constraints.append(TFLiteSemantic.constraint_tens_shape_size)
66
67 self.generic_constraints.append(TFLiteSemantic.constraint_tens_quant_none_check)
68 self.generic_constraints.append(TFLiteSemantic.constraint_tens_quant_scale)
69 self.generic_constraints.append(TFLiteSemantic.constraint_quant_scale_inf)
70
71 # Setup specific constraints. Note: the order matters
72 self.specific_constraints = defaultdict(list)
73
74 # Conv-like checks:
75 for op_type in TFLiteSemantic.convolution_like_ops:
76 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_stride_type)
77 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_dilation_type)
78
79 # Pooling checks:
80 for op_type in TFLiteSemantic.pooling_ops:
81 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_stride_type)
82 # AVG pooling specific checks:
83 for op_type in TFLiteSemantic.avg_pooling_ops:
84 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
85 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_filter_type)
86 # MAX pooling specific checks:
87 for op_type in TFLiteSemantic.max_pooling_ops:
88 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
89 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_filter_type)
90
91 # Concat specific checks:
92 for op_type in (Op.Concat, Op.ConcatTFLite):
93 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_axis_exists)
94 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_axis_valid)
95 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_dimensionality)
96 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_valid_dimensions)
97
98 # Element-wise checks:
99 for op_type in TFLiteSemantic.elem_wise_main_ops:
100 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_either_shapes)
101 # Unary specific checks:
102 for op_type in TFLiteSemantic.unary_elem_wise_main_ops:
103 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
104 # Binary Min/Max specific checks:
105 for op_type in TFLiteSemantic.binary_elem_wise_min_max_ops:
106 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
107 # Binary Add/Mul/Sub specific checks:
108 for op_type in TFLiteSemantic.binary_elem_wise_add_mul_sub:
109 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_inputs_types)
110 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_signed)
111 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_unsigned_valid)
112
113 # Softmax specific checks:
114 self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_shapes)
115 self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_in_out_types)
116 self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_beta_value_range)
117
118 # SplitV specific checks:
119 self.specific_constraints[Op.SplitV].append(TFLiteSemantic.constraint_splitv_inferred)
120
121 # StridedSlice specific checks:
122 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_stridedslice_input_count)
123 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_stridedslice_inputs_const)
124 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_ellipsis_mask)
125 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_axis_masks)
126 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_slice_ranges)
127
128 # LeakyRelu specific checks:
129 self.specific_constraints[Op.LeakyRelu].append(TFLiteSemantic.constraint_alpha_valid)
130
131 # FullyConnected specific checks:
132 self.specific_constraints[Op.FullyConnected].append(TFLiteSemantic.constraint_fc_output_2d)
133 self.specific_constraints[Op.FullyConnected].append(TFLiteSemantic.constraint_keep_dim_ifm_ofm)
134
135 # Pad specific checks:
136 self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_input_count)
137 self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_constant)
138
139 # HardSwish specific checks:
140 self.specific_constraints[Op.HardSwish].append(TFLiteSemantic.constraint_input_8bit)
141 self.specific_constraints[Op.HardSwish].append(TFLiteSemantic.constraint_matching_in_out_types)
142 # Mean specific checks:
143 self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_input_8bit)
144 self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_input_dims)
145 self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_axis)
146
147 def is_operator_semantic_valid(self, op):
148 ext_type = optype_to_builtintype(op.type)
149
150 if op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const):
151 return True
152
153 for constraint in self.generic_constraints + self.specific_constraints[op.type]:
154 valid, extra = constraint(op)
155 if not valid:
156 print(
157 f"Warning: unsupported TensorFlow Lite semantics for {ext_type} '{op.name}'. Placing on CPU instead"
158 )
159 print(f" - {constraint.__doc__}")
160 if extra:
161 print(f" {extra}")
162 return False
163
164 return True
165
166 @staticmethod
167 def constraint_tens_no_dynamic(op):
168 "Input(s) and Output tensors must not be dynamic"
169 valid = True
170 extra = []
171 tensors = [tens for tens in op.inputs + op.outputs if tens]
172 for tens in tensors:
173 if (tens.shape == []) and (tens.values is None):
174 valid = False
175 extra.append(tens.name)
176 extra = ", ".join(extra)
177 return valid, f"Op has dynamic tensor(s): {extra}"
178
179 @staticmethod
180 def constraint_tens_defined_shape(op):
181 "Input(s) and Output tensors must have a defined shape"
182 valid = True
183 extra = []
184 tensors = [tens for tens in op.inputs + op.outputs if tens]
185 for tens in tensors:
186 if not tens.has_fully_defined_shape():
187 valid = False
188 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
189 return valid, ", ".join(extra)
190
191 @staticmethod
192 def constraint_tens_output_scalar(op):
193 "Output tensors cannot be scalar"
194 ofm = op.ofm
195 valid = ofm.shape != []
196 return valid, f"Output Tensor '{ofm.name}' is scalar"
197
198 @classmethod
199 @docstring_format_args([_optype_formatter(shapeless_input_ops)])
200 def constraint_tens_input_scalar(cls, op):
201 "Scalar Input tensors are only valid for op type: {}"
202 valid = True
203 extra = []
204 tensors = [tens for tens in op.inputs if tens]
205 for tens in tensors:
206 if (tens.shape == []) and (op.type not in cls.shapeless_input_ops):
207 valid = False
208 extra.append(tens.name)
209 extra = ", ".join(extra)
210 return valid, f"Op has scalar input tensor(s): {extra}"
211
212 @staticmethod
213 def constraint_tens_shape_size(op):
214 "Input(s) and Output tensors must not be greater than 4D"
215 valid = True
216 extra = []
217 tensors = [tens for tens in op.inputs + op.outputs if tens]
218 for tens in tensors:
219 if len(tens.shape) > 4:
220 valid = False
221 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
222 return valid, ", ".join(extra)
223
224 @staticmethod
225 def constraint_tens_quant_none_check(op):
226 "Input(s), Output and Weight tensors must have quantization parameters"
227 valid = True
228 extra = []
229 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
230 for tens in tensors:
231 if tens.quantization is None:
232 valid = False
233 extra.append(tens.name)
234 extra = ", ".join(extra)
235 return valid, f"Op has tensors with missing quantization parameters: {extra}"
236
237 @staticmethod
238 def constraint_tens_quant_scale(op):
239 "Input(s), Output and Weight tensors with quantization scales must be finite"
240 valid = True
241 extra = []
242 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
243 for tens in tensors:
244 if (tens.quantization.scale_f32 is not None) and np.isinf(tens.quantization.scale_f32).any():
245 valid = False
246 extra.append(f"Tensor '{tens.name}' has quantization scale: {tens.quantization.scale_f32}")
247 return valid, ", ".join(extra)
248
249 @staticmethod
250 def constraint_fc_output_2d(op):
251 "The output tensor(s) must have 2D shape"
252 valid = True
253 extra = []
254 for tens in op.outputs:
255 if len(tens.shape) != 2:
256 valid = False
257 extra.append(f"Tensor '{tens.name}' is {len(tens.shape)}D")
258 return valid, ", ".join(extra)
259
260 @staticmethod
261 def constraint_stride_type(op):
262 "Stride values for both width and height must be integer types"
263 w, h = op.get_kernel_stride()
264 valid = is_integer(w) and is_integer(h)
265 return valid, f"Op has stride WxH as: {repr(w)}x{repr(h)}"
266
267 @staticmethod
268 def constraint_dilation_type(op):
269 "Dilation factor values for both width and height must be integer types"
270 w, h = op.get_kernel_dilation()
271 valid = is_integer(w) and is_integer(h)
272 return valid, f"Op has dilation factor WxH as: {repr(w)}x{repr(h)}"
273
274 @staticmethod
275 def constraint_quant_scale_inf(op):
276 "Input and Output tensors must have quantization scales that fit within float32 precision"
277 if op.ofm is not None and op.ofm.is_quantized():
278 ofm_scale = op.ofm.quantization.scale_f32
279 if ofm_scale < np.finfo(np.float32).tiny:
280 return (
281 False,
282 f"The quantization scale of the output tensor is {ofm_scale}, "
283 + f"minimum supported is: {np.finfo(np.float32).tiny}",
284 )
285 if op.ifm is not None and op.ifm.is_quantized():
286 ifm_scale = op.ifm.quantization.scale_f32
287 if np.isinf(ifm_scale / ofm_scale):
288 return (
289 False,
290 f"IFM scale divided by OFM scale is infinite, ifm_scale={ifm_scale} ofm_scale={ofm_scale}",
291 )
292 return True, "Op's quantization is ok"
293
294 @staticmethod
295 def constraint_matching_in_out_types(op):
296 "IFM and OFM data types must match"
297 ifm_dtype = op.ifm.dtype
298 ofm_dtype = op.ofm.dtype
299 valid = ifm_dtype == ofm_dtype
300 return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
301
302 @staticmethod
303 def constraint_beta_value_range(op):
304 "Beta value needs to be positive"
305 beta = op.attrs.get("beta", 1.0)
306 valid = beta >= 0
307 return valid, f"Op has beta={beta}"
308
309 @staticmethod
310 def constraint_filter_type(op):
311 "Kernel filter values for both width and height must be integer types"
312 w = op.kernel.width
313 h = op.kernel.height
314 valid = is_integer(w) and is_integer(h)
315 return valid, f"Op has kernel filter WxH as: {repr(w)}x{repr(h)}"
316
317 @staticmethod
318 def constraint_matching_shapes(op):
319 "IFM and OFM shapes must match"
320 ifm_shape = op.ifm.shape
321 ofm_shape = op.ofm.shape
322 valid = ifm_shape == ofm_shape
323 return valid, f"Op has ifm_shape={ifm_shape} and ofm_shape={ofm_shape}"
324
325 @staticmethod
326 def constraint_splitv_inferred(op):
327 "Only one size is allowed to be inferred"
328 sizes = op.inputs[1].values
329 valid = np.count_nonzero(sizes == -1) <= 1
330 return valid, f"Op has multiple inferred sizes (-1): {sizes}"
331
332 @staticmethod
333 def constraint_axis_exists(op):
334 "Axis attribute must exist"
335 axis = op.attrs.get("axis")
336 valid = axis is not None
337 return valid, f"Op has axis={axis}"
338
339 @staticmethod
340 def constraint_axis_valid(op):
341 "Axis attribute must be in the range [0, <ofm_dimensions>)"
342 dims = len(op.ofm.shape)
343 axis = op.attrs["axis"]
344 axis += dims if axis < 0 else 0
345 valid = 0 <= axis < dims
346 return valid, f"Op has ofm_dimensions={dims} and axis attribute is: {axis}"
347
348 @staticmethod
349 def constraint_matching_dimensionality(op):
350 "All Input dimensionalities must match OFM dimensionality"
351 valid = True
352 extra = []
353 ofm_dim = len(op.ofm.shape)
354 tensors = [tens for tens in op.inputs if tens]
355 for tens in tensors:
356 dim = len(tens.shape)
357 if dim != ofm_dim:
358 valid = False
359 extra.append(f"Tensor '{tens.name}' has dimension: {dim}")
360 extra = ", ".join(extra)
361 return valid, f"Op has ofm_dimension={ofm_dim} and the list of mismatching inputs are: {extra}"
362
363 @staticmethod
364 def constraint_valid_dimensions(op):
365 "All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute"
366 valid = True
367 extra = []
368 ofm_shape = op.ofm.shape
369 ofm_dim = len(ofm_shape)
370 axis = op.attrs["axis"]
371 axis += ofm_dim if axis < 0 else 0
372 tensors = [tens for tens in op.inputs if tens]
373 for tens in tensors:
374 if any(tens.shape[dim] != ofm_shape[dim] for dim in range(ofm_dim) if dim != axis):
375 valid = False
376 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
377 extra = ", ".join(extra)
378 return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
379
380 @staticmethod
381 def constraint_stridedslice_input_count(op):
382 "Exactly 4 Input tensors are required"
383 inputs = len(op.inputs)
384 valid = inputs == 4
385 return valid, f"Op has {inputs} inputs"
386
387 @staticmethod
388 def constraint_pad_input_count(op):
389 "Number of input tensors must be exactly 2"
390 inputs = len(op.inputs)
391 valid = inputs == 2
392 return valid, f"Op has {inputs} inputs"
393
394 @staticmethod
395 def constraint_pad_constant(op):
396 "The padding tensor must be constant"
397 pad_tensor = op.inputs[1].values
398 valid = pad_tensor is not None
399 return valid, f"Op has non-constant padding tensor: {op.inputs[1].values}"
400
401 @staticmethod
402 def constraint_stridedslice_inputs_const(op):
403 "Begin, End and Stride Input tensors must be constant"
404 valid = True
405 extra = []
406 _, begin, end, strides = op.inputs
407 if begin.values is None:
408 valid = False
409 extra.append(f"Begin tensor '{begin.name}'")
410 if end.values is None:
411 valid = False
412 extra.append(f"End tensor '{end.name}'")
413 if strides.values is None:
414 valid = False
415 extra.append(f"Stride tensor '{strides.name}'")
416 extra = ", ".join(extra)
417 return valid, f"Op has non-constant tensors: {extra}"
418
419 @staticmethod
420 def constraint_ellipsis_mask(op):
421 "ellipsis_mask must be 0"
422 ellipsis = op.attrs["ellipsis_mask"]
423 valid = ellipsis == 0
424 return valid, f"Op has ellipsis mask as: {ellipsis}"
425
426 @staticmethod
427 def constraint_axis_masks(op):
428 "new_axis_mask and shrink_axis_mask cannot both be set"
429 new_axis = op.attrs["new_axis_mask"]
430 shrink_axis = op.attrs["shrink_axis_mask"]
431 valid = (new_axis == 0) or (shrink_axis == 0)
432 return valid, f"Op has new_axis_mask={new_axis} and shrink_axis_mask={shrink_axis}"
433
434 @staticmethod
435 def constraint_slice_ranges(op):
436 "Slice 'end' values must be greater than 'begin' values"
437 ifm, begin, end, _ = op.inputs
438 # Calculate offset begin/end
439 offset_begin = get_slice_offsets(ifm.shape, begin, op.attrs["begin_mask"], is_begin=True)
440 offset_end = get_slice_offsets(ifm.shape, end, op.attrs["end_mask"], is_begin=False)
441 # Check "end - begin" doesn't result in any zero or negative elements
442 valid = all((e - b) > 0 for b, e in zip(offset_begin, offset_end))
443 return valid, f"Op has begin_values={begin.values} and end_values={end.values}"
444
445 @staticmethod
446 def constraint_matching_inputs_types(op):
447 "Both Input data types must match"
448 ifm_dtype = op.ifm.dtype
449 ifm2_dtype = op.ifm2.dtype
450 valid = ifm_dtype == ifm2_dtype
451 return valid, f"Op has ifm_dtype={ifm_dtype} and ifm2_dtype={ifm2_dtype}"
452
453 @staticmethod
454 def constraint_matching_signed(op):
455 "For IFM that are signed, OFM must also be signed"
456 valid = True
457 ifm_dtype = op.ifm.dtype
458 ofm_dtype = op.ofm.dtype
459 if ifm_dtype.type & BaseType.Signed:
460 valid = bool(ofm_dtype.type & BaseType.Signed)
461 return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
462
463 @staticmethod
464 def constraint_unsigned_valid(op):
465 "For IFM that are unsigned, OFM must either be the same type or int32"
466 valid = True
467 ifm_dtype = op.ifm.dtype
468 ofm_dtype = op.ofm.dtype
469 if ifm_dtype.type & BaseType.Unsigned:
470 valid = (ifm_dtype == ofm_dtype) or (ofm_dtype == DataType.int32)
471 return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
472
473 @staticmethod
474 def constraint_input_8bit(op):
475 "IFM must be int8 or uint8"
476 ifm_dtype = op.ifm.dtype
477 valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.uint8)
478 return valid, f"Op has ifm_dtype={ifm_dtype}"
479
480 @staticmethod
481 def constraint_matching_either_shapes(op):
482 "At least one Input's shape must match the OFM's shape"
483 ifm_shape = op.ifm.shape
484 ifm2_shape = op.ifm2.shape if op.ifm2 else None
485 ofm_shape = op.ofm.shape
486 valid = (ifm_shape == ofm_shape) or (ifm2_shape == ofm_shape)
487 return valid, f"Op has ifm_shape={ifm_shape}, ifm2_shape={ifm2_shape} and ofm_shape={ofm_shape}"
488
489 @staticmethod
490 def constraint_alpha_valid(op):
491 "Alpha must not be negative"
492 alpha = op.attrs["alpha"]
493 valid = alpha >= 0
494 return valid, f"Op has alpha={alpha}"
495
496 @staticmethod
497 def constraint_keep_dim_ifm_ofm(op):
498 "The IFM and OFM must have the same number of dimensions if keep_num_dims is set to true"
499 valid = True
500 if op.attrs.get("keep_num_dims"):
501 valid = len(op.ifm.shape) == len(op.ofm.shape)
502 return valid, f"Op has ifm shape={op.ifm.shape} and ofm shape={op.ofm.shape}"
503
504 @staticmethod
505 def constraint_mean_input_dims(op):
506 "Input tensor must be at least 2D"
507 dims = len(op.inputs[0].shape)
508 return 2 <= dims <= 4, f"Input is {dims}D"
509
510 @staticmethod
511 def constraint_mean_axis(op):
512 "Axis indices must correspond to height and width axes"
513 dims = len(op.inputs[0].shape)
514 axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
515 if dims == 2 or dims == 3:
516 valid = axis in (0, 1, [0], [1], [0, 1], [1, 0])
517 elif dims == 4:
518 valid = axis in (1, 2, [1], [2], [1, 2], [2, 1])
519 return valid, f"Axis is {axis}"
520
521
522def tflite_semantic_checker(nng):
523 semantic_checker = TFLiteSemantic()
524 for sg in nng.subgraphs:
525 for op in sg.get_all_ops():
526 op.run_on_npu = semantic_checker.is_operator_semantic_valid(op)
527 return nng