blob: 66c74fcef250488ba4a62077b99402c1035ad3af [file] [log] [blame]
Tim Hall79d07d22020-04-27 18:20:16 +01001# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
Tim Hall79d07d22020-04-27 18:20:16 +010016# Description:
17# The SupportedOperators class which is a collection of all supported operators and parameter checks.
Charles Xu87c13502020-08-06 12:17:26 +020018import numpy as np
19
Tim Hallc30f4952020-06-15 20:47:35 +010020from .data_type import BaseType
21from .data_type import DataType
Dwight Lidman8359a472020-09-28 15:53:40 +020022from .numeric_util import is_integer
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020023from .operation import get_slice_offsets
Louis Verhaardaee5d752020-09-30 09:01:52 +020024from .operation import Op
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020025
26
Michael McGeagh37ded342020-10-01 15:37:44 +010027# Custom decorator function to allow formatting docstrings containing "{}"
28def docstring_format_args(args):
29 def docstring(func):
30 func.__doc__ = func.__doc__.format(*args)
31 return func
32
33 return docstring
34
35
Louis Verhaardfa2f92a2020-09-21 11:56:18 +020036def warn_cpu(op, msg):
37 print("Warning: {} {}, placing on CPU".format(op.type, msg))
Tim Hall79d07d22020-04-27 18:20:16 +010038
39
40class SupportedOperators:
Michael McGeagh1eeea512020-09-30 14:23:09 +010041 # Categorised lists of supported operators
Louis Verhaardaee5d752020-09-30 09:01:52 +020042 npu_pre_ops = set((Op.SplitSliceRead,))
43 convolution_ops = set((Op.Conv2DBias, Op.Conv2D, Op.QuantizedConv2D,))
44 depthwise_convolution_ops = set((Op.DepthwiseConv2DBias,))
45 transpose_convolution_ops = set((Op.Conv2DBackpropInput,))
46 max_pooling_ops = Op.op_set(Op.is_maxpool_op)
47 avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
48 pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
49 resizing_ops = set((Op.ResizeBilinear,))
50 fc_vector_products = set((Op.QuantizedMatMul, Op.MatMul, Op.FullyConnected,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010051 mac_main_ops = (
52 # RNN/LSTM/GRU
Louis Verhaardaee5d752020-09-30 09:01:52 +020053 set((Op.BlockLSTM,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010054 # convolutions
55 | convolution_ops
56 # depth-wise convolutions
57 | depthwise_convolution_ops
58 # transpose convolutions
59 | transpose_convolution_ops
60 # pooling
61 | pooling_ops
62 # resizing/upscaling
63 | resizing_ops
64 # FC layers
65 | fc_vector_products
66 )
Louis Verhaardaee5d752020-09-30 09:01:52 +020067 unary_elem_wise_main_ops = Op.op_set(Op.is_unary_elementwise_op)
68 binary_elem_wise_min_max_ops = set((Op.Minimum, Op.Maximum,))
69 binary_elem_wise_shift_ops = set((Op.SHL, Op.SHR,))
70 binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.Sub,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010071 binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
72 elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops
Michael McGeagh37ded342020-10-01 15:37:44 +010073 supported_int32_tensor_ops = (
Louis Verhaardaee5d752020-09-30 09:01:52 +020074 set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
Michael McGeagh37ded342020-10-01 15:37:44 +010075 )
Louis Verhaardaee5d752020-09-30 09:01:52 +020076 activation_ops = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Sigmoid, Op.Tanh, Op.Softmax,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010077 npu_post_ops = (
Michael McGeagh1eeea512020-09-30 14:23:09 +010078 # activation functions
Louis Verhaardaee5d752020-09-30 09:01:52 +020079 activation_ops
80 # concatenation write direction
81 | set((Op.ConcatSliceWrite,))
82 # Quantization
83 | set((Op.Quantize,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010084 )
Louis Verhaardaee5d752020-09-30 09:01:52 +020085 split_ops = set((Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack,))
86 concat_ops = set((Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack,))
87 memory_only_ops = set((Op.Squeeze, Op.Reshape, Op.QuantizedReshape, Op.ExpandDims,)) | concat_ops | split_ops
88 shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV,))
89 supported_fused_activations = set((Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Tanh, Op.Sigmoid, Op.LUT,))
Michael McGeagh1eeea512020-09-30 14:23:09 +010090 supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | npu_post_ops | memory_only_ops
Michael McGeagh37ded342020-10-01 15:37:44 +010091 supported_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32))
92 # Defined ranges for allowed values:
93 tens_dim_range = (1, 65535)
Michael McGeagh1eeea512020-09-30 14:23:09 +010094
Fredrik Svedberg880e7352020-08-25 11:31:47 +020095 def __init__(self):
Tim Hall79d07d22020-04-27 18:20:16 +010096 # Setup supported operator restriction checkers
97 self.supported_operator_restrictions = {}
98 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +010099 {op: self.check_convolution_restrictions for op in SupportedOperators.convolution_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100100 )
101 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100102 {op: self.check_depthwise_convolution_restrictions for op in SupportedOperators.depthwise_convolution_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100103 )
Jacob Bohlincf7da102020-05-20 09:03:40 +0200104 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100105 {op: self.check_transpose_convolution_restrictions for op in SupportedOperators.transpose_convolution_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100106 )
107 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100108 {op: self.check_pooling_restrictions for op in SupportedOperators.pooling_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100109 )
110 self.supported_operator_restrictions.update(
Michael McGeagh1eeea512020-09-30 14:23:09 +0100111 {op: self.check_resize_restrictions for op in SupportedOperators.resizing_ops}
Tim Hall79d07d22020-04-27 18:20:16 +0100112 )
Michael McGeagh1eeea512020-09-30 14:23:09 +0100113 self.supported_operator_restrictions.update(
114 {op: self.check_vector_product_restrictions for op in SupportedOperators.fc_vector_products}
115 )
116 self.supported_operator_restrictions.update(
117 {op: self.check_element_wise_restrictions for op in SupportedOperators.elem_wise_main_ops}
118 )
119 self.supported_operator_restrictions.update(
120 {op: self.check_memory_only_restrictions for op in SupportedOperators.memory_only_ops}
121 )
122 self.supported_operator_restrictions.update(
123 {op: self.check_activation_ops for op in SupportedOperators.activation_ops}
124 )
Michael McGeagh184b2502020-10-09 17:19:52 +0100125 # Setup the generic constraints. Note: the order matters
Michael McGeagh37ded342020-10-01 15:37:44 +0100126 self.generic_constraints = []
127 self.generic_constraints.append(SupportedOperators.constraint_tens_defined_shape)
Michael McGeagh184b2502020-10-09 17:19:52 +0100128 self.generic_constraints.append(SupportedOperators.constraint_tens_output_shapeless)
129 self.generic_constraints.append(SupportedOperators.constraint_tens_input_shapeless)
Michael McGeagh37ded342020-10-01 15:37:44 +0100130 self.generic_constraints.append(SupportedOperators.constraint_tens_shape_size)
131 self.generic_constraints.append(SupportedOperators.constraint_tens_dtype)
Michael McGeagh184b2502020-10-09 17:19:52 +0100132 self.generic_constraints.append(SupportedOperators.constraint_tens_int32_ops)
Michael McGeagh37ded342020-10-01 15:37:44 +0100133 self.generic_constraints.append(SupportedOperators.constraint_tens_dimension)
Dwight Lidman8359a472020-09-28 15:53:40 +0200134 self.generic_constraints.append(SupportedOperators.constraint_tens_quant_none_check)
Michael McGeagh184b2502020-10-09 17:19:52 +0100135 self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale)
136 self.generic_constraints.append(SupportedOperators.constraint_faf)
Tim Hall79d07d22020-04-27 18:20:16 +0100137
138 def is_operator_supported(self, op):
Michael McGeagh1eeea512020-09-30 14:23:09 +0100139 if op.type not in SupportedOperators.supported_operators:
Louis Verhaard5f2ea2f2020-10-15 08:39:44 +0200140 if op.type not in (Op.Placeholder, Op.SubgraphInput, Op.Const):
141 print("Info: {} '{}' is not supported on the NPU. Placing on CPU instead".format(op.type, op.name))
Tim Hall79d07d22020-04-27 18:20:16 +0100142 return False
Michael McGeagh37ded342020-10-01 15:37:44 +0100143 for constraint in self.generic_constraints:
144 valid, extra = constraint(op)
145 if not valid:
Michael McGeagh184b2502020-10-09 17:19:52 +0100146 print("Warning: {} '{}' is not supported on the NPU. Placing on CPU instead".format(op.type, op.name))
Michael McGeagh37ded342020-10-01 15:37:44 +0100147 print(" - {}".format(constraint.__doc__))
148 if extra:
149 print(" {}".format(extra))
150 return False
Tim Hall79d07d22020-04-27 18:20:16 +0100151 if op.type in self.supported_operator_restrictions:
152 return self.supported_operator_restrictions[op.type](op)
153 return True
154
Michael McGeagh37ded342020-10-01 15:37:44 +0100155 @staticmethod
156 def constraint_tens_defined_shape(op):
157 "Input(s) and Output Tensors must have a defined shape"
158 valid = True
159 extra = []
Michael McGeagh184b2502020-10-09 17:19:52 +0100160 tensors = [tens for tens in op.inputs + op.outputs if tens]
161 for tens in tensors:
162 if not tens.has_fully_defined_shape():
163 valid = False
164 extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
165 return valid, ", ".join(extra)
Michael McGeagh37ded342020-10-01 15:37:44 +0100166
Michael McGeagh184b2502020-10-09 17:19:52 +0100167 @staticmethod
168 def constraint_tens_output_shapeless(op):
169 "Scalar or Broadcasting Tensors are only valid for Input Tensors"
Michael McGeagh37ded342020-10-01 15:37:44 +0100170 valid = True
171 extra = []
Michael McGeagh37ded342020-10-01 15:37:44 +0100172 for tens in op.outputs:
173 if tens.shape == []:
174 valid = False
Michael McGeagh184b2502020-10-09 17:19:52 +0100175 extra.append("Output Tensor '{}' is shapeless".format(tens.name))
176 return valid, ", ".join(extra)
177
178 @classmethod
179 @docstring_format_args([shapeless_input_ops])
180 def constraint_tens_input_shapeless(cls, op):
181 "Scalar or Broadcasting Input Tensors are only valid for op type: {}"
182 valid = True
183 extra = []
184 tensors = [tens for tens in op.inputs if tens]
185 for tens in tensors:
186 if (tens.shape == []) and (op.type not in cls.shapeless_input_ops):
187 valid = False
188 extra.append(tens.name)
189 extra = "Op '{}' has shapeless input tensor(s): {}".format(op.name, ", ".join(extra))
190 return valid, extra
Tim Hall79d07d22020-04-27 18:20:16 +0100191
Michael McGeagh37ded342020-10-01 15:37:44 +0100192 @staticmethod
193 def constraint_tens_shape_size(op):
194 "Input(s) and Output Tensors must not be greater than 4D"
195 valid = True
196 extra = []
Michael McGeagh184b2502020-10-09 17:19:52 +0100197 tensors = [tens for tens in op.inputs + op.outputs if tens]
198 for tens in tensors:
199 if len(tens.shape) > 4:
200 valid = False
201 extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
202 return valid, ", ".join(extra)
Tim Hall79d07d22020-04-27 18:20:16 +0100203
Michael McGeagh37ded342020-10-01 15:37:44 +0100204 @classmethod
Michael McGeagh184b2502020-10-09 17:19:52 +0100205 @docstring_format_args([supported_dtypes])
Michael McGeagh37ded342020-10-01 15:37:44 +0100206 def constraint_tens_dtype(cls, op):
Michael McGeagh184b2502020-10-09 17:19:52 +0100207 "Tensors must be of type: {}"
Michael McGeagh37ded342020-10-01 15:37:44 +0100208 valid = True
209 extra = []
210 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
211 tensors = tensors if tensors else op.inputs
212 for tens in tensors:
Michael McGeagh184b2502020-10-09 17:19:52 +0100213 if tens.dtype not in cls.supported_dtypes:
214 valid = False
215 extra.append("Tensor '{}' has data type: {}".format(tens.name, tens.dtype))
216 return valid, ", ".join(extra)
217
218 @classmethod
219 @docstring_format_args([supported_int32_tensor_ops])
220 def constraint_tens_int32_ops(cls, op):
221 "Tensors which are int32 are only valid when op type is: {}"
222 valid = True
223 extra = []
224 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
225 tensors = tensors if tensors else op.inputs
226 for tens in tensors:
227 if (tens.dtype == DataType.int32) and (op.type not in cls.supported_int32_tensor_ops):
228 valid = False
229 extra.append(tens.name)
230 extra = "Op '{}' has int32 tensor(s): {}".format(op.name, ", ".join(extra))
231 return valid, extra
Andreas Nevalaineneadb1662020-09-01 15:36:26 +0200232
Michael McGeagh37ded342020-10-01 15:37:44 +0100233 @classmethod
234 @docstring_format_args(tens_dim_range)
235 def constraint_tens_dimension(cls, op):
236 "Tensor dimensions must be in the range {}-{} (inclusive)"
237 tens_min, tens_max = cls.tens_dim_range
238 valid = True
239 extra = []
240 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
241 tensors = tensors if tensors else op.inputs
242 for tens in tensors:
Michael McGeagh184b2502020-10-09 17:19:52 +0100243 if not all(tens_min <= dim <= tens_max for dim in tens.shape):
244 valid = False
245 extra.append("Tensor '{}' has shape: {}".format(tens.name, tens.shape))
246 return valid, ", ".join(extra)
Tim Hall79d07d22020-04-27 18:20:16 +0100247
Dwight Lidman8359a472020-09-28 15:53:40 +0200248 @staticmethod
249 def constraint_tens_quant_none_check(op):
250 "Tensors must have quantization parameters"
251 valid = True
252 extra = []
253 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
254 for tens in tensors:
255 if tens.quantization is None:
256 valid = False
257 extra.append("Tensor '{}' has no quantization parameters".format(tens.name))
258 return valid, ", ".join(extra)
259
Michael McGeagh184b2502020-10-09 17:19:52 +0100260 @staticmethod
261 def constraint_tens_quant_scale(op):
262 "Tensors with quantization scales must be finite"
263 valid = True
264 extra = []
265 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
266 for tens in tensors:
267 if (tens.quantization.scale_f32 is not None) and np.isinf(tens.quantization.scale_f32).any():
268 valid = False
269 extra.append("Tensor '{}' has quantization scale: {}".format(tens.name, tens.quantization.scale_f32))
270 return valid, ", ".join(extra)
271
272 @classmethod
273 @docstring_format_args([supported_fused_activations])
274 def constraint_faf(cls, op):
275 "The fused activation function (if present) must be one of type: {}"
276 faf = op.activation
277 valid = (faf is None) or (faf in cls.supported_fused_activations)
278 extra = "Op '{}' has its fused activation function as: {}".format(op.name, faf)
279 return valid, extra
280
Michael McGeagh1eeea512020-09-30 14:23:09 +0100281 @classmethod
282 def check_convolution_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100283 # check stride
Dwight Lidman8359a472020-09-28 15:53:40 +0200284 stride_w, stride_h = op.attrs["stride_w"], op.attrs["stride_h"]
285 if not is_integer(stride_w) or not is_integer(stride_h):
286 print("Warning:", op.type, "has non-integer stride, placing on CPU")
287 return False
288 if not 1 <= stride_w <= 3 or not 1 <= stride_h <= 3:
289 print(
290 "Warning: {} has stride ({}, {}), only strides in range [1, 3] are allowed. Placing on CPU".format(
291 op.type, stride_w, stride_h
292 )
293 )
Tim Hall79d07d22020-04-27 18:20:16 +0100294 return False
295
296 # check dilation
297 dilation_w_factor = op.attrs.get("dilation_w_factor", 1)
298 dilation_h_factor = op.attrs.get("dilation_h_factor", 1)
Dwight Lidman8359a472020-09-28 15:53:40 +0200299 if not is_integer(dilation_w_factor) or not is_integer(dilation_h_factor):
300 print("Warning:", op.type, "has non-integer dilation factor, placing on CPU")
301 return False
302 if not 1 <= dilation_w_factor <= 2 or not 1 <= dilation_h_factor <= 2:
303 print(
304 "Warning:",
305 op.type,
306 "has dilation factors ({}, {}), only factors in range [1, 2] are allowed. Placing on CPU".format(
307 dilation_w_factor, dilation_h_factor
308 ),
309 )
Tim Hall79d07d22020-04-27 18:20:16 +0100310 return False
311
312 # check data type
Jacob Bohlin49d92122020-08-19 14:36:46 +0200313 ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100314 if weight_tensor.element_size() > 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200315 print("Warning: only 8-bit weights are supported, placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100316 return False
317
Michael McGeagh1eeea512020-09-30 14:23:09 +0100318 if not cls.check_bias_restrictions(bias_tensor):
Jacob Bohlin49d92122020-08-19 14:36:46 +0200319 return False
320
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200321 # check kernel size [HWIO]
Dwight Lidman8359a472020-09-28 15:53:40 +0200322 dilated_weight_w = (weight_tensor.shape[1] - 1) * dilation_w_factor + 1
323 dilated_weight_h = (weight_tensor.shape[0] - 1) * dilation_h_factor + 1
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200324
Dwight Lidman8359a472020-09-28 15:53:40 +0200325 # kernel limits
326 if not 1 <= dilated_weight_h <= 64:
327 print("Warning:", op.type, "has kernel height outside of range [1, 64], placing on CPU")
328 return False
329 if not 1 <= dilated_weight_w * dilated_weight_h <= 64 * 64:
330 print(
331 "Warning: product of kernel width and height must be >= 1 and not exceed 64 * 64 ({}),".format(64 * 64),
332 "placing on CPU",
333 )
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200334 return False
335
Andreas Nevalainen8854dc92020-09-24 13:43:00 +0200336 # check non const weights
337 if weight_tensor.values is None:
Dwight Lidman8359a472020-09-28 15:53:40 +0200338 print("Warning:", op.type, "has non-constant weights, placing on CPU")
Andreas Nevalainen8854dc92020-09-24 13:43:00 +0200339 return False
340
Andreas Nevalainenf0c59bf2020-08-26 10:56:23 +0200341 # check weight sums over [HWI]
342 zero_point = weight_tensor.quantization.zero_point
343 quant_weights = weight_tensor.quant_values.astype(np.int64)
344 weights = quant_weights - zero_point
345 totals = np.sum(np.absolute(weights), axis=(0, 1, 2))
346
347 if np.amax(totals) > 127 * 65536:
Dwight Lidman8359a472020-09-28 15:53:40 +0200348 print("Warning: sum of weights exceeds 127 * 65536 ({}), placing on CPU".format(127 * 65536))
Tim Hall79d07d22020-04-27 18:20:16 +0100349 return False
350
351 # check batch size
352 if ifm_tensor.shape[0] != 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200353 print("Warning: only batch sizes of 1 are supported, placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100354 return False
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200355
Tim Hall79d07d22020-04-27 18:20:16 +0100356 return True
357
Michael McGeagh1eeea512020-09-30 14:23:09 +0100358 @classmethod
359 def check_depthwise_convolution_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100360 # check depth
Louis Verhaardaee5d752020-09-30 09:01:52 +0200361 ifm_tensor, ofm_tensor = op.get_ifm_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100362 if op.attrs["depth_multiplier"] > 1 and not (
363 (ifm_tensor.shape[3] == 1) and (ofm_tensor.shape[3] == op.attrs["depth_multiplier"])
364 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200365 print(
366 "Warning: for depth multipliers > 1,",
367 "number of input channels must be 1 and number of output channels must be equal to depth multiplier.",
368 "Placing on CPU",
369 )
Tim Hall79d07d22020-04-27 18:20:16 +0100370 return False
Michael McGeagh1eeea512020-09-30 14:23:09 +0100371 return cls.check_convolution_restrictions(op)
Tim Hall79d07d22020-04-27 18:20:16 +0100372
Michael McGeagh1eeea512020-09-30 14:23:09 +0100373 @classmethod
374 def check_transpose_convolution_restrictions(cls, op):
Jacob Bohlincf7da102020-05-20 09:03:40 +0200375 # check stride
376 stride_h, stride_w = op.attrs["stride_h"], op.attrs["stride_w"]
Dwight Lidman8359a472020-09-28 15:53:40 +0200377 if stride_h != 2 or stride_w != 2:
378 print("Warning: stride must be equal to 2, placing on CPU")
Jacob Bohlincf7da102020-05-20 09:03:40 +0200379 return False
380
381 # check output dimensions
382 ifm_tensor, weight_tensor, _, ofm_tensor = op.get_ifm_weights_biases_ofm()
383 ifm_h, ifm_w = ifm_tensor.shape[1], ifm_tensor.shape[2]
384 ofm_h, ofm_w = ofm_tensor.shape[1], ofm_tensor.shape[2]
385 if op.attrs["padding"] == b"SAME":
386 if (ofm_h != ifm_h * stride_h) or (ofm_w != ifm_w * stride_w):
Dwight Lidman8359a472020-09-28 15:53:40 +0200387 print(
388 "Warning: for",
389 op.type,
390 "using SAME padding, output dimensions must equal input dimensions multiplied by stride.",
391 "Placing on CPU",
392 )
Jacob Bohlincf7da102020-05-20 09:03:40 +0200393 return False
394 elif op.attrs["padding"] == b"VALID":
395 kernel_h, kernel_w = weight_tensor.shape[0], weight_tensor.shape[1]
Tim Hallc30f4952020-06-15 20:47:35 +0100396 if (ofm_h != (ifm_h) * stride_h + max(kernel_h - stride_h, 0)) or (
397 ofm_w != (ifm_w) * stride_w + max(kernel_w - stride_w, 0)
398 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200399 print(
400 "Warning: for",
401 op.type,
402 "using VALID padding, output dimensions must equal input dimensions multiplied by stride,",
403 "minus difference between kernel size and stride. Placing on CPU",
404 )
Jacob Bohlincf7da102020-05-20 09:03:40 +0200405 return False
406
Michael McGeagh1eeea512020-09-30 14:23:09 +0100407 return cls.check_convolution_restrictions(op)
Jacob Bohlincf7da102020-05-20 09:03:40 +0200408
Michael McGeagh1eeea512020-09-30 14:23:09 +0100409 @classmethod
410 def check_pooling_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100411 # check stride
Dwight Lidman8359a472020-09-28 15:53:40 +0200412 stride_w, stride_h = op.attrs["stride_w"], op.attrs["stride_h"]
413 if not is_integer(stride_w) or not is_integer(stride_h):
414 print("Warning:", op.type, "has non-integer stride, placing on CPU")
415 return False
416 if not 1 <= stride_w <= 3 or not 1 <= stride_h <= 3:
417 print(
418 "Warning: {} has stride ({}, {}), only strides in range [1, 3] are allowed. Placing on CPU".format(
419 op.type, stride_w, stride_h
420 )
421 )
Tim Hall79d07d22020-04-27 18:20:16 +0100422 return False
423
424 # check data type
Louis Verhaardaee5d752020-09-30 09:01:52 +0200425 ifm_tensor, ofm_tensor = op.get_ifm_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100426 if ifm_tensor.dtype != ofm_tensor.dtype:
Louis Verhaardaee5d752020-09-30 09:01:52 +0200427 if op.type != Op.ReduceSum:
Dwight Lidman8359a472020-09-28 15:53:40 +0200428 print("Warning: input data type doesn't match output data type, placing on CPU")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200429 return False
430 # TODO: else check ReduceSum restrictions.
Tim Hall79d07d22020-04-27 18:20:16 +0100431
432 # check batch size
433 if ifm_tensor.shape[0] != 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200434 print("Warning: input batch size must be 1, placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100435 return False
436
Dwight Lidman8359a472020-09-28 15:53:40 +0200437 # check kernel size
438 kernel_w, kernel_h = op.attrs["filter_width"], op.attrs["filter_height"]
439 if op.type in cls.avg_pooling_ops and op.attrs["padding"] == b"SAME":
440 if not 1 <= kernel_w <= 8 or not 1 <= kernel_h <= 8:
441 print(
442 "Warning:",
443 op.type,
444 "has kernel size ({}, {}), only kernel sizes in range [1, 8] are allowed. Placing on CPU".format(
445 kernel_w, kernel_h
446 ),
447 )
Tim Hall79d07d22020-04-27 18:20:16 +0100448 return False
Dwight Lidman8359a472020-09-28 15:53:40 +0200449 if op.type in cls.avg_pooling_ops and op.attrs["padding"] == b"VALID" or op.type in cls.max_pooling_ops:
450 if not 1 <= kernel_w * kernel_h <= 256 * 256:
451 print(
452 "Warning: product of kernel width and height must be >= 1 and not exceed 256 * 256 ({}),".format(
453 256 * 256
454 ),
455 "placing on CPU",
456 )
457 return False
458 if not 1 <= kernel_h <= 256:
459 print("Warning:", op.type, "has kernel height outside of range [1, 256], placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100460 return False
461
Tim Hall79d07d22020-04-27 18:20:16 +0100462 return True
463
Michael McGeagh1eeea512020-09-30 14:23:09 +0100464 @classmethod
465 def check_resize_restrictions(cls, op):
Dwight Lidman42fed942020-05-29 09:37:03 +0200466 # check unsupported upscaling factor
Louis Verhaardaee5d752020-09-30 09:01:52 +0200467 if op.type == Op.ResizeBilinear:
Charles Xu9a03fdf2020-07-02 15:12:40 +0200468 if op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1:
469 return True
Charles Xu36ffaf32020-08-05 15:40:44 +0200470 if op.inputs[0].shape == op.outputs[0].shape:
471 return True
Charles Xu87c13502020-08-06 12:17:26 +0200472 upscaled_shape = np.array(op.inputs[0].shape[1:3])
473 out_shape = np.array(op.outputs[0].shape[1:3])
474 while (upscaled_shape < out_shape).all():
475 upscaled_shape *= 2
476 if op.attrs["align_corners"]:
477 upscaled_shape -= 1
478 if np.array_equal(out_shape, upscaled_shape):
479 return True
480 return False
Dwight Lidman42fed942020-05-29 09:37:03 +0200481
Michael McGeagh1eeea512020-09-30 14:23:09 +0100482 @classmethod
483 def check_vector_product_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100484 # check data type
Dwight Lidman8359a472020-09-28 15:53:40 +0200485 ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
Tim Hall79d07d22020-04-27 18:20:16 +0100486 if weight_tensor.element_size() > 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200487 print("Warning: only 8-bit datatypes supported for {}, placing on CPU".format(op.type))
488 return False
489
Michael McGeagh1eeea512020-09-30 14:23:09 +0100490 if not cls.check_bias_restrictions(bias_tensor):
Jacob Bohlin49d92122020-08-19 14:36:46 +0200491 return False
492
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200493 # check non const weights
494 if weight_tensor.values is None:
495 print("Warning:", op.type, "has non-const weights, placing on CPU")
496 return False
497
Tim Hall79d07d22020-04-27 18:20:16 +0100498 return True
499
Michael McGeagh1eeea512020-09-30 14:23:09 +0100500 @classmethod
501 def check_element_wise_restrictions(cls, op):
Tim Hall79d07d22020-04-27 18:20:16 +0100502 # check data type
503 ifm_tensor, ifm2_tensor, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200504 # input and output datatype must match for these operators
Tim Hallc30f4952020-06-15 20:47:35 +0100505 if (
Michael McGeagh1eeea512020-09-30 14:23:09 +0100506 op.type in cls.binary_elem_wise_min_max_ops | cls.unary_elem_wise_main_ops
Tim Hallc30f4952020-06-15 20:47:35 +0100507 and ifm_tensor.dtype != ofm_tensor.dtype
508 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200509 print("Warning:", op.type, "must have same input and output datatype, placing on CPU")
Tim Hall79d07d22020-04-27 18:20:16 +0100510 return False
Michael McGeagh1eeea512020-09-30 14:23:09 +0100511 if op.type in cls.binary_elem_wise_add_mul_sub:
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200512 # both inputs must have same type
Tim Hallc30f4952020-06-15 20:47:35 +0100513 if ifm_tensor.dtype != ifm2_tensor.dtype:
Dwight Lidman8359a472020-09-28 15:53:40 +0200514 print("Warning:", op.type, "must have same datatype on both inputs, placing on CPU")
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200515 return False
516 # signed input check
Tim Hallc30f4952020-06-15 20:47:35 +0100517 if ifm_tensor.dtype.type & BaseType.Signed:
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200518 # output must be signed
Tim Hallc30f4952020-06-15 20:47:35 +0100519 if ofm_tensor.dtype.type & BaseType.Unsigned:
Dwight Lidman8359a472020-09-28 15:53:40 +0200520 print("Warning: only signed output types supported for {}, placing on CPU".format(op.type))
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200521 return False
522 # and 8, 16 or 32-bit
Dwight Lidman8359a472020-09-28 15:53:40 +0200523 bit_lengths = {8, 16, 32}
524 if ofm_tensor.element_size() * 8 not in bit_lengths:
525 print(
526 "Warning:", op.type, "is only supported for bit lengths {}, placing on CPU".format(bit_lengths)
527 )
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200528 return False
529 # unsigned input check, output must be same type or int32
Tim Hallc30f4952020-06-15 20:47:35 +0100530 if ifm_tensor.dtype.type & BaseType.Unsigned and not (
531 ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32
532 ):
Dwight Lidman8359a472020-09-28 15:53:40 +0200533 print("Warning:", op.type, "has unsigned input but output is not unsigned or int32, placing on CPU")
Fredrik Svedberg388e9c22020-05-25 16:32:00 +0200534 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200535 elif op.type in cls.binary_elem_wise_shift_ops:
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200536 if ifm_tensor.dtype != DataType.int32 or ifm2_tensor.dtype != DataType.int32:
Dwight Lidman8359a472020-09-28 15:53:40 +0200537 print("Warning:", op.type, "input datatypes are not int32, placing on CPU")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200538 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200539 if op.type in (Op.CLZ, Op.SHL) and ofm_tensor.dtype != DataType.int32:
Dwight Lidman8359a472020-09-28 15:53:40 +0200540 print("Warning:", op.type, "output datatype is not int32, placing on CPU")
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200541 return False
Tim Hall79d07d22020-04-27 18:20:16 +0100542
543 # check batch size
Dwight Lidmanf995db72020-04-27 11:15:12 +0200544 if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200545 print(
546 "Warning:",
547 op.type,
548 "only supports batch size 1 for tensors with more than 2 dimensions, placing on CPU",
549 )
Tim Hallc30f4952020-06-15 20:47:35 +0100550 return False
Michael McGeagh1eeea512020-09-30 14:23:09 +0100551 if op.type in cls.binary_elem_wise_main_ops: # if op type is unary, ifm2_tensor is None
Dwight Lidmanf995db72020-04-27 11:15:12 +0200552 if len(ifm2_tensor.shape) > 2 and ifm2_tensor.shape[0] != 1:
Dwight Lidman8359a472020-09-28 15:53:40 +0200553 print(
554 "Warning:",
555 op.type,
556 "only supports batch size 1 for tensors with more than 2 dimensions, placing on CPU",
557 )
Dwight Lidmanf995db72020-04-27 11:15:12 +0200558 return False
Dwight Lidman332a7042020-06-11 15:32:42 +0200559
560 # negative alpha values are not supported
Louis Verhaardaee5d752020-09-30 09:01:52 +0200561 if op.type == Op.LeakyRelu and op.attrs["alpha"] < 0:
Dwight Lidman8359a472020-09-28 15:53:40 +0200562 print("Warning:", op.type, "has negative alpha, placing on CPU")
Dwight Lidman332a7042020-06-11 15:32:42 +0200563 return False
564
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200565 # check if ifm or ifm2 has ofm shape
566 if ifm_tensor.shape != ofm_tensor.shape and ifm2_tensor.shape != ofm_tensor.shape:
Dwight Lidman8359a472020-09-28 15:53:40 +0200567 print("Warning:", op.type, "input shape(s) differ from output shape, placing on CPU")
Andreas Nevalainend8c032d2020-09-11 10:25:09 +0200568 return False
569
Michael McGeagh1eeea512020-09-30 14:23:09 +0100570 if op.type in cls.binary_elem_wise_min_max_ops and not cls.check_quantization_restrictions_binary_elem_wise(op):
Patrik Gustavsson530992a2020-09-30 13:26:59 +0200571 return False
572
Tim Hall79d07d22020-04-27 18:20:16 +0100573 return True
574
Michael McGeagh1eeea512020-09-30 14:23:09 +0100575 @classmethod
576 def check_memory_only_restrictions(cls, op):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200577 if op.type == Op.StridedSlice:
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200578 if len(op.inputs) != 4:
579 warn_cpu(op, "has {} input tensors, only 4 inputs are supported".format(len(op.inputs)))
Tim Hall79d07d22020-04-27 18:20:16 +0100580 return False
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200581 input_tens, begin_tens, end_tens, strides_tens = op.inputs
582 if begin_tens.values is None or end_tens.values is None or strides_tens.values is None:
583 warn_cpu(op, "has a non-constant begin, end, or stride input tensor, which is not supported")
584 return False
585 if not (
586 len(input_tens.shape)
587 == len(op.outputs[0].shape)
588 == len(begin_tens.values)
589 == len(end_tens.values)
590 == len(strides_tens.values)
591 ):
592 warn_cpu(op, "has input tensors with shapes that are not supported")
593 return False
594 # check stride size
595 if any(stride != 1 for stride in strides_tens.values):
596 warn_cpu(op, "has stride values {}, only stride 1 values are supported".format(strides_tens.values))
Michael McGeaghecd20522020-07-31 16:59:45 +0100597 return False
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200598 # check ellipsis_mask
599 if op.attrs["ellipsis_mask"] != 0:
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200600 warn_cpu(op, "ellipsis_mask is {}, only 0 is supported".format(op.attrs["ellipsis_mask"]))
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200601 return False
602 # check if both new_axis_mask and shrink_axis_mask have bit set
603 if op.attrs["new_axis_mask"] != 0 and op.attrs["shrink_axis_mask"] != 0:
Louis Verhaardfa2f92a2020-09-21 11:56:18 +0200604 warn_cpu(op, "new_axis_mask and shrink_axis_mask are both non-zero, which is not supported")
605 return False
606 # Calculate offset start/end
607 offset_start = get_slice_offsets(input_tens.shape, begin_tens, op.attrs["begin_mask"], is_begin=True)
608 offset_end = get_slice_offsets(input_tens.shape, end_tens, op.attrs["end_mask"], is_begin=False)
609 # check "end - begin" doesn't result in any zero or negative elements
610 if any((end - begin) <= 0 for begin, end in zip(offset_start, offset_end)):
611 warn_cpu(
612 op,
613 "has slice begin values {}, some of which are >= end values {}, which is illegal".format(
614 begin_tens.values, end_tens.values
615 ),
616 )
Patrik Gustavssoncf728902020-04-30 08:57:23 +0200617 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200618 if op.type == Op.SplitV:
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200619 # check that maximum one size is set to -1, indicating that size should be inferred
620 sizes = op.inputs[1].values
621 num_to_be_inferred = 0
622 for size in sizes:
623 if size == -1:
624 num_to_be_inferred += 1
625
626 if num_to_be_inferred > 1:
627 print("Warning:", op.type, "has more than one size to be inferred, which is illegal, placing on CPU")
628 return False
Louis Verhaardaee5d752020-09-30 09:01:52 +0200629 if op.type in set((Op.Concat, Op.ConcatTFLite,)):
Fredrik Svedberg0f98b362020-09-29 10:00:39 +0200630 axis = op.attrs.get("axis", None)
631 if axis is None:
632 print("Warning:", op.type, "invalid or missing axis, placing on CPU")
633 return False
634 if axis < 0:
635 axis += len(op.inputs[0].shape)
Patrik Gustavsson36ad73a2020-10-06 13:58:24 +0200636 if not 0 <= axis < len(op.inputs[0].shape):
Fredrik Svedberg0f98b362020-09-29 10:00:39 +0200637 print("Warning:", op.type, "invalid axis", axis, ", placing on CPU")
638 return False
639 ofm = op.outputs[0]
640 ofm_dims = len(ofm.shape)
641 for ifm in op.inputs:
642 if len(ifm.shape) != ofm_dims:
643 return False
644 for i in range(ofm_dims):
645 if i != axis and ifm.shape[i] != ofm.shape[i]:
Patrik Gustavsson530992a2020-09-30 13:26:59 +0200646 print(
647 "Warning:",
648 op.type,
649 "invalid ifm:",
650 ifm.name,
651 ifm.shape,
652 "mismatch in dimension",
653 i,
654 ", placing on CPU",
655 )
Fredrik Svedberg0f98b362020-09-29 10:00:39 +0200656 return False
Patrik Gustavsson271ddc32020-09-01 09:15:27 +0200657
Tim Hall79d07d22020-04-27 18:20:16 +0100658 return True
Dwight Lidmanebe26c72020-06-09 11:40:54 +0200659
Michael McGeagh1eeea512020-09-30 14:23:09 +0100660 @classmethod
661 def check_quantization_restrictions_binary_elem_wise(cls, op):
Dwight Lidmanebe26c72020-06-09 11:40:54 +0200662 # makes sure IFM1, IFM2 and OFM quantization are equal for binary ops
Tim Halle3786ac2020-07-28 17:40:50 +0100663 assert len(op.inputs) >= 2 and len(op.outputs) == 1
664
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200665 if (
Tim Halle3786ac2020-07-28 17:40:50 +0100666 op.inputs[0].quantization is None
Michael McGeagh34ad19b2020-09-04 15:44:23 +0100667 or not op.inputs[0].is_scaling_equal(op.inputs[1])
668 or not op.inputs[0].is_scaling_equal(op.outputs[0])
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200669 ):
670 print(
671 "Warning: Input/output tensors with different quantization is unsupported for the", op.type, "operator"
672 )
Dwight Lidmanebe26c72020-06-09 11:40:54 +0200673 return False
Tim Halle3786ac2020-07-28 17:40:50 +0100674
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200675 return True
676
Michael McGeagh1eeea512020-09-30 14:23:09 +0100677 @classmethod
678 def check_activation_ops(cls, op):
Louis Verhaardaee5d752020-09-30 09:01:52 +0200679 if op.type == Op.Softmax:
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200680 ifm_tensor = op.inputs[0]
681 ofm_tensor = op.outputs[0]
682
683 # check data type
684 if ifm_tensor.dtype != ofm_tensor.dtype:
Dwight Lidman8359a472020-09-28 15:53:40 +0200685 print("Warning:", op.type, "input type differs from output type, placing on CPU")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200686 return False
687
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200688 if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16):
Dwight Lidman8359a472020-09-28 15:53:40 +0200689 print(
690 "Warning: only datatypes supported for {} are uint8, int8 and int16; placing on CPU".format(op.type)
691 )
Fredrik Svedberg597fd3f2020-08-13 10:02:53 +0200692 return False
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200693
Fredrik Svedberg835d8e12020-09-04 09:46:17 +0200694 # check shape
Michael McGeagh37ded342020-10-01 15:37:44 +0100695 if ifm_tensor.shape != ofm_tensor.shape:
Dwight Lidman8359a472020-09-28 15:53:40 +0200696 print("Warning:", op.type, "input shape differs from output shape, placing on CPU")
Fredrik Svedberga0c36242020-06-03 15:43:31 +0200697 return False
698
699 return True
Jacob Bohlin49d92122020-08-19 14:36:46 +0200700
Michael McGeagh1eeea512020-09-30 14:23:09 +0100701 @classmethod
702 def check_bias_restrictions(cls, bias_tensor):
Jacob Bohlin49d92122020-08-19 14:36:46 +0200703 # check data type
Jacob Bohlin258ebba2020-08-31 10:44:35 +0200704 if bias_tensor is not None and bias_tensor.dtype not in (DataType.int32, DataType.int64):
Dwight Lidman8359a472020-09-28 15:53:40 +0200705 print("Warning: bias tensor datatype must be int32 or int64, placing on CPU")
Jacob Bohlin49d92122020-08-19 14:36:46 +0200706 return False
707
708 # check if values fits in 40-bit
Jacob Bohlin258ebba2020-08-31 10:44:35 +0200709 if bias_tensor is not None and bias_tensor.dtype == DataType.int64:
Tim Hall71525172020-08-29 15:09:57 +0100710 for quant_value in bias_tensor.quant_values:
711 if not (-(1 << 39) <= quant_value < (1 << 39)):
Dwight Lidman8359a472020-09-28 15:53:40 +0200712 print("Warning: bias tensor values are larger than 40 bits, placing on CPU")
Jacob Bohlin49d92122020-08-19 14:36:46 +0200713 return False
714
715 return True