blob: 4e989124b29355e482d5a5e3966e1a148f782c19 [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# The SupportedOperators class which is a collection of all supported operators and parameter checks.
Charles Xu87c13502020-08-06 12:17:26 +020018import numpy as np
19
Tim Hallc30f4952020-06-15 20:47:35 +010020from .data_type import BaseType
21from .data_type import DataType
Dwight Lidman8359a472020-09-28 15:53:40 +020022from .numeric_util import is_integer
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020023from .operation import get_slice_offsets
Louis Verhaardaee5d752020-09-30 09:01:52 +020024from .operation import Op
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020025
26
Michael McGeagh37ded342020-10-01 15:37:44 +010027# Custom decorator function to allow formatting docstrings containing "{}"
28def docstring_format_args(args):
29 def docstring(func):
30 func.__doc__ = func.__doc__.format(*args)
31 return func
32
33 return docstring
34
35
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020036def warn_cpu(op, msg):
37 print("Warning: {} {}, placing on CPU".format(op.type, msg))
Tim Hall79d07d22020-04-27 18:20:16 +010038
39
40class SupportedOperators:
Michael McGeagh1eeea512020-09-30 14:23:09 +010041 # Categorised lists of supported operators
Louis Verhaardaee5d752020-09-30 09:01:52 +020042 npu_pre_ops = set((Op.SplitSliceRead,))
43 convolution_ops = set((Op.Conv2DBias, Op.Conv2D, Op.QuantizedConv2D,))
44 depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
45 transpose_convolution_ops = set((Op.Conv2DBackpropInput,))
46 max_pooling_ops = Op.op_set(Op.is_maxpool_op)
47 avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
48 pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
49 resizing_ops = set((Op.ResizeBilinear,))
50 fc_vector_products = set((Op.QuantizedMatMul, Op.MatMul, Op.FullyConnected,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010051 mac_main_ops = (
52 # RNN/LSTM/GRU
Louis Verhaardaee5d752020-09-30 09:01:52 +020053 set((Op.BlockLSTM,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010054 # convolutions
55 | convolution_ops
56 # depth-wise convolutions
57 | depthwise_convolution_ops
58 # transpose convolutions
59 | transpose_convolution_ops
60 # pooling
61 | pooling_ops
62 # resizing/upscaling
63 | resizing_ops
64 # FC layers
65 | fc_vector_products
66 )
Louis Verhaardaee5d752020-09-30 09:01:52 +020067 unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
68 binary_elem_wise_min_max_ops = set((Op.Minimum, Op.Maximum,))
69 binary_elem_wise_shift_ops = set((Op.SHL, Op.SHR,))
70 binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.Sub,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010071 binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
72 elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
Michael McGeagh37ded342020-10-01 15:37:44 +010073 supported_int32_tensor_ops = (
Louis Verhaardaee5d752020-09-30 09:01:52 +020074 set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
Michael McGeagh37ded342020-10-01 15:37:44 +010075 )
Louis Verhaardaee5d752020-09-30 09:01:52 +020076 activation_ops = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Sigmoid, Op.Tanh, Op.Softmax,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010077 npu_post_ops = (
Michael McGeagh1eeea512020-09-30 14:23:09 +010078 # activation functions
Louis Verhaardaee5d752020-09-30 09:01:52 +020079 activation_ops
80 # concatenation write direction
81 | set((Op.ConcatSliceWrite,))
82 # Quantization
83 | set((Op.Quantize,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010084 )
Louis Verhaardaee5d752020-09-30 09:01:52 +020085 split_ops = set((Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack,))
86 concat_ops = set((Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack,))
87 memory_only_ops = set((Op.Squeeze, Op.Reshape, Op.QuantizedReshape, Op.ExpandDims,)) | concat_ops | split_ops
88 shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV,))
89 supported_fused_activations = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Tanh, Op.Sigmoid, Op.LUT,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010090 supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | npu_post_ops | memory_only_ops
Michael McGeagh37ded342020-10-01 15:37:44 +010091 supported_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
92 # Defined ranges for allowed values:
93 tens_dim_range = (1, 65535)
Michael McGeagh1eeea512020-09-30 14:23:09 +010094
Fredrik Svedberg880e7352020-08-25 11:31:47 +020095 def __init__(self):
Tim Hall79d07d22020-04-27 18:20:16 +010096 # Setup supported operator restriction checkers
97 self.supported_operator_restrictions = {}
98 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +010099 {op: self.check_convolution_restrictions for op in SupportedOperators.convolution_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100100 )
101 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100102 {op: self.check_depthwise_convolution_restrictions for op in SupportedOperators.depthwise_convolution_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100103 )
Jacob Bohlincf7da102020-05-20 09:03:40 +0200104 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100105 {op: self.check_transpose_convolution_restrictions for op in SupportedOperators.transpose_convolution_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100106 )
107 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100108 {op: self.check_pooling_restrictions for op in SupportedOperators.pooling_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100109 )
110 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100111 {op: self.check_resize_restrictions for op in SupportedOperators.resizing_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100112 )
Michael McGeagh1eeea512020-09-30 14:23:09 +0100113 self.supported_operator_restrictions.update(
114 {op: self.check_vector_product_restrictions for op in SupportedOperators.fc_vector_products}
115 )
116 self.supported_operator_restrictions.update(
117 {op: self.check_element_wise_restrictions for op in SupportedOperators.elem_wise_main_ops}
118 )
119 self.supported_operator_restrictions.update(
120 {op: self.check_memory_only_restrictions for op in SupportedOperators.memory_only_ops}
121 )
122 self.supported_operator_restrictions.update(
123 {op: self.check_activation_ops for op in SupportedOperators.activation_ops}
124 )
Michael McGeagh184b2502020-10-09 17:19:52 +0100125 # Setup the generic constraints. Note: the order matters
Michael McGeagh37ded342020-10-01 15:37:44 +0100126 self.generic_constraints = []
127 self.generic_constraints.append(SupportedOperators.constraint_tens_defined_shape)
Michael McGeagh184b2502020-10-09 17:19:52 +0100128 self.generic_constraints.append(SupportedOperators.constraint_tens_output_shapeless)
129 self.generic_constraints.append(SupportedOperators.constraint_tens_input_shapeless)
Michael McGeagh37ded342020-10-01 15:37:44 +0100130 self.generic_constraints.append(SupportedOperators.constraint_tens_shape_size)
131 self.generic_constraints.append(SupportedOperators.constraint_tens_dtype)
Michael McGeagh184b2502020-10-09 17:19:52 +0100132 self.generic_constraints.append(SupportedOperators.constraint_tens_int32_ops)
Michael McGeagh37ded342020-10-01 15:37:44 +0100133 self.generic_constraints.append(SupportedOperators.constraint_tens_dimension)
Dwight Lidman8359a472020-09-28 15:53:40 +0200134 self.generic_constraints.append(SupportedOperators.constraint_tens_quant_none_check)
Michael McGeagh184b2502020-10-09 17:19:52 +0100135 self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale)
136 self.generic_constraints.append(SupportedOperators.constraint_faf)
Tim Hall79d07d22020-04-27 18:20:16 +0100137
138 def is_operator_supported(self, op):
Michael McGeagh1eeea512020-09-30 14:23:09 +0100139 if op.type not in SupportedOperators.supported_operators:
Patrik Gustavsson5554bbe2020-10-13 09:02:48 +0200140 print('Info: "{}" is not supported on the NPU. Placing on CPU instead'.format(op.type))
Tim Hall79d07d22020-04-27 18:20:16 +0100141 return False
Michael McGeagh37ded342020-10-01 15:37:44 +0100142 for constraint in self.generic_constraints:
143 valid, extra = constraint(op)
144 if not valid:
Michael McGeagh184b2502020-10-09 17:19:52 +0100145 print("Warning: {} '{}' is not supported on the NPU. Placing on CPU instead".format(op.type, op.name))
Michael McGeagh37ded342020-10-01 15:37:44 +0100146 print(" - {}".format(constraint.__doc__))
147 if extra:
148 print(" {}".format(extra))
149 return False
Tim Hall79d07d22020-04-27 18:20:16 +0100150 if op.type in self.supported_operator_restrictions:
151 return self.supported_operator_restrictions[op.type](op)
152 return True
153
Michael McGeagh37ded342020-10-01 15:37:44 +0100154 @staticmethod
155 def constraint_tens_defined_shape(op):
156 "Input(s) and Output Tensors must have a defined shape"
157 valid = True
158 extra = []
Michael McGeagh184b2502020-10-09 17:19:52 +0100159 tensors = [tens for tens in op.inputs + op.outputs if tens]
160 for tens in tensors:
161 if not tens.has_fully_defined_shape():
162 valid = False
163 extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
164 return valid, ", ".join(extra)
Michael McGeagh37ded342020-10-01 15:37:44 +0100165
Michael McGeagh184b2502020-10-09 17:19:52 +0100166 @staticmethod
167 def constraint_tens_output_shapeless(op):
168 "Scalar or Broadcasting Tensors are only valid for Input Tensors"
Michael McGeagh37ded342020-10-01 15:37:44 +0100169 valid = True
170 extra = []
Michael McGeagh37ded342020-10-01 15:37:44 +0100171 for tens in op.outputs:
172 if tens.shape == []:
173 valid = False
Michael McGeagh184b2502020-10-09 17:19:52 +0100174 extra.append("Output Tensor '{}' is shapeless".format(tens.name))
175 return valid, ", ".join(extra)
176
177 @classmethod
178 @docstring_format_args([shapeless_input_ops])
179 def constraint_tens_input_shapeless(cls, op):
180 "Scalar or Broadcasting Input Tensors are only valid for op type: {}"
181 valid = True
182 extra = []
183 tensors = [tens for tens in op.inputs if tens]
184 for tens in tensors:
185 if (tens.shape == []) and (op.type not in cls.shapeless_input_ops):
186 valid = False
187 extra.append(tens.name)
188 extra = "Op '{}' has shapeless input tensor(s): {}".format(op.name, ", ".join(extra))
189 return valid, extra
Tim Hall79d07d22020-04-27 18:20:16 +0100190
Michael McGeagh37ded342020-10-01 15:37:44 +0100191 @staticmethod
192 def constraint_tens_shape_size(op):
193 "Input(s) and Output Tensors must not be greater than 4D"
194 valid = True
195 extra = []
Michael McGeagh184b2502020-10-09 17:19:52 +0100196 tensors = [tens for tens in op.inputs + op.outputs if tens]
197 for tens in tensors:
198 if len(tens.shape) > 4:
199 valid = False
200 extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
201 return valid, ", ".join(extra)
Tim Hall79d07d22020-04-27 18:20:16 +0100202
Michael McGeagh37ded342020-10-01 15:37:44 +0100203 @classmethod
Michael McGeagh184b2502020-10-09 17:19:52 +0100204 @docstring_format_args([supported_dtypes])
Michael McGeagh37ded342020-10-01 15:37:44 +0100205 def constraint_tens_dtype(cls, op):
Michael McGeagh184b2502020-10-09 17:19:52 +0100206 "Tensors must be of type: {}"
Michael McGeagh37ded342020-10-01 15:37:44 +0100207 valid = True
208 extra = []
209 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
210 tensors = tensors if tensors else op.inputs
211 for tens in tensors:
Michael McGeagh184b2502020-10-09 17:19:52 +0100212 if tens.dtype not in cls.supported_dtypes:
213 valid = False
214 extra.append("Tensor '{}' has data type: {}".format(tens.name, tens.dtype))
215 return valid, ", ".join(extra)
216
217 @classmethod
218 @docstring_format_args([supported_int32_tensor_ops])
219 def constraint_tens_int32_ops(cls, op):
220 "Tensors which are int32 are only valid when op type is: {}"
221 valid = True
222 extra = []
223 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
224 tensors = tensors if tensors else op.inputs
225 for tens in tensors:
226 if (tens.dtype == DataType.int32) and (op.type not in cls.supported_int32_tensor_ops):
227 valid = False
228 extra.append(tens.name)
229 extra = "Op '{}' has int32 tensor(s): {}".format(op.name, ", ".join(extra))
230 return valid, extra
Andreas Nevalaineneadb1662020-09-01 15:36:26 +0200231
Michael McGeagh37ded342020-10-01 15:37:44 +0100232 @classmethod
233 @docstring_format_args(tens_dim_range)
234 def constraint_tens_dimension(cls, op):
235 "Tensor dimensions must be in the range {}-{} (inclusive)"
236 tens_min, tens_max = cls.tens_dim_range
237 valid = True
238 extra = []
239 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
240 tensors = tensors if tensors else op.inputs
241 for tens in tensors:
Michael McGeagh184b2502020-10-09 17:19:52 +0100242 if not all(tens_min <= dim <= tens_max for dim in tens.shape):
243 valid = False
244 extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
245 return valid, ", ".join(extra)
Tim Hall79d07d22020-04-27 18:20:16 +0100246
Dwight Lidman8359a472020-09-28 15:53:40 +0200247 @staticmethod
248 def constraint_tens_quant_none_check(op):
249 "Tensors must have quantization parameters"
250 valid = True
251 extra = []
252 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
253 for tens in tensors:
254 if tens.quantization is None:
255 valid = False
256 extra.append("Tensor '{}' has no quantization parameters".format(tens.name))
257 return valid, ", ".join(extra)
258
Michael McGeagh184b2502020-10-09 17:19:52 +0100259 @staticmethod
260 def constraint_tens_quant_scale(op):
261 "Tensors with quantization scales must be finite"
262 valid = True
263 extra = []
264 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
265 for tens in tensors:
266 if (tens.quantization.scale_f32 is not None) and np.isinf(tens.quantization.scale_f32).any():
267 valid = False
268 extra.append("Tensor '{}' has quantization scale: {}".format(tens.name, tens.quantization.scale_f32))
269 return valid, ", ".join(extra)
270
271 @classmethod
272 @docstring_format_args([supported_fused_activations])
273 def constraint_faf(cls, op):
274 "The fused activation function (if present) must be one of type: {}"
275 faf = op.activation
276 valid = (faf is None) or (faf in cls.supported_fused_activations)
277 extra = "Op '{}' has its fused activation function as: {}".format(op.name, faf)
278 return valid, extra
279
Michael McGeagh1eeea512020-09-30 14:23:09 +0100280 @classmethod
281 def check_convolution_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100282 # check stride
Dwight Lidman8359a472020-09-28 15:53:40 +0200283 stride_w, stride_h = op.attrs["stride_w"], op.attrs["stride_h"]
284 if not is_integer(stride_w) or not is_integer(stride_h):
285 print("Warning:", op.type, "has non-integer stride, placing on CPU")
286 return False
287 if not 1 <= stride_w <= 3 or not 1 <= stride_h <= 3:
288 print(
289 "Warning: {} has stride ({}, {}), only strides in range [1, 3] are allowed. Placing on CPU".format(
290 op.type, stride_w, stride_h
291 )
292 )
Tim Hall79d07d22020-04-27 18:20:16 +0100293 return False
294
295 # check dilation
296 dilation_w_factor = op.attrs.get("dilation_w_factor", 1)
297 dilation_h_factor = op.attrs.get("dilation_h_factor", 1)
Dwight Lidman8359a472020-09-28 15:53:40 +0200298 if not is_integer(dilation_w_factor) or not is_integer(dilation_h_factor):
299 print("Warning:", op.type, "has non-integer dilation factor, placing on CPU")
300 return False
301 if not 1 <= dilation_w_factor <= 2 or not 1 <= dilation_h_factor <= 2:
302 print(
303 "Warning:",
304 op.type,
305 "has dilation factors ({}, {}), only factors in range [1, 2] are allowed. Placing on CPU".format(
306 dilation_w_factor, dilation_h_factor
307 ),
308 )
Tim Hall79d07d22020-04-27 18:20:16 +0100309 return False
310
311 # check data type
Jacob Bohlin49d92122020-08-19 14:36:46 +0200312 ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100313 if weight_tensor.element_size() > 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200314 print("Warning: only 8-bit weights are supported, placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100315 return False
316
Michael McGeagh1eeea512020-09-30 14:23:09 +0100317 if not cls.check_bias_restrictions(bias_tensor):
Jacob Bohlin49d92122020-08-19 14:36:46 +0200318 return False
319
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200320 # check kernel size [HWIO]
Dwight Lidman8359a472020-09-28 15:53:40 +0200321 dilated_weight_w = (weight_tensor.shape[1] - 1) * dilation_w_factor + 1
322 dilated_weight_h = (weight_tensor.shape[0] - 1) * dilation_h_factor + 1
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200323
Dwight Lidman8359a472020-09-28 15:53:40 +0200324 # kernel limits
325 if not 1 <= dilated_weight_h <= 64:
326 print("Warning:", op.type, "has kernel height outside of range [1, 64], placing on CPU")
327 return False
328 if not 1 <= dilated_weight_w * dilated_weight_h <= 64 * 64:
329 print(
330 "Warning: product of kernel width and height must be >= 1 and not exceed 64 * 64 ({}),".format(64 * 64),
331 "placing on CPU",
332 )
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200333 return False
334
Andreas Nevalainen8854dc92020-09-24 13:43:00 +0200335 # check non const weights
336 if weight_tensor.values is None:
Dwight Lidman8359a472020-09-28 15:53:40 +0200337 print("Warning:", op.type, "has non-constant weights, placing on CPU")
Andreas Nevalainen8854dc92020-09-24 13:43:00 +0200338 return False
339
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200340 # check weight sums over [HWI]
341 zero_point = weight_tensor.quantization.zero_point
342 quant_weights = weight_tensor.quant_values.astype(np.int64)
343 weights = quant_weights - zero_point
344 totals = np.sum(np.absolute(weights), axis=(0, 1, 2))
345
346 if np.amax(totals) > 127 * 65536:
Dwight Lidman8359a472020-09-28 15:53:40 +0200347 print("Warning: sum of weights exceeds 127 * 65536 ({}), placing on CPU".format(127 * 65536))
Tim Hall79d07d22020-04-27 18:20:16 +0100348 return False
349
350 # check batch size
351 if ifm_tensor.shape[0] != 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200352 print("Warning: only batch sizes of 1 are supported, placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100353 return False
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200354
Tim Hall79d07d22020-04-27 18:20:16 +0100355 return True
356
Michael McGeagh1eeea512020-09-30 14:23:09 +0100357 @classmethod
358 def check_depthwise_convolution_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100359 # check depth
Louis Verhaardaee5d752020-09-30 09:01:52 +0200360 ifm_tensor, ofm_tensor = op.get_ifm_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100361 if op.attrs["depth_multiplier"] > 1 and not (
362 (ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"])
363 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200364 print(
365 "Warning: for depth multipliers > 1,",
366 "number of input channels must be 1 and number of output channels must be equal to depth multiplier.",
367 "Placing on CPU",
368 )
Tim Hall79d07d22020-04-27 18:20:16 +0100369 return False
Michael McGeagh1eeea512020-09-30 14:23:09 +0100370 return cls.check_convolution_restrictions(op)
Tim Hall79d07d22020-04-27 18:20:16 +0100371
Michael McGeagh1eeea512020-09-30 14:23:09 +0100372 @classmethod
373 def check_transpose_convolution_restrictions(cls, op):
Jacob Bohlincf7da102020-05-20 09:03:40 +0200374 # check stride
375 stride_h, stride_w = op.attrs["stride_h"], op.attrs["stride_w"]
Dwight Lidman8359a472020-09-28 15:53:40 +0200376 if stride_h != 2 or stride_w != 2:
377 print("Warning: stride must be equal to 2, placing on CPU")
Jacob Bohlincf7da102020-05-20 09:03:40 +0200378 return False
379
380 # check output dimensions
381 ifm_tensor, weight_tensor, _, ofm_tensor = op.get_ifm_weights_biases_ofm()
382 ifm_h, ifm_w = ifm_tensor.shape[1], ifm_tensor.shape[2]
383 ofm_h, ofm_w = ofm_tensor.shape[1], ofm_tensor.shape[2]
384 if op.attrs["padding"] == b"SAME":
385 if (ofm_h != ifm_h * stride_h) or (ofm_w != ifm_w * stride_w):
Dwight Lidman8359a472020-09-28 15:53:40 +0200386 print(
387 "Warning: for",
388 op.type,
389 "using SAME padding, output dimensions must equal input dimensions multiplied by stride.",
390 "Placing on CPU",
391 )
Jacob Bohlincf7da102020-05-20 09:03:40 +0200392 return False
393 elif op.attrs["padding"] == b"VALID":
394 kernel_h, kernel_w = weight_tensor.shape[0], weight_tensor.shape[1]
Tim Hallc30f4952020-06-15 20:47:35 +0100395 if (ofm_h != (ifm_h) * stride_h + max(kernel_h - stride_h, 0)) or (
396 ofm_w != (ifm_w) * stride_w + max(kernel_w - stride_w, 0)
397 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200398 print(
399 "Warning: for",
400 op.type,
401 "using VALID padding, output dimensions must equal input dimensions multiplied by stride,",
402 "minus difference between kernel size and stride. Placing on CPU",
403 )
Jacob Bohlincf7da102020-05-20 09:03:40 +0200404 return False
405
Michael McGeagh1eeea512020-09-30 14:23:09 +0100406 return cls.check_convolution_restrictions(op)
Jacob Bohlincf7da102020-05-20 09:03:40 +0200407
Michael McGeagh1eeea512020-09-30 14:23:09 +0100408 @classmethod
409 def check_pooling_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100410 # check stride
Dwight Lidman8359a472020-09-28 15:53:40 +0200411 stride_w, stride_h = op.attrs["stride_w"], op.attrs["stride_h"]
412 if not is_integer(stride_w) or not is_integer(stride_h):
413 print("Warning:", op.type, "has non-integer stride, placing on CPU")
414 return False
415 if not 1 <= stride_w <= 3 or not 1 <= stride_h <= 3:
416 print(
417 "Warning: {} has stride ({}, {}), only strides in range [1, 3] are allowed. Placing on CPU".format(
418 op.type, stride_w, stride_h
419 )
420 )
Tim Hall79d07d22020-04-27 18:20:16 +0100421 return False
422
423 # check data type
Louis Verhaardaee5d752020-09-30 09:01:52 +0200424 ifm_tensor, ofm_tensor = op.get_ifm_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100425 if ifm_tensor.dtype != ofm_tensor.dtype:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200426 if op.type != Op.ReduceSum:
Dwight Lidman8359a472020-09-28 15:53:40 +0200427 print("Warning: input data type doesn't match output data type, placing on CPU")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200428 return False
429 # TODO: else check ReduceSum restrictions.
Tim Hall79d07d22020-04-27 18:20:16 +0100430
431 # check batch size
432 if ifm_tensor.shape[0] != 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200433 print("Warning: input batch size must be 1, placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100434 return False
435
Dwight Lidman8359a472020-09-28 15:53:40 +0200436 # check kernel size
437 kernel_w, kernel_h = op.attrs["filter_width"], op.attrs["filter_height"]
438 if op.type in cls.avg_pooling_ops and op.attrs["padding"] == b"SAME":
439 if not 1 <= kernel_w <= 8 or not 1 <= kernel_h <= 8:
440 print(
441 "Warning:",
442 op.type,
443 "has kernel size ({}, {}), only kernel sizes in range [1, 8] are allowed. Placing on CPU".format(
444 kernel_w, kernel_h
445 ),
446 )
Tim Hall79d07d22020-04-27 18:20:16 +0100447 return False
Dwight Lidman8359a472020-09-28 15:53:40 +0200448 if op.type in cls.avg_pooling_ops and op.attrs["padding"] == b"VALID" or op.type in cls.max_pooling_ops:
449 if not 1 <= kernel_w * kernel_h <= 256 * 256:
450 print(
451 "Warning: product of kernel width and height must be >= 1 and not exceed 256 * 256 ({}),".format(
452 256 * 256
453 ),
454 "placing on CPU",
455 )
456 return False
457 if not 1 <= kernel_h <= 256:
458 print("Warning:", op.type, "has kernel height outside of range [1, 256], placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100459 return False
460
Tim Hall79d07d22020-04-27 18:20:16 +0100461 return True
462
Michael McGeagh1eeea512020-09-30 14:23:09 +0100463 @classmethod
464 def check_resize_restrictions(cls, op):
Dwight Lidman42fed942020-05-29 09:37:03 +0200465 # check unsupported upscaling factor
Louis Verhaardaee5d752020-09-30 09:01:52 +0200466 if op.type == Op.ResizeBilinear:
Charles Xu9a03fdf2020-07-02 15:12:40 +0200467 if op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1:
468 return True
Charles Xu36ffaf32020-08-05 15:40:44 +0200469 if op.inputs[0].shape == op.outputs[0].shape:
470 return True
Charles Xu87c13502020-08-06 12:17:26 +0200471 upscaled_shape = np.array(op.inputs[0].shape[1:3])
472 out_shape = np.array(op.outputs[0].shape[1:3])
473 while (upscaled_shape < out_shape).all():
474 upscaled_shape *= 2
475 if op.attrs["align_corners"]:
476 upscaled_shape -= 1
477 if np.array_equal(out_shape, upscaled_shape):
478 return True
479 return False
Dwight Lidman42fed942020-05-29 09:37:03 +0200480
Michael McGeagh1eeea512020-09-30 14:23:09 +0100481 @classmethod
482 def check_vector_product_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100483 # check data type
Dwight Lidman8359a472020-09-28 15:53:40 +0200484 ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100485 if weight_tensor.element_size() > 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200486 print("Warning: only 8-bit datatypes supported for {}, placing on CPU".format(op.type))
487 return False
488
Michael McGeagh1eeea512020-09-30 14:23:09 +0100489 if not cls.check_bias_restrictions(bias_tensor):
Jacob Bohlin49d92122020-08-19 14:36:46 +0200490 return False
491
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200492 # check non const weights
493 if weight_tensor.values is None:
494 print("Warning:", op.type, "has non-const weights, placing on CPU")
495 return False
496
Tim Hall79d07d22020-04-27 18:20:16 +0100497 return True
498
Michael McGeagh1eeea512020-09-30 14:23:09 +0100499 @classmethod
500 def check_element_wise_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100501 # check data type
502 ifm_tensor, ifm2_tensor, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200503 # input and output datatype must match for these operators
Tim Hallc30f4952020-06-15 20:47:35 +0100504 if (
Michael McGeagh1eeea512020-09-30 14:23:09 +0100505 op.type in cls.binary_elem_wise_min_max_ops | cls.unary_elem_wise_main_ops
Tim Hallc30f4952020-06-15 20:47:35 +0100506 and ifm_tensor.dtype != ofm_tensor.dtype
507 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200508 print("Warning:", op.type, "must have same input and output datatype, placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100509 return False
Michael McGeagh1eeea512020-09-30 14:23:09 +0100510 if op.type in cls.binary_elem_wise_add_mul_sub:
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200511 # both inputs must have same type
Tim Hallc30f4952020-06-15 20:47:35 +0100512 if ifm_tensor.dtype != ifm2_tensor.dtype:
Dwight Lidman8359a472020-09-28 15:53:40 +0200513 print("Warning:", op.type, "must have same datatype on both inputs, placing on CPU")
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200514 return False
515 # signed input check
Tim Hallc30f4952020-06-15 20:47:35 +0100516 if ifm_tensor.dtype.type & BaseType.Signed:
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200517 # output must be signed
Tim Hallc30f4952020-06-15 20:47:35 +0100518 if ofm_tensor.dtype.type & BaseType.Unsigned:
Dwight Lidman8359a472020-09-28 15:53:40 +0200519 print("Warning: only signed output types supported for {}, placing on CPU".format(op.type))
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200520 return False
521 # and 8, 16 or 32-bit
Dwight Lidman8359a472020-09-28 15:53:40 +0200522 bit_lengths = {8, 16, 32}
523 if ofm_tensor.element_size() * 8 not in bit_lengths:
524 print(
525 "Warning:", op.type, "is only supported for bit lengths {}, placing on CPU".format(bit_lengths)
526 )
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200527 return False
528 # unsigned input check, output must be same type or int32
Tim Hallc30f4952020-06-15 20:47:35 +0100529 if ifm_tensor.dtype.type & BaseType.Unsigned and not (
530 ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32
531 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200532 print("Warning:", op.type, "has unsigned input but output is not unsigned or int32, placing on CPU")
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200533 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200534 elif op.type in cls.binary_elem_wise_shift_ops:
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200535 if ifm_tensor.dtype != DataType.int32 or ifm2_tensor.dtype != DataType.int32:
Dwight Lidman8359a472020-09-28 15:53:40 +0200536 print("Warning:", op.type, "input datatypes are not int32, placing on CPU")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200537 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200538 if op.type in (Op.CLZ, Op.SHL) and ofm_tensor.dtype != DataType.int32:
Dwight Lidman8359a472020-09-28 15:53:40 +0200539 print("Warning:", op.type, "output datatype is not int32, placing on CPU")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200540 return False
Tim Hall79d07d22020-04-27 18:20:16 +0100541
542 # check batch size
Dwight Lidmanf995db72020-04-27 11:15:12 +0200543 if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200544 print(
545 "Warning:",
546 op.type,
547 "only supports batch size 1 for tensors with more than 2 dimensions, placing on CPU",
548 )
Tim Hallc30f4952020-06-15 20:47:35 +0100549 return False
Michael McGeagh1eeea512020-09-30 14:23:09 +0100550 if op.type in cls.binary_elem_wise_main_ops: # if op type is unary, ifm2_tensor is None
Dwight Lidmanf995db72020-04-27 11:15:12 +0200551 if len(ifm2_tensor.shape) > 2 and ifm2_tensor.shape[0] != 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200552 print(
553 "Warning:",
554 op.type,
555 "only supports batch size 1 for tensors with more than 2 dimensions, placing on CPU",
556 )
Dwight Lidmanf995db72020-04-27 11:15:12 +0200557 return False
Dwight Lidman332a7042020-06-11 15:32:42 +0200558
559 # negative alpha values are not supported
Louis Verhaardaee5d752020-09-30 09:01:52 +0200560 if op.type == Op.LeakyRelu and op.attrs["alpha"] < 0:
Dwight Lidman8359a472020-09-28 15:53:40 +0200561 print("Warning:", op.type, "has negative alpha, placing on CPU")
Dwight Lidman332a7042020-06-11 15:32:42 +0200562 return False
563
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200564 # check if ifm or ifm2 has ofm shape
565 if ifm_tensor.shape != ofm_tensor.shape and ifm2_tensor.shape != ofm_tensor.shape:
Dwight Lidman8359a472020-09-28 15:53:40 +0200566 print("Warning:", op.type, "input shape(s) differ from output shape, placing on CPU")
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200567 return False
568
Michael McGeagh1eeea512020-09-30 14:23:09 +0100569 if op.type in cls.binary_elem_wise_min_max_ops and not cls.check_quantization_restrictions_binary_elem_wise(op):
Patrik Gustavsson530992a2020-09-30 13:26:59 +0200570 return False
571
Tim Hall79d07d22020-04-27 18:20:16 +0100572 return True
573
Michael McGeagh1eeea512020-09-30 14:23:09 +0100574 @classmethod
575 def check_memory_only_restrictions(cls, op):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200576 if op.type == Op.StridedSlice:
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200577 if len(op.inputs) != 4:
578 warn_cpu(op, "has {} input tensors, only 4 inputs are supported".format(len(op.inputs)))
Tim Hall79d07d22020-04-27 18:20:16 +0100579 return False
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200580 input_tens, begin_tens, end_tens, strides_tens = op.inputs
581 if begin_tens.values is None or end_tens.values is None or strides_tens.values is None:
582 warn_cpu(op, "has a non-constant begin, end, or stride input tensor, which is not supported")
583 return False
584 if not (
585 len(input_tens.shape)
586 == len(op.outputs[0].shape)
587 == len(begin_tens.values)
588 == len(end_tens.values)
589 == len(strides_tens.values)
590 ):
591 warn_cpu(op, "has input tensors with shapes that are not supported")
592 return False
593 # check stride size
594 if any(stride != 1 for stride in strides_tens.values):
595 warn_cpu(op, "has stride values {}, only stride 1 values are supported".format(strides_tens.values))
Michael McGeaghecd20522020-07-31 16:59:45 +0100596 return False
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200597 # check ellipsis_mask
598 if op.attrs["ellipsis_mask"] != 0:
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200599 warn_cpu(op, "ellipsis_mask is {}, only 0 is supported".format(op.attrs["ellipsis_mask"]))
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200600 return False
601 # check if both new_axis_mask and shrink_axis_mask have bit set
602 if op.attrs["new_axis_mask"] != 0 and op.attrs["shrink_axis_mask"] != 0:
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200603 warn_cpu(op, "new_axis_mask and shrink_axis_mask are both non-zero, which is not supported")
604 return False
605 # Calculate offset start/end
606 offset_start = get_slice_offsets(input_tens.shape, begin_tens, op.attrs["begin_mask"], is_begin=True)
607 offset_end = get_slice_offsets(input_tens.shape, end_tens, op.attrs["end_mask"], is_begin=False)
608 # check "end - begin" doesn't result in any zero or negative elements
609 if any((end - begin) <= 0 for begin, end in zip(offset_start, offset_end)):
610 warn_cpu(
611 op,
612 "has slice begin values {}, some of which are >= end values {}, which is illegal".format(
613 begin_tens.values, end_tens.values
614 ),
615 )
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200616 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200617 if op.type == Op.SplitV:
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200618 # check that maximum one size is set to -1, indicating that size should be inferred
619 sizes = op.inputs[1].values
620 num_to_be_inferred = 0
621 for size in sizes:
622 if size == -1:
623 num_to_be_inferred += 1
624
625 if num_to_be_inferred > 1:
626 print("Warning:", op.type, "has more than one size to be inferred, which is illegal, placing on CPU")
627 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200628 if op.type in set((Op.Concat, Op.ConcatTFLite,)):
Fredrik Svedberg0f98b362020-09-29 10:00:39 +0200629 axis = op.attrs.get("axis", None)
630 if axis is None:
631 print("Warning:", op.type, "invalid or missing axis, placing on CPU")
632 return False
633 if axis < 0:
634 axis += len(op.inputs[0].shape)
Patrik Gustavsson36ad73a2020-10-06 13:58:24 +0200635 if not 0 <= axis < len(op.inputs[0].shape):
Fredrik Svedberg0f98b362020-09-29 10:00:39 +0200636 print("Warning:", op.type, "invalid axis", axis, ", placing on CPU")
637 return False
638 ofm = op.outputs[0]
639 ofm_dims = len(ofm.shape)
640 for ifm in op.inputs:
641 if len(ifm.shape) != ofm_dims:
642 return False
643 for i in range(ofm_dims):
644 if i != axis and ifm.shape[i] != ofm.shape[i]:
Patrik Gustavsson530992a2020-09-30 13:26:59 +0200645 print(
646 "Warning:",
647 op.type,
648 "invalid ifm:",
649 ifm.name,
650 ifm.shape,
651 "mismatch in dimension",
652 i,
653 ", placing on CPU",
654 )
Fredrik Svedberg0f98b362020-09-29 10:00:39 +0200655 return False
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200656
Tim Hall79d07d22020-04-27 18:20:16 +0100657 return True
Dwight Lidmanebe26c72020-06-09 11:40:54 +0200658
Michael McGeagh1eeea512020-09-30 14:23:09 +0100659 @classmethod
660 def check_quantization_restrictions_binary_elem_wise(cls, op):
Dwight Lidmanebe26c72020-06-09 11:40:54 +0200661 # makes sure IFM1, IFM2 and OFM quantization are equal for binary ops
Tim Halle3786ac2020-07-28 17:40:50 +0100662 assert len(op.inputs) >= 2 and len(op.outputs) == 1
663
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200664 if (
Tim Halle3786ac2020-07-28 17:40:50 +0100665 op.inputs[0].quantization is None
Michael McGeagh34ad19b2020-09-04 15:44:23 +0100666 or not op.inputs[0].is_scaling_equal(op.inputs[1])
667 or not op.inputs[0].is_scaling_equal(op.outputs[0])
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200668 ):
669 print(
670 "Warning: Input/output tensors with different quantization is unsupported for the", op.type, "operator"
671 )
Dwight Lidmanebe26c72020-06-09 11:40:54 +0200672 return False
Tim Halle3786ac2020-07-28 17:40:50 +0100673
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200674 return True
675
Michael McGeagh1eeea512020-09-30 14:23:09 +0100676 @classmethod
677 def check_activation_ops(cls, op):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200678 if op.type == Op.Softmax:
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200679 ifm_tensor = op.inputs[0]
680 ofm_tensor = op.outputs[0]
681
682 # check data type
683 if ifm_tensor.dtype != ofm_tensor.dtype:
Dwight Lidman8359a472020-09-28 15:53:40 +0200684 print("Warning:", op.type, "input type differs from output type, placing on CPU")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200685 return False
686
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200687 if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16):
Dwight Lidman8359a472020-09-28 15:53:40 +0200688 print(
689 "Warning: only datatypes supported for {} are uint8, int8 and int16; placing on CPU".format(op.type)
690 )
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200691 return False
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200692
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200693 # check shape
Michael McGeagh37ded342020-10-01 15:37:44 +0100694 if ifm_tensor.shape != ofm_tensor.shape:
Dwight Lidman8359a472020-09-28 15:53:40 +0200695 print("Warning:", op.type, "input shape differs from output shape, placing on CPU")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200696 return False
697
698 return True
Jacob Bohlin49d92122020-08-19 14:36:46 +0200699
Michael McGeagh1eeea512020-09-30 14:23:09 +0100700 @classmethod
701 def check_bias_restrictions(cls, bias_tensor):
Jacob Bohlin49d92122020-08-19 14:36:46 +0200702 # check data type
Jacob Bohlin258ebba2020-08-31 10:44:35 +0200703 if bias_tensor is not None and bias_tensor.dtype not in (DataType.int32, DataType.int64):
Dwight Lidman8359a472020-09-28 15:53:40 +0200704 print("Warning: bias tensor datatype must be int32 or int64, placing on CPU")
Jacob Bohlin49d92122020-08-19 14:36:46 +0200705 return False
706
707 # check if values fits in 40-bit
Jacob Bohlin258ebba2020-08-31 10:44:35 +0200708 if bias_tensor is not None and bias_tensor.dtype == DataType.int64:
Tim Hall71525172020-08-29 15:09:57 +0100709 for quant_value in bias_tensor.quant_values:
710 if not (-(1 << 39) <= quant_value < (1 << 39)):
Dwight Lidman8359a472020-09-28 15:53:40 +0200711 print("Warning: bias tensor values are larger than 40 bits, placing on CPU")
Jacob Bohlin49d92122020-08-19 14:36:46 +0200712 return False
713
714 return True