blob: 3b7f248a9b952d10307f619495d72102d5b8a7dc [file] [log] [blame]
Jonas Ohlsson45e653d2021-07-26 16:13:12 +02001# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# Description:
17# The TFLiteSemantic class which is a collection of TensorFlow lite model semantic checks.
18from collections import defaultdict
19
20import numpy as np
21
22from .data_type import BaseType
23from .data_type import DataType
24from .numeric_util import is_integer
25from .operation import get_slice_offsets
26from .operation import Op
27from .supported_operators_util import docstring_format_args
28from .supported_operators_util import list_formatter
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020029from .tensor import check_quantized_tens_scaling_equal
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020030from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN
31from .tflite_mapping import optype_to_builtintype
32
33
34def _optype_formatter(op_list):
35 # Convert internal op types to external names
36 output = map(optype_to_builtintype, op_list)
37 # Remove UNKNOWNs
38 output = (x for x in output if x is not BUILTIN_OPERATOR_UNKNOWN)
39 return list_formatter(output)
40
41
42class TFLiteSemantic:
43 # Categorised lists of operators
44 convolution_ops = set((Op.Conv2DBias, Op.Conv2D, Op.QuantizedConv2D,))
45 depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
46 transpose_convolution_ops = set((Op.Conv2DBackpropInput,))
47 convolution_like_ops = convolution_ops | depthwise_convolution_ops | transpose_convolution_ops
48 max_pooling_ops = Op.op_set(Op.is_maxpool_op)
49 avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
50 pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
51 unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
52 binary_elem_wise_min_max_ops = set((Op.Minimum, Op.Maximum,))
53 binary_elem_wise_shift_ops = set((Op.SHL, Op.SHR,))
54 binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.Sub,))
55 binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
56 elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +020057 shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV, Op.Mean, Op.ExpandDims))
58 reshape_ops = set((Op.Reshape, Op.QuantizedReshape, Op.Squeeze, Op.ExpandDims,))
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020059
60 def __init__(self):
61 # Setup the generic constraints. Note: the order matters
62 self.generic_constraints = []
63 self.generic_constraints.append(TFLiteSemantic.constraint_tens_no_dynamic)
64 self.generic_constraints.append(TFLiteSemantic.constraint_tens_defined_shape)
65 self.generic_constraints.append(TFLiteSemantic.constraint_tens_output_scalar)
66 self.generic_constraints.append(TFLiteSemantic.constraint_tens_input_scalar)
67 self.generic_constraints.append(TFLiteSemantic.constraint_tens_shape_size)
68
69 self.generic_constraints.append(TFLiteSemantic.constraint_tens_quant_none_check)
70 self.generic_constraints.append(TFLiteSemantic.constraint_tens_quant_scale)
71 self.generic_constraints.append(TFLiteSemantic.constraint_quant_scale_inf)
erik.andersson@arm.com3bbbed62021-12-20 14:14:16 +010072 self.generic_constraints.append(TFLiteSemantic.constraint_none_const_tensors)
Jonas Ohlsson45e653d2021-07-26 16:13:12 +020073
74 # Setup specific constraints. Note: the order matters
75 self.specific_constraints = defaultdict(list)
76
77 # Conv-like checks:
78 for op_type in TFLiteSemantic.convolution_like_ops:
79 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_stride_type)
80 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_dilation_type)
81
82 # Pooling checks:
83 for op_type in TFLiteSemantic.pooling_ops:
84 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_stride_type)
85 # AVG pooling specific checks:
86 for op_type in TFLiteSemantic.avg_pooling_ops:
87 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
88 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_filter_type)
89 # MAX pooling specific checks:
90 for op_type in TFLiteSemantic.max_pooling_ops:
91 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
92 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_filter_type)
93
94 # Concat specific checks:
95 for op_type in (Op.Concat, Op.ConcatTFLite):
96 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_axis_exists)
97 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_axis_valid)
98 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_dimensionality)
99 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_valid_dimensions)
100
101 # Element-wise checks:
102 for op_type in TFLiteSemantic.elem_wise_main_ops:
103 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_either_shapes)
104 # Unary specific checks:
105 for op_type in TFLiteSemantic.unary_elem_wise_main_ops:
106 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
107 # Binary Min/Max specific checks:
108 for op_type in TFLiteSemantic.binary_elem_wise_min_max_ops:
109 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_types)
110 # Binary Add/Mul/Sub specific checks:
111 for op_type in TFLiteSemantic.binary_elem_wise_add_mul_sub:
112 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_inputs_types)
113 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_signed)
114 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_unsigned_valid)
115
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200116 # Ops reshaping dimensions: Reshape, Squeeze and ExpandDims
117 for op_type in TFLiteSemantic.reshape_ops:
118 self.specific_constraints[op_type].append(TFLiteSemantic.constraint_matching_in_out_quant)
119
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200120 # Softmax specific checks:
121 self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_shapes)
122 self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_matching_in_out_types)
123 self.specific_constraints[Op.Softmax].append(TFLiteSemantic.constraint_beta_value_range)
124
125 # SplitV specific checks:
126 self.specific_constraints[Op.SplitV].append(TFLiteSemantic.constraint_splitv_inferred)
127
128 # StridedSlice specific checks:
129 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_stridedslice_input_count)
130 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_stridedslice_inputs_const)
131 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_ellipsis_mask)
132 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_axis_masks)
133 self.specific_constraints[Op.StridedSlice].append(TFLiteSemantic.constraint_slice_ranges)
134
135 # LeakyRelu specific checks:
136 self.specific_constraints[Op.LeakyRelu].append(TFLiteSemantic.constraint_alpha_valid)
137
138 # FullyConnected specific checks:
139 self.specific_constraints[Op.FullyConnected].append(TFLiteSemantic.constraint_fc_output_2d)
140 self.specific_constraints[Op.FullyConnected].append(TFLiteSemantic.constraint_keep_dim_ifm_ofm)
141
142 # Pad specific checks:
143 self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_input_count)
144 self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_constant)
145
146 # HardSwish specific checks:
147 self.specific_constraints[Op.HardSwish].append(TFLiteSemantic.constraint_input_8bit)
148 self.specific_constraints[Op.HardSwish].append(TFLiteSemantic.constraint_matching_in_out_types)
149 # Mean specific checks:
150 self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_input_8bit)
151 self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_input_dims)
152 self.specific_constraints[Op.Mean].append(TFLiteSemantic.constraint_mean_axis)
153
154 def is_operator_semantic_valid(self, op):
155 ext_type = optype_to_builtintype(op.type)
156
157 if op.type in (Op.Placeholder, Op.SubgraphInput, Op.Const):
158 return True
159
160 for constraint in self.generic_constraints + self.specific_constraints[op.type]:
161 valid, extra = constraint(op)
162 if not valid:
163 print(
Tim Hall3584a9c2021-11-18 22:05:17 +0000164 f"Warning: Unsupported TensorFlow Lite semantics for {ext_type} '{op.name}'. Placing on CPU instead"
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200165 )
166 print(f" - {constraint.__doc__}")
167 if extra:
168 print(f" {extra}")
169 return False
170
171 return True
172
173 @staticmethod
erik.andersson@arm.com3bbbed62021-12-20 14:14:16 +0100174 def constraint_none_const_tensors(op):
175 "Constant tensors should not have NoneType-values"
176 valid = True
177 extra = ""
178 for tens in filter(None, op.inputs):
179 if len(tens.ops) > 0 and tens.ops[0].type == Op.Const and tens.values is None:
180 valid = False
181 extra = str(tens.name)
182 return valid, f"Unexpected None value for constant tensor: {extra}"
183
184 @staticmethod
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200185 def constraint_tens_no_dynamic(op):
186 "Input(s) and Output tensors must not be dynamic"
187 valid = True
188 extra = []
189 tensors = [tens for tens in op.inputs + op.outputs if tens]
190 for tens in tensors:
191 if (tens.shape == []) and (tens.values is None):
192 valid = False
193 extra.append(tens.name)
194 extra = ", ".join(extra)
195 return valid, f"Op has dynamic tensor(s): {extra}"
196
197 @staticmethod
198 def constraint_tens_defined_shape(op):
199 "Input(s) and Output tensors must have a defined shape"
200 valid = True
201 extra = []
202 tensors = [tens for tens in op.inputs + op.outputs if tens]
203 for tens in tensors:
204 if not tens.has_fully_defined_shape():
205 valid = False
206 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
207 return valid, ", ".join(extra)
208
209 @staticmethod
210 def constraint_tens_output_scalar(op):
211 "Output tensors cannot be scalar"
212 ofm = op.ofm
213 valid = ofm.shape != []
214 return valid, f"Output Tensor '{ofm.name}' is scalar"
215
216 @classmethod
217 @docstring_format_args([_optype_formatter(shapeless_input_ops)])
218 def constraint_tens_input_scalar(cls, op):
219 "Scalar Input tensors are only valid for op type: {}"
220 valid = True
221 extra = []
222 tensors = [tens for tens in op.inputs if tens]
223 for tens in tensors:
224 if (tens.shape == []) and (op.type not in cls.shapeless_input_ops):
225 valid = False
226 extra.append(tens.name)
227 extra = ", ".join(extra)
228 return valid, f"Op has scalar input tensor(s): {extra}"
229
230 @staticmethod
231 def constraint_tens_shape_size(op):
232 "Input(s) and Output tensors must not be greater than 4D"
233 valid = True
234 extra = []
235 tensors = [tens for tens in op.inputs + op.outputs if tens]
236 for tens in tensors:
237 if len(tens.shape) > 4:
238 valid = False
239 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
240 return valid, ", ".join(extra)
241
242 @staticmethod
243 def constraint_tens_quant_none_check(op):
244 "Input(s), Output and Weight tensors must have quantization parameters"
245 valid = True
246 extra = []
247 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
248 for tens in tensors:
249 if tens.quantization is None:
250 valid = False
251 extra.append(tens.name)
252 extra = ", ".join(extra)
253 return valid, f"Op has tensors with missing quantization parameters: {extra}"
254
255 @staticmethod
256 def constraint_tens_quant_scale(op):
257 "Input(s), Output and Weight tensors with quantization scales must be finite"
258 valid = True
259 extra = []
260 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
261 for tens in tensors:
262 if (tens.quantization.scale_f32 is not None) and np.isinf(tens.quantization.scale_f32).any():
263 valid = False
264 extra.append(f"Tensor '{tens.name}' has quantization scale: {tens.quantization.scale_f32}")
265 return valid, ", ".join(extra)
266
267 @staticmethod
268 def constraint_fc_output_2d(op):
269 "The output tensor(s) must have 2D shape"
270 valid = True
271 extra = []
272 for tens in op.outputs:
273 if len(tens.shape) != 2:
274 valid = False
275 extra.append(f"Tensor '{tens.name}' is {len(tens.shape)}D")
276 return valid, ", ".join(extra)
277
278 @staticmethod
279 def constraint_stride_type(op):
280 "Stride values for both width and height must be integer types"
281 w, h = op.get_kernel_stride()
282 valid = is_integer(w) and is_integer(h)
283 return valid, f"Op has stride WxH as: {repr(w)}x{repr(h)}"
284
285 @staticmethod
286 def constraint_dilation_type(op):
287 "Dilation factor values for both width and height must be integer types"
288 w, h = op.get_kernel_dilation()
289 valid = is_integer(w) and is_integer(h)
290 return valid, f"Op has dilation factor WxH as: {repr(w)}x{repr(h)}"
291
292 @staticmethod
293 def constraint_quant_scale_inf(op):
294 "Input and Output tensors must have quantization scales that fit within float32 precision"
295 if op.ofm is not None and op.ofm.is_quantized():
296 ofm_scale = op.ofm.quantization.scale_f32
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200297 if np.any(ofm_scale < np.finfo(np.float32).tiny):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200298 return (
299 False,
300 f"The quantization scale of the output tensor is {ofm_scale}, "
301 + f"minimum supported is: {np.finfo(np.float32).tiny}",
302 )
303 if op.ifm is not None and op.ifm.is_quantized():
304 ifm_scale = op.ifm.quantization.scale_f32
Dwight Lidman4caf29d2021-10-08 14:26:54 +0200305 if np.any(np.isinf(ifm_scale / ofm_scale)):
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200306 return (
307 False,
308 f"IFM scale divided by OFM scale is infinite, ifm_scale={ifm_scale} ofm_scale={ofm_scale}",
309 )
310 return True, "Op's quantization is ok"
311
312 @staticmethod
313 def constraint_matching_in_out_types(op):
314 "IFM and OFM data types must match"
315 ifm_dtype = op.ifm.dtype
316 ofm_dtype = op.ofm.dtype
317 valid = ifm_dtype == ofm_dtype
318 return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
319
320 @staticmethod
321 def constraint_beta_value_range(op):
322 "Beta value needs to be positive"
323 beta = op.attrs.get("beta", 1.0)
324 valid = beta >= 0
325 return valid, f"Op has beta={beta}"
326
327 @staticmethod
328 def constraint_filter_type(op):
329 "Kernel filter values for both width and height must be integer types"
330 w = op.kernel.width
331 h = op.kernel.height
332 valid = is_integer(w) and is_integer(h)
333 return valid, f"Op has kernel filter WxH as: {repr(w)}x{repr(h)}"
334
335 @staticmethod
336 def constraint_matching_shapes(op):
337 "IFM and OFM shapes must match"
338 ifm_shape = op.ifm.shape
339 ofm_shape = op.ofm.shape
340 valid = ifm_shape == ofm_shape
341 return valid, f"Op has ifm_shape={ifm_shape} and ofm_shape={ofm_shape}"
342
343 @staticmethod
344 def constraint_splitv_inferred(op):
345 "Only one size is allowed to be inferred"
346 sizes = op.inputs[1].values
347 valid = np.count_nonzero(sizes == -1) <= 1
348 return valid, f"Op has multiple inferred sizes (-1): {sizes}"
349
350 @staticmethod
351 def constraint_axis_exists(op):
352 "Axis attribute must exist"
353 axis = op.attrs.get("axis")
354 valid = axis is not None
355 return valid, f"Op has axis={axis}"
356
357 @staticmethod
358 def constraint_axis_valid(op):
359 "Axis attribute must be in the range [0, <ofm_dimensions>)"
360 dims = len(op.ofm.shape)
361 axis = op.attrs["axis"]
362 axis += dims if axis < 0 else 0
363 valid = 0 <= axis < dims
364 return valid, f"Op has ofm_dimensions={dims} and axis attribute is: {axis}"
365
366 @staticmethod
367 def constraint_matching_dimensionality(op):
368 "All Input dimensionalities must match OFM dimensionality"
369 valid = True
370 extra = []
371 ofm_dim = len(op.ofm.shape)
372 tensors = [tens for tens in op.inputs if tens]
373 for tens in tensors:
374 dim = len(tens.shape)
375 if dim != ofm_dim:
376 valid = False
377 extra.append(f"Tensor '{tens.name}' has dimension: {dim}")
378 extra = ", ".join(extra)
379 return valid, f"Op has ofm_dimension={ofm_dim} and the list of mismatching inputs are: {extra}"
380
381 @staticmethod
382 def constraint_valid_dimensions(op):
383 "All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute"
384 valid = True
385 extra = []
386 ofm_shape = op.ofm.shape
387 ofm_dim = len(ofm_shape)
388 axis = op.attrs["axis"]
389 axis += ofm_dim if axis < 0 else 0
390 tensors = [tens for tens in op.inputs if tens]
391 for tens in tensors:
392 if any(tens.shape[dim] != ofm_shape[dim] for dim in range(ofm_dim) if dim != axis):
393 valid = False
394 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
395 extra = ", ".join(extra)
396 return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
397
398 @staticmethod
399 def constraint_stridedslice_input_count(op):
400 "Exactly 4 Input tensors are required"
401 inputs = len(op.inputs)
402 valid = inputs == 4
403 return valid, f"Op has {inputs} inputs"
404
405 @staticmethod
406 def constraint_pad_input_count(op):
407 "Number of input tensors must be exactly 2"
408 inputs = len(op.inputs)
409 valid = inputs == 2
410 return valid, f"Op has {inputs} inputs"
411
412 @staticmethod
413 def constraint_pad_constant(op):
414 "The padding tensor must be constant"
415 pad_tensor = op.inputs[1].values
416 valid = pad_tensor is not None
417 return valid, f"Op has non-constant padding tensor: {op.inputs[1].values}"
418
419 @staticmethod
420 def constraint_stridedslice_inputs_const(op):
421 "Begin, End and Stride Input tensors must be constant"
422 valid = True
423 extra = []
424 _, begin, end, strides = op.inputs
425 if begin.values is None:
426 valid = False
427 extra.append(f"Begin tensor '{begin.name}'")
428 if end.values is None:
429 valid = False
430 extra.append(f"End tensor '{end.name}'")
431 if strides.values is None:
432 valid = False
433 extra.append(f"Stride tensor '{strides.name}'")
434 extra = ", ".join(extra)
435 return valid, f"Op has non-constant tensors: {extra}"
436
437 @staticmethod
438 def constraint_ellipsis_mask(op):
439 "ellipsis_mask must be 0"
440 ellipsis = op.attrs["ellipsis_mask"]
441 valid = ellipsis == 0
442 return valid, f"Op has ellipsis mask as: {ellipsis}"
443
444 @staticmethod
445 def constraint_axis_masks(op):
446 "new_axis_mask and shrink_axis_mask cannot both be set"
447 new_axis = op.attrs["new_axis_mask"]
448 shrink_axis = op.attrs["shrink_axis_mask"]
449 valid = (new_axis == 0) or (shrink_axis == 0)
450 return valid, f"Op has new_axis_mask={new_axis} and shrink_axis_mask={shrink_axis}"
451
452 @staticmethod
453 def constraint_slice_ranges(op):
454 "Slice 'end' values must be greater than 'begin' values"
455 ifm, begin, end, _ = op.inputs
456 # Calculate offset begin/end
457 offset_begin = get_slice_offsets(ifm.shape, begin, op.attrs["begin_mask"], is_begin=True)
458 offset_end = get_slice_offsets(ifm.shape, end, op.attrs["end_mask"], is_begin=False)
459 # Check "end - begin" doesn't result in any zero or negative elements
460 valid = all((e - b) > 0 for b, e in zip(offset_begin, offset_end))
461 return valid, f"Op has begin_values={begin.values} and end_values={end.values}"
462
463 @staticmethod
464 def constraint_matching_inputs_types(op):
465 "Both Input data types must match"
466 ifm_dtype = op.ifm.dtype
467 ifm2_dtype = op.ifm2.dtype
468 valid = ifm_dtype == ifm2_dtype
469 return valid, f"Op has ifm_dtype={ifm_dtype} and ifm2_dtype={ifm2_dtype}"
470
471 @staticmethod
472 def constraint_matching_signed(op):
473 "For IFM that are signed, OFM must also be signed"
474 valid = True
475 ifm_dtype = op.ifm.dtype
476 ofm_dtype = op.ofm.dtype
477 if ifm_dtype.type & BaseType.Signed:
478 valid = bool(ofm_dtype.type & BaseType.Signed)
479 return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
480
481 @staticmethod
482 def constraint_unsigned_valid(op):
483 "For IFM that are unsigned, OFM must either be the same type or int32"
484 valid = True
485 ifm_dtype = op.ifm.dtype
486 ofm_dtype = op.ofm.dtype
487 if ifm_dtype.type & BaseType.Unsigned:
488 valid = (ifm_dtype == ofm_dtype) or (ofm_dtype == DataType.int32)
489 return valid, f"Op has ifm_dtype={ifm_dtype} and ofm_dtype={ofm_dtype}"
490
491 @staticmethod
492 def constraint_input_8bit(op):
493 "IFM must be int8 or uint8"
494 ifm_dtype = op.ifm.dtype
495 valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.uint8)
496 return valid, f"Op has ifm_dtype={ifm_dtype}"
497
498 @staticmethod
499 def constraint_matching_either_shapes(op):
500 "At least one Input's shape must match the OFM's shape"
501 ifm_shape = op.ifm.shape
502 ifm2_shape = op.ifm2.shape if op.ifm2 else None
503 ofm_shape = op.ofm.shape
504 valid = (ifm_shape == ofm_shape) or (ifm2_shape == ofm_shape)
505 return valid, f"Op has ifm_shape={ifm_shape}, ifm2_shape={ifm2_shape} and ofm_shape={ofm_shape}"
506
507 @staticmethod
508 def constraint_alpha_valid(op):
509 "Alpha must not be negative"
510 alpha = op.attrs["alpha"]
511 valid = alpha >= 0
512 return valid, f"Op has alpha={alpha}"
513
514 @staticmethod
515 def constraint_keep_dim_ifm_ofm(op):
516 "The IFM and OFM must have the same number of dimensions if keep_num_dims is set to true"
517 valid = True
518 if op.attrs.get("keep_num_dims"):
519 valid = len(op.ifm.shape) == len(op.ofm.shape)
520 return valid, f"Op has ifm shape={op.ifm.shape} and ofm shape={op.ofm.shape}"
521
522 @staticmethod
523 def constraint_mean_input_dims(op):
524 "Input tensor must be at least 2D"
525 dims = len(op.inputs[0].shape)
526 return 2 <= dims <= 4, f"Input is {dims}D"
527
528 @staticmethod
529 def constraint_mean_axis(op):
530 "Axis indices must correspond to height and width axes"
531 dims = len(op.inputs[0].shape)
532 axis = int(op.inputs[1].values) if op.inputs[1].shape == [] else list(op.inputs[1].values)
533 if dims == 2 or dims == 3:
534 valid = axis in (0, 1, [0], [1], [0, 1], [1, 0])
535 elif dims == 4:
536 valid = axis in (1, 2, [1], [2], [1, 2], [2, 1])
537 return valid, f"Axis is {axis}"
538
Jonas Ohlsson0957e3e2021-09-01 15:57:21 +0200539 @staticmethod
540 def constraint_matching_in_out_quant(op):
541 "Input and output quantisation must match."
542 if not check_quantized_tens_scaling_equal(op.ifm, op.ofm):
543 return False, "IFM and OFM quantisation parameters are not equal."
544 return True, "IFM and OFM quantisation parameters matches."
545
Jonas Ohlsson45e653d2021-07-26 16:13:12 +0200546
547def tflite_semantic_checker(nng):
548 semantic_checker = TFLiteSemantic()
549 for sg in nng.subgraphs:
550 for op in sg.get_all_ops():
551 op.run_on_npu = semantic_checker.is_operator_semantic_valid(op)
552 return nng