blob: dfb7bc7d976d82535140ccff05ad50eb51071f38 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# The SupportedOperators class which is a collection of all supported operators and parameter checks.
Michael McGeagh1f951fc2020-10-14 09:30:02 +010018from collections import defaultdict
19
Charles Xu87c13502020-08-06 12:17:26 +020020import numpy as np
21
Tim Hallc30f4952020-06-15 20:47:35 +010022from .data_type import BaseType
23from .data_type import DataType
Dwight Lidman8359a472020-09-28 15:53:40 +020024from .numeric_util import is_integer
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020025from .operation import get_slice_offsets
Louis Verhaardaee5d752020-09-30 09:01:52 +020026from .operation import Op
Tim Hall93582962020-09-09 21:58:15 +010027from .tensor import check_quantized_tens_scaling_equal
28from .tensor import check_tens_quantized
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020029
30
Michael McGeagh37ded342020-10-01 15:37:44 +010031# Custom decorator function to allow formatting docstrings containing "{}"
32def docstring_format_args(args):
33 def docstring(func):
34 func.__doc__ = func.__doc__.format(*args)
35 return func
36
37 return docstring
38
39
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020040def warn_cpu(op, msg):
41 print("Warning: {} {}, placing on CPU".format(op.type, msg))
Tim Hall79d07d22020-04-27 18:20:16 +010042
43
44class SupportedOperators:
Michael McGeagh1eeea512020-09-30 14:23:09 +010045 # Categorised lists of supported operators
Louis Verhaardaee5d752020-09-30 09:01:52 +020046 npu_pre_ops = set((Op.SplitSliceRead,))
47 convolution_ops = set((Op.Conv2DBias, Op.Conv2D, Op.QuantizedConv2D,))
48 depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
49 transpose_convolution_ops = set((Op.Conv2DBackpropInput,))
Michael McGeagh1f951fc2020-10-14 09:30:02 +010050 convolution_like_ops = convolution_ops | depthwise_convolution_ops | transpose_convolution_ops
Louis Verhaardaee5d752020-09-30 09:01:52 +020051 max_pooling_ops = Op.op_set(Op.is_maxpool_op)
52 avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
53 pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
54 resizing_ops = set((Op.ResizeBilinear,))
55 fc_vector_products = set((Op.QuantizedMatMul, Op.MatMul, Op.FullyConnected,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010056 mac_main_ops = (
57 # RNN/LSTM/GRU
Louis Verhaardaee5d752020-09-30 09:01:52 +020058 set((Op.BlockLSTM,))
Michael McGeagh1f951fc2020-10-14 09:30:02 +010059 # conv/depthwiseconv/transposeconv
60 | convolution_like_ops
Michael McGeagh1eeea512020-09-30 14:23:09 +010061 # pooling
62 | pooling_ops
63 # resizing/upscaling
64 | resizing_ops
65 # FC layers
66 | fc_vector_products
67 )
Louis Verhaardaee5d752020-09-30 09:01:52 +020068 unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
69 binary_elem_wise_min_max_ops = set((Op.Minimum, Op.Maximum,))
70 binary_elem_wise_shift_ops = set((Op.SHL, Op.SHR,))
71 binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.Sub,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010072 binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
73 elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
Michael McGeagh37ded342020-10-01 15:37:44 +010074 supported_int32_tensor_ops = (
Louis Verhaardaee5d752020-09-30 09:01:52 +020075 set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
Michael McGeagh37ded342020-10-01 15:37:44 +010076 )
Louis Verhaardaee5d752020-09-30 09:01:52 +020077 activation_ops = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Sigmoid, Op.Tanh, Op.Softmax,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010078 npu_post_ops = (
Michael McGeagh1eeea512020-09-30 14:23:09 +010079 # activation functions
Louis Verhaardaee5d752020-09-30 09:01:52 +020080 activation_ops
81 # concatenation write direction
82 | set((Op.ConcatSliceWrite,))
83 # Quantization
84 | set((Op.Quantize,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010085 )
Louis Verhaardaee5d752020-09-30 09:01:52 +020086 split_ops = set((Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack,))
87 concat_ops = set((Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack,))
88 memory_only_ops = set((Op.Squeeze, Op.Reshape, Op.QuantizedReshape, Op.ExpandDims,)) | concat_ops | split_ops
89 shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV,))
90 supported_fused_activations = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Tanh, Op.Sigmoid, Op.LUT,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010091 supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | npu_post_ops | memory_only_ops
Michael McGeagh1f951fc2020-10-14 09:30:02 +010092 # Supported data types
93 supported_op_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
94 supported_bias_dtypes = set((DataType.int32, DataType.int64))
Michael McGeagh37ded342020-10-01 15:37:44 +010095 # Defined ranges for allowed values:
96 tens_dim_range = (1, 65535)
Michael McGeagh1f951fc2020-10-14 09:30:02 +010097 stride_range = (1, 3)
98 dilation_range = (1, 2)
99 dilated_height_range = (1, 64)
100 dilated_product_range = (1, 64 * 64)
101 weights_limit = 127 * 65536
Michael McGeagh1eeea512020-09-30 14:23:09 +0100102
Fredrik Svedberg880e7352020-08-25 11:31:47 +0200103 def __init__(self):
Tim Hall79d07d22020-04-27 18:20:16 +0100104 # Setup supported operator restriction checkers
105 self.supported_operator_restrictions = {}
106 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100107 {op: self.check_depthwise_convolution_restrictions for op in SupportedOperators.depthwise_convolution_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100108 )
Jacob Bohlincf7da102020-05-20 09:03:40 +0200109 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100110 {op: self.check_transpose_convolution_restrictions for op in SupportedOperators.transpose_convolution_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100111 )
112 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100113 {op: self.check_pooling_restrictions for op in SupportedOperators.pooling_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100114 )
115 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100116 {op: self.check_resize_restrictions for op in SupportedOperators.resizing_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100117 )
Michael McGeagh1eeea512020-09-30 14:23:09 +0100118 self.supported_operator_restrictions.update(
119 {op: self.check_vector_product_restrictions for op in SupportedOperators.fc_vector_products}
120 )
121 self.supported_operator_restrictions.update(
122 {op: self.check_element_wise_restrictions for op in SupportedOperators.elem_wise_main_ops}
123 )
124 self.supported_operator_restrictions.update(
125 {op: self.check_memory_only_restrictions for op in SupportedOperators.memory_only_ops}
126 )
127 self.supported_operator_restrictions.update(
128 {op: self.check_activation_ops for op in SupportedOperators.activation_ops}
129 )
Michael McGeagh184b2502020-10-09 17:19:52 +0100130 # Setup the generic constraints. Note: the order matters
Michael McGeagh37ded342020-10-01 15:37:44 +0100131 self.generic_constraints = []
132 self.generic_constraints.append(SupportedOperators.constraint_tens_defined_shape)
Michael McGeagh184b2502020-10-09 17:19:52 +0100133 self.generic_constraints.append(SupportedOperators.constraint_tens_output_shapeless)
134 self.generic_constraints.append(SupportedOperators.constraint_tens_input_shapeless)
Michael McGeagh37ded342020-10-01 15:37:44 +0100135 self.generic_constraints.append(SupportedOperators.constraint_tens_shape_size)
136 self.generic_constraints.append(SupportedOperators.constraint_tens_dtype)
Michael McGeagh184b2502020-10-09 17:19:52 +0100137 self.generic_constraints.append(SupportedOperators.constraint_tens_int32_ops)
Michael McGeagh37ded342020-10-01 15:37:44 +0100138 self.generic_constraints.append(SupportedOperators.constraint_tens_dimension)
Dwight Lidman8359a472020-09-28 15:53:40 +0200139 self.generic_constraints.append(SupportedOperators.constraint_tens_quant_none_check)
Michael McGeagh184b2502020-10-09 17:19:52 +0100140 self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale)
141 self.generic_constraints.append(SupportedOperators.constraint_faf)
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100142 # Setup specific constraints. The key in the dictionary must be a tuple of op types the constraints apply to
143 self.specific_constraints = defaultdict(list)
144 # Conv-like ops have the same checks applied to them:
145 conv_like_ops = tuple(SupportedOperators.convolution_like_ops)
146 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_stride_type)
147 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_stride_range)
148 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilation_type)
149 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilation_range)
150 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilated_height_range)
151 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_dilated_product_range)
152 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_weights_type)
153 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_weights_nonconst)
154 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_weights_limit)
155 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_bias_type)
156 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_bias_40bit)
157 self.specific_constraints[conv_like_ops].append(SupportedOperators.constraint_batch_size)
158
159 def get_constraints_list(self, op_type):
160 constraint_list = list(self.generic_constraints)
161 for ops in self.specific_constraints:
162 if op_type in ops:
163 constraint_list.extend(self.specific_constraints[ops])
164 return constraint_list
Tim Hall79d07d22020-04-27 18:20:16 +0100165
166 def is_operator_supported(self, op):
Michael McGeagh1eeea512020-09-30 14:23:09 +0100167 if op.type not in SupportedOperators.supported_operators:
Louis Verhaard5f2ea2f2020-10-15 08:39:44 +0200168 if op.type not in (Op.Placeholder, Op.SubgraphInput, Op.Const):
169 print("Info: {} '{}' is not supported on the NPU. Placing on CPU instead".format(op.type, op.name))
Tim Hall79d07d22020-04-27 18:20:16 +0100170 return False
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100171
172 for constraint in self.get_constraints_list(op.type):
Michael McGeagh37ded342020-10-01 15:37:44 +0100173 valid, extra = constraint(op)
174 if not valid:
Michael McGeagh184b2502020-10-09 17:19:52 +0100175 print("Warning: {} '{}' is not supported on the NPU. Placing on CPU instead".format(op.type, op.name))
Michael McGeagh37ded342020-10-01 15:37:44 +0100176 print(" - {}".format(constraint.__doc__))
177 if extra:
178 print(" {}".format(extra))
179 return False
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100180
Tim Hall79d07d22020-04-27 18:20:16 +0100181 if op.type in self.supported_operator_restrictions:
182 return self.supported_operator_restrictions[op.type](op)
183 return True
184
Michael McGeagh37ded342020-10-01 15:37:44 +0100185 @staticmethod
186 def constraint_tens_defined_shape(op):
187 "Input(s) and Output Tensors must have a defined shape"
188 valid = True
189 extra = []
Michael McGeagh184b2502020-10-09 17:19:52 +0100190 tensors = [tens for tens in op.inputs + op.outputs if tens]
191 for tens in tensors:
192 if not tens.has_fully_defined_shape():
193 valid = False
194 extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
195 return valid, ", ".join(extra)
Michael McGeagh37ded342020-10-01 15:37:44 +0100196
Michael McGeagh184b2502020-10-09 17:19:52 +0100197 @staticmethod
198 def constraint_tens_output_shapeless(op):
199 "Scalar or Broadcasting Tensors are only valid for Input Tensors"
Michael McGeagh37ded342020-10-01 15:37:44 +0100200 valid = True
201 extra = []
Michael McGeagh37ded342020-10-01 15:37:44 +0100202 for tens in op.outputs:
203 if tens.shape == []:
204 valid = False
Michael McGeagh184b2502020-10-09 17:19:52 +0100205 extra.append("Output Tensor '{}' is shapeless".format(tens.name))
206 return valid, ", ".join(extra)
207
208 @classmethod
209 @docstring_format_args([shapeless_input_ops])
210 def constraint_tens_input_shapeless(cls, op):
211 "Scalar or Broadcasting Input Tensors are only valid for op type: {}"
212 valid = True
213 extra = []
214 tensors = [tens for tens in op.inputs if tens]
215 for tens in tensors:
216 if (tens.shape == []) and (op.type not in cls.shapeless_input_ops):
217 valid = False
218 extra.append(tens.name)
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100219 extra = "Op has shapeless input tensor(s): {}".format(", ".join(extra))
Michael McGeagh184b2502020-10-09 17:19:52 +0100220 return valid, extra
Tim Hall79d07d22020-04-27 18:20:16 +0100221
Michael McGeagh37ded342020-10-01 15:37:44 +0100222 @staticmethod
223 def constraint_tens_shape_size(op):
224 "Input(s) and Output Tensors must not be greater than 4D"
225 valid = True
226 extra = []
Michael McGeagh184b2502020-10-09 17:19:52 +0100227 tensors = [tens for tens in op.inputs + op.outputs if tens]
228 for tens in tensors:
229 if len(tens.shape) > 4:
230 valid = False
231 extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
232 return valid, ", ".join(extra)
Tim Hall79d07d22020-04-27 18:20:16 +0100233
Michael McGeagh37ded342020-10-01 15:37:44 +0100234 @classmethod
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100235 @docstring_format_args([supported_op_dtypes])
Michael McGeagh37ded342020-10-01 15:37:44 +0100236 def constraint_tens_dtype(cls, op):
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100237 "Input(s), Output and Weight Tensors must be of type: {}"
Michael McGeagh37ded342020-10-01 15:37:44 +0100238 valid = True
239 extra = []
240 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
241 tensors = tensors if tensors else op.inputs
242 for tens in tensors:
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100243 if tens.dtype not in cls.supported_op_dtypes:
Michael McGeagh184b2502020-10-09 17:19:52 +0100244 valid = False
245 extra.append("Tensor '{}' has data type: {}".format(tens.name, tens.dtype))
246 return valid, ", ".join(extra)
247
248 @classmethod
249 @docstring_format_args([supported_int32_tensor_ops])
250 def constraint_tens_int32_ops(cls, op):
251 "Tensors which are int32 are only valid when op type is: {}"
252 valid = True
253 extra = []
254 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
255 tensors = tensors if tensors else op.inputs
256 for tens in tensors:
257 if (tens.dtype == DataType.int32) and (op.type not in cls.supported_int32_tensor_ops):
258 valid = False
259 extra.append(tens.name)
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100260 extra = "Op has int32 tensor(s): {}".format(", ".join(extra))
Michael McGeagh184b2502020-10-09 17:19:52 +0100261 return valid, extra
Andreas Nevalaineneadb1662020-09-01 15:36:26 +0200262
Michael McGeagh37ded342020-10-01 15:37:44 +0100263 @classmethod
264 @docstring_format_args(tens_dim_range)
265 def constraint_tens_dimension(cls, op):
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100266 "Tensor dimensions must be in the range [{}, {}]"
Michael McGeagh37ded342020-10-01 15:37:44 +0100267 tens_min, tens_max = cls.tens_dim_range
268 valid = True
269 extra = []
270 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
271 tensors = tensors if tensors else op.inputs
272 for tens in tensors:
Michael McGeagh184b2502020-10-09 17:19:52 +0100273 if not all(tens_min <= dim <= tens_max for dim in tens.shape):
274 valid = False
275 extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
276 return valid, ", ".join(extra)
Tim Hall79d07d22020-04-27 18:20:16 +0100277
Dwight Lidman8359a472020-09-28 15:53:40 +0200278 @staticmethod
279 def constraint_tens_quant_none_check(op):
280 "Tensors must have quantization parameters"
281 valid = True
282 extra = []
283 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
284 for tens in tensors:
285 if tens.quantization is None:
286 valid = False
287 extra.append("Tensor '{}' has no quantization parameters".format(tens.name))
288 return valid, ", ".join(extra)
289
Michael McGeagh184b2502020-10-09 17:19:52 +0100290 @staticmethod
291 def constraint_tens_quant_scale(op):
292 "Tensors with quantization scales must be finite"
293 valid = True
294 extra = []
295 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
296 for tens in tensors:
297 if (tens.quantization.scale_f32 is not None) and np.isinf(tens.quantization.scale_f32).any():
298 valid = False
299 extra.append("Tensor '{}' has quantization scale: {}".format(tens.name, tens.quantization.scale_f32))
300 return valid, ", ".join(extra)
301
302 @classmethod
303 @docstring_format_args([supported_fused_activations])
304 def constraint_faf(cls, op):
305 "The fused activation function (if present) must be one of type: {}"
306 faf = op.activation
307 valid = (faf is None) or (faf in cls.supported_fused_activations)
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100308 extra = "Op has its fused activation function as: {}".format(faf)
309 return valid, extra
310
311 @staticmethod
312 def constraint_stride_type(op):
313 "Stride values for both width and height must be integer types"
314 w = op.attrs["stride_w"]
315 h = op.attrs["stride_h"]
316 valid = is_integer(w) and is_integer(h)
317 extra = "Op has stride WxH as: {}x{}".format(repr(w), repr(h))
Michael McGeagh184b2502020-10-09 17:19:52 +0100318 return valid, extra
319
Michael McGeagh1eeea512020-09-30 14:23:09 +0100320 @classmethod
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100321 @docstring_format_args(stride_range)
322 def constraint_stride_range(cls, op):
323 "Stride values for both width and height must be in the range [{}, {}]"
324 w = op.attrs["stride_w"]
325 h = op.attrs["stride_h"]
326 stride_min, stride_max = cls.stride_range
327 valid = (stride_min <= w <= stride_max) and (stride_min <= h <= stride_max)
328 extra = "Op has stride WxH as: {}x{}".format(w, h)
329 return valid, extra
Tim Hall79d07d22020-04-27 18:20:16 +0100330
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100331 @staticmethod
332 def constraint_dilation_type(op):
333 "Dilation factor values for both width and height must be integer types"
334 w = op.attrs.get("dilation_w_factor", 1)
335 h = op.attrs.get("dilation_h_factor", 1)
336 valid = is_integer(w) and is_integer(h)
337 extra = "Op has dilation factor WxH as: {}x{}".format(repr(w), repr(h))
338 return valid, extra
Tim Hall79d07d22020-04-27 18:20:16 +0100339
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100340 @classmethod
341 @docstring_format_args(dilation_range)
342 def constraint_dilation_range(cls, op):
343 "Dilation factor values for both width and height must be in the range [{}, {}]"
344 w = op.attrs.get("dilation_w_factor", 1)
345 h = op.attrs.get("dilation_h_factor", 1)
346 dilation_min, dilation_max = cls.dilation_range
347 valid = (dilation_min <= w <= dilation_max) and (dilation_min <= h <= dilation_max)
348 extra = "Op has dilation factor WxH as: {}x{}".format(w, h)
349 return valid, extra
Tim Hall79d07d22020-04-27 18:20:16 +0100350
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100351 @classmethod
352 @docstring_format_args(dilated_height_range)
353 def constraint_dilated_height_range(cls, op):
354 "Dilated kernel height must be in the range [{}, {}]"
355 h = (op.weights.shape[0] - 1) * op.attrs.get("dilation_h_factor", 1) + 1
356 dilated_height_min, dilated_height_max = cls.dilated_height_range
357 valid = dilated_height_min <= h <= dilated_height_max
358 extra = "Op has dilated kernel height as: {}".format(h)
359 return valid, extra
Jacob Bohlin49d92122020-08-19 14:36:46 +0200360
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100361 @classmethod
362 @docstring_format_args(dilated_product_range)
363 def constraint_dilated_product_range(cls, op):
364 "Product of dilated kernel width and height must be in the range [{}, {}]"
365 weights = op.weights
366 w = (weights.shape[1] - 1) * op.attrs.get("dilation_w_factor", 1) + 1
367 h = (weights.shape[0] - 1) * op.attrs.get("dilation_h_factor", 1) + 1
368 product = w * h
369 dilated_product_min, dilated_product_max = cls.dilated_product_range
370 valid = dilated_product_min <= product <= dilated_product_max
371 extra = "Op has product of dilated kernel width and height as: {}".format(product)
372 return valid, extra
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200373
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100374 @staticmethod
375 def constraint_weights_type(op):
376 "Weight Tensor must be 8-bit"
377 weights = op.weights
378 valid = weights.element_size() == 1
379 extra = "Tensor '{}' is {}-bit".format(weights.name, int(weights.element_size() * 8))
380 return valid, extra
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200381
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100382 @staticmethod
383 def constraint_weights_nonconst(op):
384 "Weight tensor cannot be non-constant"
385 weights = op.weights
386 valid = weights.values is not None
387 extra = "Tensor '{}' has non-constant values".format(weights.name)
388 return valid, extra
Andreas Nevalainen8854dc92020-09-24 13:43:00 +0200389
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100390 @classmethod
391 @docstring_format_args([weights_limit])
392 def constraint_weights_limit(cls, op):
393 "The sum of the weights cannot exceed {}"
394 weights = op.weights
395 values = weights.quant_values.astype(np.int64) - weights.quantization.zero_point
396 limit = np.amax(np.sum(np.absolute(values), axis=(0, 1, 2)))
397 valid = limit <= cls.weights_limit
398 extra = "Tensor '{}' has the sum of weights: {}".format(weights.name, limit)
399 return valid, extra
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200400
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100401 @classmethod
402 @docstring_format_args([supported_bias_dtypes])
403 def constraint_bias_type(cls, op):
404 "Optional Bias Tensor must be of type: {}"
405 valid = True
406 extra = ""
407 bias = op.bias
408 if bias:
409 valid = bias.dtype in cls.supported_bias_dtypes
410 extra = "Tensor '{}' has data type: {}".format(bias.name, bias.dtype)
411 return valid, extra
Tim Hall79d07d22020-04-27 18:20:16 +0100412
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100413 @staticmethod
414 def constraint_bias_40bit(op):
415 "Optional Bias Tensor values must fit within 40-bits"
416 valid = True
417 extra = ""
418 bias = op.bias
419 if bias and bias.dtype == DataType.int64:
420 valid = all(len(bin(quant_value)[2:]) <= 40 for quant_value in bias.quant_values)
421 extra = "Tensor '{}' has values larger than 40-bits".format(bias.name)
422 return valid, extra
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200423
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100424 @staticmethod
425 def constraint_batch_size(op):
426 "IFM Tensor batch size must be 1"
427 ifm = op.ifm
428 valid = ifm.shape[0] == 1
429 extra = "Tensor '{}' has batch size: {}".format(ifm.name, ifm.shape[0])
430 return valid, extra
Tim Hall79d07d22020-04-27 18:20:16 +0100431
Michael McGeagh1eeea512020-09-30 14:23:09 +0100432 @classmethod
433 def check_depthwise_convolution_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100434 # check depth
Louis Verhaardaee5d752020-09-30 09:01:52 +0200435 ifm_tensor, ofm_tensor = op.get_ifm_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100436 if op.attrs["depth_multiplier"] > 1 and not (
437 (ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"])
438 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200439 print(
440 "Warning: for depth multipliers > 1,",
441 "number of input channels must be 1 and number of output channels must be equal to depth multiplier.",
442 "Placing on CPU",
443 )
Tim Hall79d07d22020-04-27 18:20:16 +0100444 return False
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100445 return True
Tim Hall79d07d22020-04-27 18:20:16 +0100446
Michael McGeagh1eeea512020-09-30 14:23:09 +0100447 @classmethod
448 def check_transpose_convolution_restrictions(cls, op):
Jacob Bohlincf7da102020-05-20 09:03:40 +0200449 # check stride
450 stride_h, stride_w = op.attrs["stride_h"], op.attrs["stride_w"]
Dwight Lidman8359a472020-09-28 15:53:40 +0200451 if stride_h != 2 or stride_w != 2:
452 print("Warning: stride must be equal to 2, placing on CPU")
Jacob Bohlincf7da102020-05-20 09:03:40 +0200453 return False
454
455 # check output dimensions
456 ifm_tensor, weight_tensor, _, ofm_tensor = op.get_ifm_weights_biases_ofm()
457 ifm_h, ifm_w = ifm_tensor.shape[1], ifm_tensor.shape[2]
458 ofm_h, ofm_w = ofm_tensor.shape[1], ofm_tensor.shape[2]
459 if op.attrs["padding"] == b"SAME":
460 if (ofm_h != ifm_h * stride_h) or (ofm_w != ifm_w * stride_w):
Dwight Lidman8359a472020-09-28 15:53:40 +0200461 print(
462 "Warning: for",
463 op.type,
464 "using SAME padding, output dimensions must equal input dimensions multiplied by stride.",
465 "Placing on CPU",
466 )
Jacob Bohlincf7da102020-05-20 09:03:40 +0200467 return False
468 elif op.attrs["padding"] == b"VALID":
469 kernel_h, kernel_w = weight_tensor.shape[0], weight_tensor.shape[1]
Tim Hallc30f4952020-06-15 20:47:35 +0100470 if (ofm_h != (ifm_h) * stride_h + max(kernel_h - stride_h, 0)) or (
471 ofm_w != (ifm_w) * stride_w + max(kernel_w - stride_w, 0)
472 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200473 print(
474 "Warning: for",
475 op.type,
476 "using VALID padding, output dimensions must equal input dimensions multiplied by stride,",
477 "minus difference between kernel size and stride. Placing on CPU",
478 )
Jacob Bohlincf7da102020-05-20 09:03:40 +0200479 return False
Michael McGeagh1f951fc2020-10-14 09:30:02 +0100480 return True
Jacob Bohlincf7da102020-05-20 09:03:40 +0200481
Michael McGeagh1eeea512020-09-30 14:23:09 +0100482 @classmethod
483 def check_pooling_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100484 # check stride
Dwight Lidman8359a472020-09-28 15:53:40 +0200485 stride_w, stride_h = op.attrs["stride_w"], op.attrs["stride_h"]
486 if not is_integer(stride_w) or not is_integer(stride_h):
487 print("Warning:", op.type, "has non-integer stride, placing on CPU")
488 return False
489 if not 1 <= stride_w <= 3 or not 1 <= stride_h <= 3:
490 print(
491 "Warning: {} has stride ({}, {}), only strides in range [1, 3] are allowed. Placing on CPU".format(
492 op.type, stride_w, stride_h
493 )
494 )
Tim Hall79d07d22020-04-27 18:20:16 +0100495 return False
496
497 # check data type
Louis Verhaardaee5d752020-09-30 09:01:52 +0200498 ifm_tensor, ofm_tensor = op.get_ifm_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100499 if ifm_tensor.dtype != ofm_tensor.dtype:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200500 if op.type != Op.ReduceSum:
Dwight Lidman8359a472020-09-28 15:53:40 +0200501 print("Warning: input data type doesn't match output data type, placing on CPU")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200502 return False
503 # TODO: else check ReduceSum restrictions.
Tim Hall79d07d22020-04-27 18:20:16 +0100504
505 # check batch size
506 if ifm_tensor.shape[0] != 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200507 print("Warning: input batch size must be 1, placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100508 return False
509
Dwight Lidman8359a472020-09-28 15:53:40 +0200510 # check kernel size
511 kernel_w, kernel_h = op.attrs["filter_width"], op.attrs["filter_height"]
512 if op.type in cls.avg_pooling_ops and op.attrs["padding"] == b"SAME":
513 if not 1 <= kernel_w <= 8 or not 1 <= kernel_h <= 8:
514 print(
515 "Warning:",
516 op.type,
517 "has kernel size ({}, {}), only kernel sizes in range [1, 8] are allowed. Placing on CPU".format(
518 kernel_w, kernel_h
519 ),
520 )
Tim Hall79d07d22020-04-27 18:20:16 +0100521 return False
Dwight Lidman8359a472020-09-28 15:53:40 +0200522 if op.type in cls.avg_pooling_ops and op.attrs["padding"] == b"VALID" or op.type in cls.max_pooling_ops:
523 if not 1 <= kernel_w * kernel_h <= 256 * 256:
524 print(
525 "Warning: product of kernel width and height must be >= 1 and not exceed 256 * 256 ({}),".format(
526 256 * 256
527 ),
528 "placing on CPU",
529 )
530 return False
531 if not 1 <= kernel_h <= 256:
532 print("Warning:", op.type, "has kernel height outside of range [1, 256], placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100533 return False
534
Tim Hall79d07d22020-04-27 18:20:16 +0100535 return True
536
Michael McGeagh1eeea512020-09-30 14:23:09 +0100537 @classmethod
538 def check_resize_restrictions(cls, op):
Dwight Lidman42fed942020-05-29 09:37:03 +0200539 # check unsupported upscaling factor
Louis Verhaardaee5d752020-09-30 09:01:52 +0200540 if op.type == Op.ResizeBilinear:
Charles Xu9a03fdf2020-07-02 15:12:40 +0200541 if op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1:
542 return True
Charles Xu36ffaf32020-08-05 15:40:44 +0200543 if op.inputs[0].shape == op.outputs[0].shape:
544 return True
Charles Xu87c13502020-08-06 12:17:26 +0200545 upscaled_shape = np.array(op.inputs[0].shape[1:3])
546 out_shape = np.array(op.outputs[0].shape[1:3])
547 while (upscaled_shape < out_shape).all():
548 upscaled_shape *= 2
549 if op.attrs["align_corners"]:
550 upscaled_shape -= 1
551 if np.array_equal(out_shape, upscaled_shape):
552 return True
553 return False
Dwight Lidman42fed942020-05-29 09:37:03 +0200554
Michael McGeagh1eeea512020-09-30 14:23:09 +0100555 @classmethod
556 def check_vector_product_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100557 # check data type
Dwight Lidman8359a472020-09-28 15:53:40 +0200558 ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100559 if weight_tensor.element_size() > 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200560 print("Warning: only 8-bit datatypes supported for {}, placing on CPU".format(op.type))
561 return False
562
Michael McGeagh1eeea512020-09-30 14:23:09 +0100563 if not cls.check_bias_restrictions(bias_tensor):
Jacob Bohlin49d92122020-08-19 14:36:46 +0200564 return False
565
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200566 # check non const weights
567 if weight_tensor.values is None:
568 print("Warning:", op.type, "has non-const weights, placing on CPU")
569 return False
570
Tim Hall79d07d22020-04-27 18:20:16 +0100571 return True
572
Michael McGeagh1eeea512020-09-30 14:23:09 +0100573 @classmethod
574 def check_element_wise_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100575 # check data type
576 ifm_tensor, ifm2_tensor, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200577 # input and output datatype must match for these operators
Tim Hallc30f4952020-06-15 20:47:35 +0100578 if (
Michael McGeagh1eeea512020-09-30 14:23:09 +0100579 op.type in cls.binary_elem_wise_min_max_ops | cls.unary_elem_wise_main_ops
Tim Hallc30f4952020-06-15 20:47:35 +0100580 and ifm_tensor.dtype != ofm_tensor.dtype
581 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200582 print("Warning:", op.type, "must have same input and output datatype, placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100583 return False
Michael McGeagh1eeea512020-09-30 14:23:09 +0100584 if op.type in cls.binary_elem_wise_add_mul_sub:
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200585 # both inputs must have same type
Tim Hallc30f4952020-06-15 20:47:35 +0100586 if ifm_tensor.dtype != ifm2_tensor.dtype:
Dwight Lidman8359a472020-09-28 15:53:40 +0200587 print("Warning:", op.type, "must have same datatype on both inputs, placing on CPU")
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200588 return False
589 # signed input check
Tim Hallc30f4952020-06-15 20:47:35 +0100590 if ifm_tensor.dtype.type & BaseType.Signed:
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200591 # output must be signed
Tim Hallc30f4952020-06-15 20:47:35 +0100592 if ofm_tensor.dtype.type & BaseType.Unsigned:
Dwight Lidman8359a472020-09-28 15:53:40 +0200593 print("Warning: only signed output types supported for {}, placing on CPU".format(op.type))
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200594 return False
595 # and 8, 16 or 32-bit
Dwight Lidman8359a472020-09-28 15:53:40 +0200596 bit_lengths = {8, 16, 32}
597 if ofm_tensor.element_size() * 8 not in bit_lengths:
598 print(
599 "Warning:", op.type, "is only supported for bit lengths {}, placing on CPU".format(bit_lengths)
600 )
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200601 return False
602 # unsigned input check, output must be same type or int32
Tim Hallc30f4952020-06-15 20:47:35 +0100603 if ifm_tensor.dtype.type & BaseType.Unsigned and not (
604 ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32
605 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200606 print("Warning:", op.type, "has unsigned input but output is not unsigned or int32, placing on CPU")
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200607 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200608 elif op.type in cls.binary_elem_wise_shift_ops:
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200609 if ifm_tensor.dtype != DataType.int32 or ifm2_tensor.dtype != DataType.int32:
Dwight Lidman8359a472020-09-28 15:53:40 +0200610 print("Warning:", op.type, "input datatypes are not int32, placing on CPU")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200611 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200612 if op.type in (Op.CLZ, Op.SHL) and ofm_tensor.dtype != DataType.int32:
Dwight Lidman8359a472020-09-28 15:53:40 +0200613 print("Warning:", op.type, "output datatype is not int32, placing on CPU")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200614 return False
Tim Hall79d07d22020-04-27 18:20:16 +0100615
616 # check batch size
Dwight Lidmanf995db72020-04-27 11:15:12 +0200617 if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200618 print(
619 "Warning:",
620 op.type,
621 "only supports batch size 1 for tensors with more than 2 dimensions, placing on CPU",
622 )
Tim Hallc30f4952020-06-15 20:47:35 +0100623 return False
Michael McGeagh1eeea512020-09-30 14:23:09 +0100624 if op.type in cls.binary_elem_wise_main_ops: # if op type is unary, ifm2_tensor is None
Dwight Lidmanf995db72020-04-27 11:15:12 +0200625 if len(ifm2_tensor.shape) > 2 and ifm2_tensor.shape[0] != 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200626 print(
627 "Warning:",
628 op.type,
629 "only supports batch size 1 for tensors with more than 2 dimensions, placing on CPU",
630 )
Dwight Lidmanf995db72020-04-27 11:15:12 +0200631 return False
Dwight Lidman332a7042020-06-11 15:32:42 +0200632
633 # negative alpha values are not supported
Louis Verhaardaee5d752020-09-30 09:01:52 +0200634 if op.type == Op.LeakyRelu and op.attrs["alpha"] < 0:
Dwight Lidman8359a472020-09-28 15:53:40 +0200635 print("Warning:", op.type, "has negative alpha, placing on CPU")
Dwight Lidman332a7042020-06-11 15:32:42 +0200636 return False
637
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200638 # check if ifm or ifm2 has ofm shape
639 if ifm_tensor.shape != ofm_tensor.shape and ifm2_tensor.shape != ofm_tensor.shape:
Dwight Lidman8359a472020-09-28 15:53:40 +0200640 print("Warning:", op.type, "input shape(s) differ from output shape, placing on CPU")
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200641 return False
642
Michael McGeagh1eeea512020-09-30 14:23:09 +0100643 if op.type in cls.binary_elem_wise_min_max_ops and not cls.check_quantization_restrictions_binary_elem_wise(op):
Patrik Gustavsson530992a2020-09-30 13:26:59 +0200644 return False
645
Tim Hall79d07d22020-04-27 18:20:16 +0100646 return True
647
Michael McGeagh1eeea512020-09-30 14:23:09 +0100648 @classmethod
649 def check_memory_only_restrictions(cls, op):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200650 if op.type == Op.StridedSlice:
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200651 if len(op.inputs) != 4:
652 warn_cpu(op, "has {} input tensors, only 4 inputs are supported".format(len(op.inputs)))
Tim Hall79d07d22020-04-27 18:20:16 +0100653 return False
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200654 input_tens, begin_tens, end_tens, strides_tens = op.inputs
655 if begin_tens.values is None or end_tens.values is None or strides_tens.values is None:
656 warn_cpu(op, "has a non-constant begin, end, or stride input tensor, which is not supported")
657 return False
658 if not (
659 len(input_tens.shape)
660 == len(op.outputs[0].shape)
661 == len(begin_tens.values)
662 == len(end_tens.values)
663 == len(strides_tens.values)
664 ):
665 warn_cpu(op, "has input tensors with shapes that are not supported")
666 return False
667 # check stride size
668 if any(stride != 1 for stride in strides_tens.values):
669 warn_cpu(op, "has stride values {}, only stride 1 values are supported".format(strides_tens.values))
Michael McGeaghecd20522020-07-31 16:59:45 +0100670 return False
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200671 # check ellipsis_mask
672 if op.attrs["ellipsis_mask"] != 0:
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200673 warn_cpu(op, "ellipsis_mask is {}, only 0 is supported".format(op.attrs["ellipsis_mask"]))
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200674 return False
675 # check if both new_axis_mask and shrink_axis_mask have bit set
676 if op.attrs["new_axis_mask"] != 0 and op.attrs["shrink_axis_mask"] != 0:
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200677 warn_cpu(op, "new_axis_mask and shrink_axis_mask are both non-zero, which is not supported")
678 return False
679 # Calculate offset start/end
680 offset_start = get_slice_offsets(input_tens.shape, begin_tens, op.attrs["begin_mask"], is_begin=True)
681 offset_end = get_slice_offsets(input_tens.shape, end_tens, op.attrs["end_mask"], is_begin=False)
682 # check "end - begin" doesn't result in any zero or negative elements
683 if any((end - begin) <= 0 for begin, end in zip(offset_start, offset_end)):
684 warn_cpu(
685 op,
686 "has slice begin values {}, some of which are >= end values {}, which is illegal".format(
687 begin_tens.values, end_tens.values
688 ),
689 )
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200690 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200691 if op.type == Op.SplitV:
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200692 # check that maximum one size is set to -1, indicating that size should be inferred
693 sizes = op.inputs[1].values
694 num_to_be_inferred = 0
695 for size in sizes:
696 if size == -1:
697 num_to_be_inferred += 1
698
699 if num_to_be_inferred > 1:
700 print("Warning:", op.type, "has more than one size to be inferred, which is illegal, placing on CPU")
701 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200702 if op.type in set((Op.Concat, Op.ConcatTFLite,)):
Fredrik Svedberg0f98b362020-09-29 10:00:39 +0200703 axis = op.attrs.get("axis", None)
704 if axis is None:
705 print("Warning:", op.type, "invalid or missing axis, placing on CPU")
706 return False
707 if axis < 0:
708 axis += len(op.inputs[0].shape)
Patrik Gustavsson36ad73a2020-10-06 13:58:24 +0200709 if not 0 <= axis < len(op.inputs[0].shape):
Fredrik Svedberg0f98b362020-09-29 10:00:39 +0200710 print("Warning:", op.type, "invalid axis", axis, ", placing on CPU")
711 return False
712 ofm = op.outputs[0]
713 ofm_dims = len(ofm.shape)
714 for ifm in op.inputs:
715 if len(ifm.shape) != ofm_dims:
716 return False
717 for i in range(ofm_dims):
718 if i != axis and ifm.shape[i] != ofm.shape[i]:
Patrik Gustavsson530992a2020-09-30 13:26:59 +0200719 print(
720 "Warning:",
721 op.type,
722 "invalid ifm:",
723 ifm.name,
724 ifm.shape,
725 "mismatch in dimension",
726 i,
727 ", placing on CPU",
728 )
Fredrik Svedberg0f98b362020-09-29 10:00:39 +0200729 return False
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200730
Tim Hall79d07d22020-04-27 18:20:16 +0100731 return True
Dwight Lidmanebe26c72020-06-09 11:40:54 +0200732
Michael McGeagh1eeea512020-09-30 14:23:09 +0100733 @classmethod
734 def check_quantization_restrictions_binary_elem_wise(cls, op):
Tim Hall93582962020-09-09 21:58:15 +0100735 # checks that IFM1, IFM2 and OFM quantization are equal for binary ops
736
Tim Halle3786ac2020-07-28 17:40:50 +0100737 assert len(op.inputs) >= 2 and len(op.outputs) == 1
738
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200739 if (
Tim Hall93582962020-09-09 21:58:15 +0100740 not check_tens_quantized(op.inputs[0])
741 or not check_tens_quantized(op.inputs[1])
742 or not check_tens_quantized(op.outputs[0])
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200743 ):
Tim Hall93582962020-09-09 21:58:15 +0100744 warn_cpu(op, "has non-quantised input and/or output tensors")
745 return False
746
747 if not check_quantized_tens_scaling_equal(op.inputs[0], op.inputs[1]) or not check_quantized_tens_scaling_equal(
748 op.inputs[0], op.outputs[0]
749 ):
750 warn_cpu(op, "has input/output tensors with different quantisation which is illegal")
Dwight Lidmanebe26c72020-06-09 11:40:54 +0200751 return False
Tim Halle3786ac2020-07-28 17:40:50 +0100752
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200753 return True
754
Michael McGeagh1eeea512020-09-30 14:23:09 +0100755 @classmethod
756 def check_activation_ops(cls, op):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200757 if op.type == Op.Softmax:
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200758 ifm_tensor = op.inputs[0]
759 ofm_tensor = op.outputs[0]
760
761 # check data type
762 if ifm_tensor.dtype != ofm_tensor.dtype:
Dwight Lidman8359a472020-09-28 15:53:40 +0200763 print("Warning:", op.type, "input type differs from output type, placing on CPU")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200764 return False
765
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200766 if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16):
Dwight Lidman8359a472020-09-28 15:53:40 +0200767 print(
768 "Warning: only datatypes supported for {} are uint8, int8 and int16; placing on CPU".format(op.type)
769 )
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200770 return False
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200771
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200772 # check shape
Michael McGeagh37ded342020-10-01 15:37:44 +0100773 if ifm_tensor.shape != ofm_tensor.shape:
Dwight Lidman8359a472020-09-28 15:53:40 +0200774 print("Warning:", op.type, "input shape differs from output shape, placing on CPU")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200775 return False
776
777 return True
Jacob Bohlin49d92122020-08-19 14:36:46 +0200778
Michael McGeagh1eeea512020-09-30 14:23:09 +0100779 @classmethod
780 def check_bias_restrictions(cls, bias_tensor):
Jacob Bohlin49d92122020-08-19 14:36:46 +0200781 # check data type
Jacob Bohlin258ebba2020-08-31 10:44:35 +0200782 if bias_tensor is not None and bias_tensor.dtype not in (DataType.int32, DataType.int64):
Dwight Lidman8359a472020-09-28 15:53:40 +0200783 print("Warning: bias tensor datatype must be int32 or int64, placing on CPU")
Jacob Bohlin49d92122020-08-19 14:36:46 +0200784 return False
785
786 # check if values fits in 40-bit
Jacob Bohlin258ebba2020-08-31 10:44:35 +0200787 if bias_tensor is not None and bias_tensor.dtype == DataType.int64:
Tim Hall71525172020-08-29 15:09:57 +0100788 for quant_value in bias_tensor.quant_values:
789 if not (-(1 << 39) <= quant_value < (1 << 39)):
Dwight Lidman8359a472020-09-28 15:53:40 +0200790 print("Warning: bias tensor values are larger than 40 bits, placing on CPU")
Jacob Bohlin49d92122020-08-19 14:36:46 +0200791 return False
792
793 return True