blob: f4dd5796f6b587c8a5959c22d34a2c45eed16fd0 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# The SupportedOperators class which is a collection of all supported operators and parameter checks.
Michael McGeagh1f951fc2020-10-14 09:30:02 +010018from collections import defaultdict
19
Charles Xu87c13502020-08-06 12:17:26 +020020import numpy as np
21
Tim Hallc30f4952020-06-15 20:47:35 +010022from .data_type import BaseType
23from .data_type import DataType
Dwight Lidman8359a472020-09-28 15:53:40 +020024from .numeric_util import is_integer
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020025from .operation import get_slice_offsets
Louis Verhaardaee5d752020-09-30 09:01:52 +020026from .operation import Op
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020027
28
Michael McGeagh37ded342020-10-01 15:37:44 +010029# Custom decorator function to allow formatting docstrings containing "{}"
30def docstring_format_args(args):
31 def docstring(func):
32 func.__doc__ = func.__doc__.format(*args)
33 return func
34
35 return docstring
36
37
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020038def warn_cpu(op, msg):
39 print("Warning: {} {}, placing on CPU".format(op.type, msg))
Tim Hall79d07d22020-04-27 18:20:16 +010040
41
42class SupportedOperators:
Michael McGeagh1eeea512020-09-30 14:23:09 +010043 # Categorised lists of supported operators
Louis Verhaardaee5d752020-09-30 09:01:52 +020044 npu_pre_ops = set((Op.SplitSliceRead,))
45 convolution_ops = set((Op.Conv2DBias, Op.Conv2D, Op.QuantizedConv2D,))
46 depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
47 transpose_convolution_ops = set((Op.Conv2DBackpropInput,))
Michael McGeagh1f951fc2020-10-14 09:30:02 +010048 convolution_like_ops = convolution_ops | depthwise_convolution_ops | transpose_convolution_ops
Louis Verhaardaee5d752020-09-30 09:01:52 +020049 max_pooling_ops = Op.op_set(Op.is_maxpool_op)
50 avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
51 pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
52 resizing_ops = set((Op.ResizeBilinear,))
53 fc_vector_products = set((Op.QuantizedMatMul, Op.MatMul, Op.FullyConnected,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010054 mac_main_ops = (
55 # RNN/LSTM/GRU
Louis Verhaardaee5d752020-09-30 09:01:52 +020056 set((Op.BlockLSTM,))
Michael McGeagh1f951fc2020-10-14 09:30:02 +010057 # conv/depthwiseconv/transposeconv
58 | convolution_like_ops
Michael McGeagh1eeea512020-09-30 14:23:09 +010059 # pooling
60 | pooling_ops
61 # resizing/upscaling
62 | resizing_ops
63 # FC layers
64 | fc_vector_products
65 )
Louis Verhaardaee5d752020-09-30 09:01:52 +020066 unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
67 binary_elem_wise_min_max_ops = set((Op.Minimum, Op.Maximum,))
68 binary_elem_wise_shift_ops = set((Op.SHL, Op.SHR,))
69 binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.Sub,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010070 binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
71 elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
Michael McGeagh37ded342020-10-01 15:37:44 +010072 supported_int32_tensor_ops = (
Louis Verhaardaee5d752020-09-30 09:01:52 +020073 set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
Michael McGeagh37ded342020-10-01 15:37:44 +010074 )
Louis Verhaardaee5d752020-09-30 09:01:52 +020075 activation_ops = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Sigmoid, Op.Tanh, Op.Softmax,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010076 npu_post_ops = (
Michael McGeagh1eeea512020-09-30 14:23:09 +010077 # activation functions
Louis Verhaardaee5d752020-09-30 09:01:52 +020078 activation_ops
79 # concatenation write direction
80 | set((Op.ConcatSliceWrite,))
81 # Quantization
82 | set((Op.Quantize,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010083 )
Louis Verhaardaee5d752020-09-30 09:01:52 +020084 split_ops = set((Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack,))
85 concat_ops = set((Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack,))
86 memory_only_ops = set((Op.Squeeze, Op.Reshape, Op.QuantizedReshape, Op.ExpandDims,)) | concat_ops | split_ops
87 shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV,))
88 supported_fused_activations = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Tanh, Op.Sigmoid, Op.LUT,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010089 supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | npu_post_ops | memory_only_ops
Michael McGeagh1f951fc2020-10-14 09:30:02 +010090 # Supported data types
91 supported_op_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
92 supported_bias_dtypes = set((DataType.int32, DataType.int64))
Michael McGeagh37ded342020-10-01 15:37:44 +010093 # Defined ranges for allowed values:
94 tens_dim_range = (1, 65535)
Michael McGeagh1f951fc2020-10-14 09:30:02 +010095 stride_range = (1, 3)
96 dilation_range = (1, 2)
97 dilated_height_range = (1, 64)
98 dilated_product_range = (1, 64 * 64)
99 weights_limit = 127 * 65536
Michael McGeagh1eeea512020-09-30 14:23:09 +0100100
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200101 def __init__(self):
Tim Hall79d07d22020-04-27 18:20:16 +0100102 # Setup supported operator restriction checkers
103 self.supported_operator_restrictions = {}
104 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100105 {op: self.check_depthwise_convolution_restrictions for op in SupportedOperators.depthwise_convolution_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100106 )
Jacob Bohlincf7da102020-05-20 09:03:40 +0200107 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100108 {op: self.check_transpose_convolution_restrictions for op in SupportedOperators.transpose_convolution_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100109 )
110 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100111 {op: self.check_pooling_restrictions for op in SupportedOperators.pooling_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100112 )
113 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100114 {op: self.check_resize_restrictions for op in SupportedOperators.resizing_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100115 )
Michael McGeagh1eeea512020-09-30 14:23:09 +0100116 self.supported_operator_restrictions.update(
117 {op: self.check_vector_product_restrictions for op in SupportedOperators.fc_vector_products}
118 )
119 self.supported_operator_restrictions.update(
120 {op: self.check_element_wise_restrictions for op in SupportedOperators.elem_wise_main_ops}
121 )
122 self.supported_operator_restrictions.update(
123 {op: self.check_memory_only_restrictions for op in SupportedOperators.memory_only_ops}
124 )
125 self.supported_operator_restrictions.update(
126 {op: self.check_activation_ops for op in SupportedOperators.activation_ops}
127 )
Michael McGeagh184b2502020-10-09 17:19:52 +0100128 # Setup the generic constraints. Note: the order matters
Michael McGeagh37ded342020-10-01 15:37:44 +0100129 self.generic_constraints = []
130 self.generic_constraints.append(SupportedOperators.constraint_tens_defined_shape)
Michael McGeagh184b2502020-10-09 17:19:52 +0100131 self.generic_constraints.append(SupportedOperators.constraint_tens_output_shapeless)
132 self.generic_constraints.append(SupportedOperators.constraint_tens_input_shapeless)
Michael McGeagh37ded342020-10-01 15:37:44 +0100133 self.generic_constraints.append(SupportedOperators.constraint_tens_shape_size)
134 self.generic_constraints.append(SupportedOperators.constraint_tens_dtype)
Michael McGeagh184b2502020-10-09 17:19:52 +0100135 self.generic_constraints.append(SupportedOperators.constraint_tens_int32_ops)
Michael McGeagh37ded342020-10-01 15:37:44 +0100136 self.generic_constraints.append(SupportedOperators.constraint_tens_dimension)
Dwight Lidman8359a472020-09-28 15:53:40 +0200137 self.generic_constraints.append(SupportedOperators.constraint_tens_quant_none_check)
Michael McGeagh184b2502020-10-09 17:19:52 +0100138 self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale)
139 self.generic_constraints.append(SupportedOperators.constraint_faf)
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100140 # Setup specific constraints. The key in the dictionary must be a tuple of op types the constraints apply to
141 self.specific_constraints = defaultdict(list)
142 # Conv-like ops have the same checks applied to them:
143 conv_like_ops = tuple(SupportedOperators.convolution_like_ops)
144 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_stride_type)
145 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_stride_range)
146 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilation_type)
147 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilation_range)
148 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilated_height_range)
149 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilated_product_range)
150 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_weights_type)
151 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_weights_nonconst)
152 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_weights_limit)
153 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_bias_type)
154 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_bias_40bit)
155 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_batch_size)
156
157 def get_constraints_list(self, op_type):
158 constraint_list = list(self.generic_constraints)
159 for ops in self.specific_constraints:
160 if op_type in ops:
161 constraint_list.extend(self.specific_constraints[ops])
162 return constraint_list
Tim Hall79d07d22020-04-27 18:20:16 +0100163
164 def is_operator_supported(self, op):
Michael McGeagh1eeea512020-09-30 14:23:09 +0100165 if op.type not in SupportedOperators.supported_operators:
Louis Verhaard5f2ea2f2020-10-15 08:39:44 +0200166 if op.type not in (Op.Placeholder, Op.SubgraphInput, Op.Const):
167 print("Info: {} '{}' is not supported on the NPU. Placing on CPU instead".format(op.type, op.name))
Tim Hall79d07d22020-04-27 18:20:16 +0100168 return False
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100169
170 for constraint in self.get_constraints_list(op.type):
Michael McGeagh37ded342020-10-01 15:37:44 +0100171 valid, extra = constraint(op)
172 if not valid:
Michael McGeagh184b2502020-10-09 17:19:52 +0100173 print("Warning: {} '{}' is not supported on the NPU. Placing on CPU instead".format(op.type, op.name))
Michael McGeagh37ded342020-10-01 15:37:44 +0100174 print(" - {}".format(constraint.__doc__))
175 if extra:
176 print(" {}".format(extra))
177 return False
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100178
Tim Hall79d07d22020-04-27 18:20:16 +0100179 if op.type in self.supported_operator_restrictions:
180 return self.supported_operator_restrictions[op.type](op)
181 return True
182
Michael McGeagh37ded342020-10-01 15:37:44 +0100183 @staticmethod
184 def constraint_tens_defined_shape(op):
185 "Input(s) and Output Tensors must have a defined shape"
186 valid = True
187 extra = []
Michael McGeagh184b2502020-10-09 17:19:52 +0100188 tensors = [tens for tens in op.inputs + op.outputs if tens]
189 for tens in tensors:
190 if not tens.has_fully_defined_shape():
191 valid = False
192 extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
193 return valid, ", ".join(extra)
Michael McGeagh37ded342020-10-01 15:37:44 +0100194
Michael McGeagh184b2502020-10-09 17:19:52 +0100195 @staticmethod
196 def constraint_tens_output_shapeless(op):
197 "Scalar or Broadcasting Tensors are only valid for Input Tensors"
Michael McGeagh37ded342020-10-01 15:37:44 +0100198 valid = True
199 extra = []
Michael McGeagh37ded342020-10-01 15:37:44 +0100200 for tens in op.outputs:
201 if tens.shape == []:
202 valid = False
Michael McGeagh184b2502020-10-09 17:19:52 +0100203 extra.append("Output Tensor '{}' is shapeless".format(tens.name))
204 return valid, ", ".join(extra)
205
206 @classmethod
207 @docstring_format_args([shapeless_input_ops])
208 def constraint_tens_input_shapeless(cls, op):
209 "Scalar or Broadcasting Input Tensors are only valid for op type: {}"
210 valid = True
211 extra = []
212 tensors = [tens for tens in op.inputs if tens]
213 for tens in tensors:
214 if (tens.shape == []) and (op.type not in cls.shapeless_input_ops):
215 valid = False
216 extra.append(tens.name)
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100217 extra = "Op has shapeless input tensor(s): {}".format(", ".join(extra))
Michael McGeagh184b2502020-10-09 17:19:52 +0100218 return valid, extra
Tim Hall79d07d22020-04-27 18:20:16 +0100219
Michael McGeagh37ded342020-10-01 15:37:44 +0100220 @staticmethod
221 def constraint_tens_shape_size(op):
222 "Input(s) and Output Tensors must not be greater than 4D"
223 valid = True
224 extra = []
Michael McGeagh184b2502020-10-09 17:19:52 +0100225 tensors = [tens for tens in op.inputs + op.outputs if tens]
226 for tens in tensors:
227 if len(tens.shape) > 4:
228 valid = False
229 extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
230 return valid, ", ".join(extra)
Tim Hall79d07d22020-04-27 18:20:16 +0100231
Michael McGeagh37ded342020-10-01 15:37:44 +0100232 @classmethod
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100233 @docstring_format_args([supported_op_dtypes])
Michael McGeagh37ded342020-10-01 15:37:44 +0100234 def constraint_tens_dtype(cls, op):
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100235 "Input(s), Output and Weight Tensors must be of type: {}"
Michael McGeagh37ded342020-10-01 15:37:44 +0100236 valid = True
237 extra = []
238 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
239 tensors = tensors if tensors else op.inputs
240 for tens in tensors:
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100241 if tens.dtype not in cls.supported_op_dtypes:
Michael McGeagh184b2502020-10-09 17:19:52 +0100242 valid = False
243 extra.append("Tensor '{}' has data type: {}".format(tens.name, tens.dtype))
244 return valid, ", ".join(extra)
245
246 @classmethod
247 @docstring_format_args([supported_int32_tensor_ops])
248 def constraint_tens_int32_ops(cls, op):
249 "Tensors which are int32 are only valid when op type is: {}"
250 valid = True
251 extra = []
252 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
253 tensors = tensors if tensors else op.inputs
254 for tens in tensors:
255 if (tens.dtype == DataType.int32) and (op.type not in cls.supported_int32_tensor_ops):
256 valid = False
257 extra.append(tens.name)
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100258 extra = "Op has int32 tensor(s): {}".format(", ".join(extra))
Michael McGeagh184b2502020-10-09 17:19:52 +0100259 return valid, extra
Andreas Nevalaineneadb1662020-09-01 15:36:26 +0200260
Michael McGeagh37ded342020-10-01 15:37:44 +0100261 @classmethod
262 @docstring_format_args(tens_dim_range)
263 def constraint_tens_dimension(cls, op):
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100264 "Tensor dimensions must be in the range [{}, {}]"
Michael McGeagh37ded342020-10-01 15:37:44 +0100265 tens_min, tens_max = cls.tens_dim_range
266 valid = True
267 extra = []
268 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
269 tensors = tensors if tensors else op.inputs
270 for tens in tensors:
Michael McGeagh184b2502020-10-09 17:19:52 +0100271 if not all(tens_min <= dim <= tens_max for dim in tens.shape):
272 valid = False
273 extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
274 return valid, ", ".join(extra)
Tim Hall79d07d22020-04-27 18:20:16 +0100275
Dwight Lidman8359a472020-09-28 15:53:40 +0200276 @staticmethod
277 def constraint_tens_quant_none_check(op):
278 "Tensors must have quantization parameters"
279 valid = True
280 extra = []
281 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
282 for tens in tensors:
283 if tens.quantization is None:
284 valid = False
285 extra.append("Tensor '{}' has no quantization parameters".format(tens.name))
286 return valid, ", ".join(extra)
287
Michael McGeagh184b2502020-10-09 17:19:52 +0100288 @staticmethod
289 def constraint_tens_quant_scale(op):
290 "Tensors with quantization scales must be finite"
291 valid = True
292 extra = []
293 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
294 for tens in tensors:
295 if (tens.quantization.scale_f32 is not None) and np.isinf(tens.quantization.scale_f32).any():
296 valid = False
297 extra.append("Tensor '{}' has quantization scale: {}".format(tens.name, tens.quantization.scale_f32))
298 return valid, ", ".join(extra)
299
300 @classmethod
301 @docstring_format_args([supported_fused_activations])
302 def constraint_faf(cls, op):
303 "The fused activation function (if present) must be one of type: {}"
304 faf = op.activation
305 valid = (faf is None) or (faf in cls.supported_fused_activations)
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100306 extra = "Op has its fused activation function as: {}".format(faf)
307 return valid, extra
308
309 @staticmethod
310 def constraint_stride_type(op):
311 "Stride values for both width and height must be integer types"
312 w = op.attrs["stride_w"]
313 h = op.attrs["stride_h"]
314 valid = is_integer(w) and is_integer(h)
315 extra = "Op has stride WxH as: {}x{}".format(repr(w), repr(h))
Michael McGeagh184b2502020-10-09 17:19:52 +0100316 return valid, extra
317
Michael McGeagh1eeea512020-09-30 14:23:09 +0100318 @classmethod
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100319 @docstring_format_args(stride_range)
320 def constraint_stride_range(cls, op):
321 "Stride values for both width and height must be in the range [{}, {}]"
322 w = op.attrs["stride_w"]
323 h = op.attrs["stride_h"]
324 stride_min, stride_max = cls.stride_range
325 valid = (stride_min <= w <= stride_max) and (stride_min <= h <= stride_max)
326 extra = "Op has stride WxH as: {}x{}".format(w, h)
327 return valid, extra
Tim Hall79d07d22020-04-27 18:20:16 +0100328
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100329 @staticmethod
330 def constraint_dilation_type(op):
331 "Dilation factor values for both width and height must be integer types"
332 w = op.attrs.get("dilation_w_factor", 1)
333 h = op.attrs.get("dilation_h_factor", 1)
334 valid = is_integer(w) and is_integer(h)
335 extra = "Op has dilation factor WxH as: {}x{}".format(repr(w), repr(h))
336 return valid, extra
Tim Hall79d07d22020-04-27 18:20:16 +0100337
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100338 @classmethod
339 @docstring_format_args(dilation_range)
340 def constraint_dilation_range(cls, op):
341 "Dilation factor values for both width and height must be in the range [{}, {}]"
342 w = op.attrs.get("dilation_w_factor", 1)
343 h = op.attrs.get("dilation_h_factor", 1)
344 dilation_min, dilation_max = cls.dilation_range
345 valid = (dilation_min <= w <= dilation_max) and (dilation_min <= h <= dilation_max)
346 extra = "Op has dilation factor WxH as: {}x{}".format(w, h)
347 return valid, extra
Tim Hall79d07d22020-04-27 18:20:16 +0100348
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100349 @classmethod
350 @docstring_format_args(dilated_height_range)
351 def constraint_dilated_height_range(cls, op):
352 "Dilated kernel height must be in the range [{}, {}]"
353 h = (op.weights.shape[0] - 1) * op.attrs.get("dilation_h_factor", 1) + 1
354 dilated_height_min, dilated_height_max = cls.dilated_height_range
355 valid = dilated_height_min <= h <= dilated_height_max
356 extra = "Op has dilated kernel height as: {}".format(h)
357 return valid, extra
Jacob Bohlin49d92122020-08-19 14:36:46 +0200358
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100359 @classmethod
360 @docstring_format_args(dilated_product_range)
361 def constraint_dilated_product_range(cls, op):
362 "Product of dilated kernel width and height must be in the range [{}, {}]"
363 weights = op.weights
364 w = (weights.shape[1] - 1) * op.attrs.get("dilation_w_factor", 1) + 1
365 h = (weights.shape[0] - 1) * op.attrs.get("dilation_h_factor", 1) + 1
366 product = w * h
367 dilated_product_min, dilated_product_max = cls.dilated_product_range
368 valid = dilated_product_min <= product <= dilated_product_max
369 extra = "Op has product of dilated kernel width and height as: {}".format(product)
370 return valid, extra
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200371
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100372 @staticmethod
373 def constraint_weights_type(op):
374 "Weight Tensor must be 8-bit"
375 weights = op.weights
376 valid = weights.element_size() == 1
377 extra = "Tensor '{}' is {}-bit".format(weights.name, int(weights.element_size() * 8))
378 return valid, extra
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200379
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100380 @staticmethod
381 def constraint_weights_nonconst(op):
382 "Weight tensor cannot be non-constant"
383 weights = op.weights
384 valid = weights.values is not None
385 extra = "Tensor '{}' has non-constant values".format(weights.name)
386 return valid, extra
Andreas Nevalainen8854dc92020-09-24 13:43:00 +0200387
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100388 @classmethod
389 @docstring_format_args([weights_limit])
390 def constraint_weights_limit(cls, op):
391 "The sum of the weights cannot exceed {}"
392 weights = op.weights
393 values = weights.quant_values.astype(np.int64) - weights.quantization.zero_point
394 limit = np.amax(np.sum(np.absolute(values), axis=(0, 1, 2)))
395 valid = limit <= cls.weights_limit
396 extra = "Tensor '{}' has the sum of weights: {}".format(weights.name, limit)
397 return valid, extra
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200398
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100399 @classmethod
400 @docstring_format_args([supported_bias_dtypes])
401 def constraint_bias_type(cls, op):
402 "Optional Bias Tensor must be of type: {}"
403 valid = True
404 extra = ""
405 bias = op.bias
406 if bias:
407 valid = bias.dtype in cls.supported_bias_dtypes
408 extra = "Tensor '{}' has data type: {}".format(bias.name, bias.dtype)
409 return valid, extra
Tim Hall79d07d22020-04-27 18:20:16 +0100410
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100411 @staticmethod
412 def constraint_bias_40bit(op):
413 "Optional Bias Tensor values must fit within 40-bits"
414 valid = True
415 extra = ""
416 bias = op.bias
417 if bias and bias.dtype == DataType.int64:
418 valid = all(len(bin(quant_value)[2:]) <= 40 for quant_value in bias.quant_values)
419 extra = "Tensor '{}' has values larger than 40-bits".format(bias.name)
420 return valid, extra
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200421
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100422 @staticmethod
423 def constraint_batch_size(op):
424 "IFM Tensor batch size must be 1"
425 ifm = op.ifm
426 valid = ifm.shape[0] == 1
427 extra = "Tensor '{}' has batch size: {}".format(ifm.name, ifm.shape[0])
428 return valid, extra
Tim Hall79d07d22020-04-27 18:20:16 +0100429
Michael McGeagh1eeea512020-09-30 14:23:09 +0100430 @classmethod
431 def check_depthwise_convolution_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100432 # check depth
Louis Verhaardaee5d752020-09-30 09:01:52 +0200433 ifm_tensor, ofm_tensor = op.get_ifm_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100434 if op.attrs["depth_multiplier"] > 1 and not (
435 (ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"])
436 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200437 print(
438 "Warning: for depth multipliers > 1,",
439 "number of input channels must be 1 and number of output channels must be equal to depth multiplier.",
440 "Placing on CPU",
441 )
Tim Hall79d07d22020-04-27 18:20:16 +0100442 return False
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100443 return True
Tim Hall79d07d22020-04-27 18:20:16 +0100444
Michael McGeagh1eeea512020-09-30 14:23:09 +0100445 @classmethod
446 def check_transpose_convolution_restrictions(cls, op):
Jacob Bohlincf7da102020-05-20 09:03:40 +0200447 # check stride
448 stride_h, stride_w = op.attrs["stride_h"], op.attrs["stride_w"]
Dwight Lidman8359a472020-09-28 15:53:40 +0200449 if stride_h != 2 or stride_w != 2:
450 print("Warning: stride must be equal to 2, placing on CPU")
Jacob Bohlincf7da102020-05-20 09:03:40 +0200451 return False
452
453 # check output dimensions
454 ifm_tensor, weight_tensor, _, ofm_tensor = op.get_ifm_weights_biases_ofm()
455 ifm_h, ifm_w = ifm_tensor.shape[1], ifm_tensor.shape[2]
456 ofm_h, ofm_w = ofm_tensor.shape[1], ofm_tensor.shape[2]
457 if op.attrs["padding"] == b"SAME":
458 if (ofm_h != ifm_h * stride_h) or (ofm_w != ifm_w * stride_w):
Dwight Lidman8359a472020-09-28 15:53:40 +0200459 print(
460 "Warning: for",
461 op.type,
462 "using SAME padding, output dimensions must equal input dimensions multiplied by stride.",
463 "Placing on CPU",
464 )
Jacob Bohlincf7da102020-05-20 09:03:40 +0200465 return False
466 elif op.attrs["padding"] == b"VALID":
467 kernel_h, kernel_w = weight_tensor.shape[0], weight_tensor.shape[1]
Tim Hallc30f4952020-06-15 20:47:35 +0100468 if (ofm_h != (ifm_h) * stride_h + max(kernel_h - stride_h, 0)) or (
469 ofm_w != (ifm_w) * stride_w + max(kernel_w - stride_w, 0)
470 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200471 print(
472 "Warning: for",
473 op.type,
474 "using VALID padding, output dimensions must equal input dimensions multiplied by stride,",
475 "minus difference between kernel size and stride. Placing on CPU",
476 )
Jacob Bohlincf7da102020-05-20 09:03:40 +0200477 return False
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100478 return True
Jacob Bohlincf7da102020-05-20 09:03:40 +0200479
Michael McGeagh1eeea512020-09-30 14:23:09 +0100480 @classmethod
481 def check_pooling_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100482 # check stride
Dwight Lidman8359a472020-09-28 15:53:40 +0200483 stride_w, stride_h = op.attrs["stride_w"], op.attrs["stride_h"]
484 if not is_integer(stride_w) or not is_integer(stride_h):
485 print("Warning:", op.type, "has non-integer stride, placing on CPU")
486 return False
487 if not 1 <= stride_w <= 3 or not 1 <= stride_h <= 3:
488 print(
489 "Warning: {} has stride ({}, {}), only strides in range [1, 3] are allowed. Placing on CPU".format(
490 op.type, stride_w, stride_h
491 )
492 )
Tim Hall79d07d22020-04-27 18:20:16 +0100493 return False
494
495 # check data type
Louis Verhaardaee5d752020-09-30 09:01:52 +0200496 ifm_tensor, ofm_tensor = op.get_ifm_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100497 if ifm_tensor.dtype != ofm_tensor.dtype:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200498 if op.type != Op.ReduceSum:
Dwight Lidman8359a472020-09-28 15:53:40 +0200499 print("Warning: input data type doesn't match output data type, placing on CPU")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200500 return False
501 # TODO: else check ReduceSum restrictions.
Tim Hall79d07d22020-04-27 18:20:16 +0100502
503 # check batch size
504 if ifm_tensor.shape[0] != 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200505 print("Warning: input batch size must be 1, placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100506 return False
507
Dwight Lidman8359a472020-09-28 15:53:40 +0200508 # check kernel size
509 kernel_w, kernel_h = op.attrs["filter_width"], op.attrs["filter_height"]
510 if op.type in cls.avg_pooling_ops and op.attrs["padding"] == b"SAME":
511 if not 1 <= kernel_w <= 8 or not 1 <= kernel_h <= 8:
512 print(
513 "Warning:",
514 op.type,
515 "has kernel size ({}, {}), only kernel sizes in range [1, 8] are allowed. Placing on CPU".format(
516 kernel_w, kernel_h
517 ),
518 )
Tim Hall79d07d22020-04-27 18:20:16 +0100519 return False
Dwight Lidman8359a472020-09-28 15:53:40 +0200520 if op.type in cls.avg_pooling_ops and op.attrs["padding"] == b"VALID" or op.type in cls.max_pooling_ops:
521 if not 1 <= kernel_w * kernel_h <= 256 * 256:
522 print(
523 "Warning: product of kernel width and height must be >= 1 and not exceed 256 * 256 ({}),".format(
524 256 * 256
525 ),
526 "placing on CPU",
527 )
528 return False
529 if not 1 <= kernel_h <= 256:
530 print("Warning:", op.type, "has kernel height outside of range [1, 256], placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100531 return False
532
Tim Hall79d07d22020-04-27 18:20:16 +0100533 return True
534
Michael McGeagh1eeea512020-09-30 14:23:09 +0100535 @classmethod
536 def check_resize_restrictions(cls, op):
Dwight Lidman42fed942020-05-29 09:37:03 +0200537 # check unsupported upscaling factor
Louis Verhaardaee5d752020-09-30 09:01:52 +0200538 if op.type == Op.ResizeBilinear:
Charles Xu9a03fdf2020-07-02 15:12:40 +0200539 if op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1:
540 return True
Charles Xu36ffaf32020-08-05 15:40:44 +0200541 if op.inputs[0].shape == op.outputs[0].shape:
542 return True
Charles Xu87c13502020-08-06 12:17:26 +0200543 upscaled_shape = np.array(op.inputs[0].shape[1:3])
544 out_shape = np.array(op.outputs[0].shape[1:3])
545 while (upscaled_shape < out_shape).all():
546 upscaled_shape *= 2
547 if op.attrs["align_corners"]:
548 upscaled_shape -= 1
549 if np.array_equal(out_shape, upscaled_shape):
550 return True
551 return False
Dwight Lidman42fed942020-05-29 09:37:03 +0200552
Michael McGeagh1eeea512020-09-30 14:23:09 +0100553 @classmethod
554 def check_vector_product_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100555 # check data type
Dwight Lidman8359a472020-09-28 15:53:40 +0200556 ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100557 if weight_tensor.element_size() > 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200558 print("Warning: only 8-bit datatypes supported for {}, placing on CPU".format(op.type))
559 return False
560
Michael McGeagh1eeea512020-09-30 14:23:09 +0100561 if not cls.check_bias_restrictions(bias_tensor):
Jacob Bohlin49d92122020-08-19 14:36:46 +0200562 return False
563
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200564 # check non const weights
565 if weight_tensor.values is None:
566 print("Warning:", op.type, "has non-const weights, placing on CPU")
567 return False
568
Tim Hall79d07d22020-04-27 18:20:16 +0100569 return True
570
Michael McGeagh1eeea512020-09-30 14:23:09 +0100571 @classmethod
572 def check_element_wise_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100573 # check data type
574 ifm_tensor, ifm2_tensor, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200575 # input and output datatype must match for these operators
Tim Hallc30f4952020-06-15 20:47:35 +0100576 if (
Michael McGeagh1eeea512020-09-30 14:23:09 +0100577 op.type in cls.binary_elem_wise_min_max_ops | cls.unary_elem_wise_main_ops
Tim Hallc30f4952020-06-15 20:47:35 +0100578 and ifm_tensor.dtype != ofm_tensor.dtype
579 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200580 print("Warning:", op.type, "must have same input and output datatype, placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100581 return False
Michael McGeagh1eeea512020-09-30 14:23:09 +0100582 if op.type in cls.binary_elem_wise_add_mul_sub:
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200583 # both inputs must have same type
Tim Hallc30f4952020-06-15 20:47:35 +0100584 if ifm_tensor.dtype != ifm2_tensor.dtype:
Dwight Lidman8359a472020-09-28 15:53:40 +0200585 print("Warning:", op.type, "must have same datatype on both inputs, placing on CPU")
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200586 return False
587 # signed input check
Tim Hallc30f4952020-06-15 20:47:35 +0100588 if ifm_tensor.dtype.type & BaseType.Signed:
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200589 # output must be signed
Tim Hallc30f4952020-06-15 20:47:35 +0100590 if ofm_tensor.dtype.type & BaseType.Unsigned:
Dwight Lidman8359a472020-09-28 15:53:40 +0200591 print("Warning: only signed output types supported for {}, placing on CPU".format(op.type))
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200592 return False
593 # and 8, 16 or 32-bit
Dwight Lidman8359a472020-09-28 15:53:40 +0200594 bit_lengths = {8, 16, 32}
595 if ofm_tensor.element_size() * 8 not in bit_lengths:
596 print(
597 "Warning:", op.type, "is only supported for bit lengths {}, placing on CPU".format(bit_lengths)
598 )
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200599 return False
600 # unsigned input check, output must be same type or int32
Tim Hallc30f4952020-06-15 20:47:35 +0100601 if ifm_tensor.dtype.type & BaseType.Unsigned and not (
602 ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32
603 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200604 print("Warning:", op.type, "has unsigned input but output is not unsigned or int32, placing on CPU")
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200605 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200606 elif op.type in cls.binary_elem_wise_shift_ops:
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200607 if ifm_tensor.dtype != DataType.int32 or ifm2_tensor.dtype != DataType.int32:
Dwight Lidman8359a472020-09-28 15:53:40 +0200608 print("Warning:", op.type, "input datatypes are not int32, placing on CPU")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200609 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200610 if op.type in (Op.CLZ, Op.SHL) and ofm_tensor.dtype != DataType.int32:
Dwight Lidman8359a472020-09-28 15:53:40 +0200611 print("Warning:", op.type, "output datatype is not int32, placing on CPU")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200612 return False
Tim Hall79d07d22020-04-27 18:20:16 +0100613
614 # check batch size
Dwight Lidmanf995db72020-04-27 11:15:12 +0200615 if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200616 print(
617 "Warning:",
618 op.type,
619 "only supports batch size 1 for tensors with more than 2 dimensions, placing on CPU",
620 )
Tim Hallc30f4952020-06-15 20:47:35 +0100621 return False
Michael McGeagh1eeea512020-09-30 14:23:09 +0100622 if op.type in cls.binary_elem_wise_main_ops: # if op type is unary, ifm2_tensor is None
Dwight Lidmanf995db72020-04-27 11:15:12 +0200623 if len(ifm2_tensor.shape) > 2 and ifm2_tensor.shape[0] != 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200624 print(
625 "Warning:",
626 op.type,
627 "only supports batch size 1 for tensors with more than 2 dimensions, placing on CPU",
628 )
Dwight Lidmanf995db72020-04-27 11:15:12 +0200629 return False
Dwight Lidman332a7042020-06-11 15:32:42 +0200630
631 # negative alpha values are not supported
Louis Verhaardaee5d752020-09-30 09:01:52 +0200632 if op.type == Op.LeakyRelu and op.attrs["alpha"] < 0:
Dwight Lidman8359a472020-09-28 15:53:40 +0200633 print("Warning:", op.type, "has negative alpha, placing on CPU")
Dwight Lidman332a7042020-06-11 15:32:42 +0200634 return False
635
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200636 # check if ifm or ifm2 has ofm shape
637 if ifm_tensor.shape != ofm_tensor.shape and ifm2_tensor.shape != ofm_tensor.shape:
Dwight Lidman8359a472020-09-28 15:53:40 +0200638 print("Warning:", op.type, "input shape(s) differ from output shape, placing on CPU")
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200639 return False
640
Michael McGeagh1eeea512020-09-30 14:23:09 +0100641 if op.type in cls.binary_elem_wise_min_max_ops and not cls.check_quantization_restrictions_binary_elem_wise(op):
Patrik Gustavsson530992a2020-09-30 13:26:59 +0200642 return False
643
Tim Hall79d07d22020-04-27 18:20:16 +0100644 return True
645
Michael McGeagh1eeea512020-09-30 14:23:09 +0100646 @classmethod
647 def check_memory_only_restrictions(cls, op):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200648 if op.type == Op.StridedSlice:
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200649 if len(op.inputs) != 4:
650 warn_cpu(op, "has {} input tensors, only 4 inputs are supported".format(len(op.inputs)))
Tim Hall79d07d22020-04-27 18:20:16 +0100651 return False
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200652 input_tens, begin_tens, end_tens, strides_tens = op.inputs
653 if begin_tens.values is None or end_tens.values is None or strides_tens.values is None:
654 warn_cpu(op, "has a non-constant begin, end, or stride input tensor, which is not supported")
655 return False
656 if not (
657 len(input_tens.shape)
658 == len(op.outputs[0].shape)
659 == len(begin_tens.values)
660 == len(end_tens.values)
661 == len(strides_tens.values)
662 ):
663 warn_cpu(op, "has input tensors with shapes that are not supported")
664 return False
665 # check stride size
666 if any(stride != 1 for stride in strides_tens.values):
667 warn_cpu(op, "has stride values {}, only stride 1 values are supported".format(strides_tens.values))
Michael McGeaghecd20522020-07-31 16:59:45 +0100668 return False
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200669 # check ellipsis_mask
670 if op.attrs["ellipsis_mask"] != 0:
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200671 warn_cpu(op, "ellipsis_mask is {}, only 0 is supported".format(op.attrs["ellipsis_mask"]))
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200672 return False
673 # check if both new_axis_mask and shrink_axis_mask have bit set
674 if op.attrs["new_axis_mask"] != 0 and op.attrs["shrink_axis_mask"] != 0:
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200675 warn_cpu(op, "new_axis_mask and shrink_axis_mask are both non-zero, which is not supported")
676 return False
677 # Calculate offset start/end
678 offset_start = get_slice_offsets(input_tens.shape, begin_tens, op.attrs["begin_mask"], is_begin=True)
679 offset_end = get_slice_offsets(input_tens.shape, end_tens, op.attrs["end_mask"], is_begin=False)
680 # check "end - begin" doesn't result in any zero or negative elements
681 if any((end - begin) <= 0 for begin, end in zip(offset_start, offset_end)):
682 warn_cpu(
683 op,
684 "has slice begin values {}, some of which are >= end values {}, which is illegal".format(
685 begin_tens.values, end_tens.values
686 ),
687 )
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200688 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200689 if op.type == Op.SplitV:
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200690 # check that maximum one size is set to -1, indicating that size should be inferred
691 sizes = op.inputs[1].values
692 num_to_be_inferred = 0
693 for size in sizes:
694 if size == -1:
695 num_to_be_inferred += 1
696
697 if num_to_be_inferred > 1:
698 print("Warning:", op.type, "has more than one size to be inferred, which is illegal, placing on CPU")
699 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200700 if op.type in set((Op.Concat, Op.ConcatTFLite,)):
Fredrik Svedberg0f98b362020-09-29 10:00:39 +0200701 axis = op.attrs.get("axis", None)
702 if axis is None:
703 print("Warning:", op.type, "invalid or missing axis, placing on CPU")
704 return False
705 if axis < 0:
706 axis += len(op.inputs[0].shape)
Patrik Gustavsson36ad73a2020-10-06 13:58:24 +0200707 if not 0 <= axis < len(op.inputs[0].shape):
Fredrik Svedberg0f98b362020-09-29 10:00:39 +0200708 print("Warning:", op.type, "invalid axis", axis, ", placing on CPU")
709 return False
710 ofm = op.outputs[0]
711 ofm_dims = len(ofm.shape)
712 for ifm in op.inputs:
713 if len(ifm.shape) != ofm_dims:
714 return False
715 for i in range(ofm_dims):
716 if i != axis and ifm.shape[i] != ofm.shape[i]:
Patrik Gustavsson530992a2020-09-30 13:26:59 +0200717 print(
718 "Warning:",
719 op.type,
720 "invalid ifm:",
721 ifm.name,
722 ifm.shape,
723 "mismatch in dimension",
724 i,
725 ", placing on CPU",
726 )
Fredrik Svedberg0f98b362020-09-29 10:00:39 +0200727 return False
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200728
Tim Hall79d07d22020-04-27 18:20:16 +0100729 return True
Dwight Lidmanebe26c72020-06-09 11:40:54 +0200730
Michael McGeagh1eeea512020-09-30 14:23:09 +0100731 @classmethod
732 def check_quantization_restrictions_binary_elem_wise(cls, op):
Dwight Lidmanebe26c72020-06-09 11:40:54 +0200733 # makes sure IFM1, IFM2 and OFM quantization are equal for binary ops
Tim Halle3786ac2020-07-28 17:40:50 +0100734 assert len(op.inputs) >= 2 and len(op.outputs) == 1
735
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200736 if (
Tim Halle3786ac2020-07-28 17:40:50 +0100737 op.inputs[0].quantization is None
Michael McGeagh34ad19b2020-09-04 15:44:23 +0100738 or not op.inputs[0].is_scaling_equal(op.inputs[1])
739 or not op.inputs[0].is_scaling_equal(op.outputs[0])
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200740 ):
741 print(
742 "Warning: Input/output tensors with different quantization is unsupported for the", op.type, "operator"
743 )
Dwight Lidmanebe26c72020-06-09 11:40:54 +0200744 return False
Tim Halle3786ac2020-07-28 17:40:50 +0100745
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200746 return True
747
Michael McGeagh1eeea512020-09-30 14:23:09 +0100748 @classmethod
749 def check_activation_ops(cls, op):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200750 if op.type == Op.Softmax:
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200751 ifm_tensor = op.inputs[0]
752 ofm_tensor = op.outputs[0]
753
754 # check data type
755 if ifm_tensor.dtype != ofm_tensor.dtype:
Dwight Lidman8359a472020-09-28 15:53:40 +0200756 print("Warning:", op.type, "input type differs from output type, placing on CPU")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200757 return False
758
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200759 if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16):
Dwight Lidman8359a472020-09-28 15:53:40 +0200760 print(
761 "Warning: only datatypes supported for {} are uint8, int8 and int16; placing on CPU".format(op.type)
762 )
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200763 return False
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200764
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200765 # check shape
Michael McGeagh37ded342020-10-01 15:37:44 +0100766 if ifm_tensor.shape != ofm_tensor.shape:
Dwight Lidman8359a472020-09-28 15:53:40 +0200767 print("Warning:", op.type, "input shape differs from output shape, placing on CPU")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200768 return False
769
770 return True
Jacob Bohlin49d92122020-08-19 14:36:46 +0200771
Michael McGeagh1eeea512020-09-30 14:23:09 +0100772 @classmethod
773 def check_bias_restrictions(cls, bias_tensor):
Jacob Bohlin49d92122020-08-19 14:36:46 +0200774 # check data type
Jacob Bohlin258ebba2020-08-31 10:44:35 +0200775 if bias_tensor is not None and bias_tensor.dtype not in (DataType.int32, DataType.int64):
Dwight Lidman8359a472020-09-28 15:53:40 +0200776 print("Warning: bias tensor datatype must be int32 or int64, placing on CPU")
Jacob Bohlin49d92122020-08-19 14:36:46 +0200777 return False
778
779 # check if values fits in 40-bit
Jacob Bohlin258ebba2020-08-31 10:44:35 +0200780 if bias_tensor is not None and bias_tensor.dtype == DataType.int64:
Tim Hall71525172020-08-29 15:09:57 +0100781 for quant_value in bias_tensor.quant_values:
782 if not (-(1 << 39) <= quant_value < (1 << 39)):
Dwight Lidman8359a472020-09-28 15:53:40 +0200783 print("Warning: bias tensor values are larger than 40 bits, placing on CPU")
Jacob Bohlin49d92122020-08-19 14:36:46 +0200784 return False
785
786 return True